├── README.md ├── TemporalAlignment ├── copy.sh ├── dataset.py ├── models │ ├── mocogan_discriminator.py │ ├── mocoganhd_content_disc.py │ ├── mocoganhd_losses.py │ ├── mocoganhd_models.py │ ├── mocoganhd_video_disc.py │ └── video_discriminator.py ├── perturbations.py └── ranges.py ├── bad_mp4s.json ├── bash_scripts ├── train_videovqvae.sh ├── train_videovqvae_perturbations.sh └── train_videovqvae_resume_ckpt.sh ├── command.txt ├── config.py ├── copy.sh ├── datasets ├── face_translation_videos3_utils.py └── face_translation_videos3_utils_bb.py ├── disc_trainers ├── train_vqvae_mocogan_disc.py ├── train_vqvae_mocogan_disc_perceptual.py ├── train_vqvae_mocoganhd_disc.py ├── train_vqvae_mocoganhd_disc_single.py └── train_vqvae_perceptual_mocoganhd_disc.py ├── distributed ├── __init__.py ├── distributed.py └── launch.py ├── environment.yml ├── loss.py ├── models ├── actnorm.py ├── discriminator.py ├── lpips.py └── vqvae_conv3d_latent.py ├── preprocessing ├── landmark_generation.py └── preprocess_dataset.py ├── results ├── inference_pipeline.gif ├── training_pipeline.gif ├── v2v_comparisons1.gif ├── v2v_comparisons31.gif ├── v2v_face_swapping.gif ├── v2v_faceswapping_looped2.gif ├── v2v_more_result.gif ├── v2v_results1.gif ├── v2v_results2.gif ├── v2v_results3.gif ├── v2v_results4.gif ├── v2v_same_identity1.gif ├── v2v_same_identity2.gif └── v2v_same_identity3.gif ├── sample └── .gitignore ├── scheduler.py ├── train_faceoff.py ├── train_faceoff_perceptual.py ├── utils.py ├── valid_folders_ft.json └── valid_videos.txt /README.md: -------------------------------------------------------------------------------- 1 | # FaceOff: A Video-to-Video Face Swapping System 2 | 3 | [Aditya Agarwal](http://skymanaditya1.github.io/)\*1, 4 | [Bipasha Sen](https://bipashasen.github.io/)\*1, 5 | [Rudrabha Mukhopadhyay](https://rudrabha.github.io/)1, 6 | [Vinay Namboodiri](https://vinaypn.github.io/)2, 7 | [C V Jawahar](https://faculty.iiit.ac.in/~jawahar/)1
8 | 1International Institute of Information Technology, Hyderabad, 2University of Bath 9 | 10 | \*denotes equal contribution 11 | 12 | This is the official implementation of the paper "FaceOff: A Video-to-Video Face Swapping System" **published** at WACV 2023. 13 | 14 | 15 | 16 | 17 | 18 | For more results, information, and details visit our [**project page**](http://cvit.iiit.ac.in/research/projects/cvit-projects/faceoff) and read our [**paper**](https://openaccess.thecvf.com/content/WACV2023/papers/Agarwal_FaceOff_A_Video-to-Video_Face_Swapping_System_WACV_2023_paper.pdf). Following are some outputs from our network on the V2V Face Swapping Task. 19 | 20 | 21 | ## Results on same identity 22 | 23 | 24 | 30 | 31 | ## Getting started 32 | 33 | 1. Set up a conda environment with all dependencies using the following commands: 34 | 35 | ``` 36 | conda env create -f environment.yml 37 | conda activate faceoff 38 | ``` 39 | 40 | ## Training FaceOff 41 | 42 | 43 | 44 | The following command trains the V2V Face Swapping network. At set intervals, it will generate the data on the validation dataset. 45 | 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_faceoff_perceptual.py 48 | ``` 49 | **Parameters**
50 | Below is the full list of parameters 51 | 52 | ```--dist_url``` - port on which experiment is run
53 | ```--batch_size``` - batch size, default is ```32```
54 | ```--size``` - image size, default is ```256```
55 | ```--epoch``` - number of epochs to train for
56 | ```--lr``` - learning rate, default is ```3e-4``` 57 | ```--sched``` - scheduler to use
58 | ```--checkpoint_suffix``` - folder where checkpoints are saved, in default mode a random folder name is created
59 | ```--validate_at``` - number of steps after which validation is performed, default is ```1024```
60 | ```--ckpt``` - indicates a pretrained checkpoint, default is None
61 | ```--test``` - whether testing the model
62 | ```--gray``` - whether testing on gray scale
63 | ```--colorjit``` - type of color jitter to add, ```const```, ```random``` or ```empty``` are the possible options
64 | ```--crossid``` - whether cross id required during validation, default is ```True```
65 | ```--custom_validation``` - used to test FaceOff on two videos, default is ```False```
66 | ```--sample_folder``` - path where the validation videos are stored
67 | ```--checkpoint_dir``` - dir path where checkpoints are saved
68 | ```--validation_folder``` - dir path where validated samples are saved
69 | 70 | All the values can be left at their default values to train FaceOff in the vanilla setting. An example is given below: 71 | 72 | ``` 73 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_faceoff_perceptual.py 74 | ``` 75 | 76 | ## Checkpoints 77 | Pretrained checkpoints would be released soon! 78 | 79 | ## We would love your contributions to improve FaceOff 80 | 81 | FaceOff introduces the novel task of Video-to-Video Face Swapping that tackles a pressing challenge in the moviemaking industry: swapping the actor's face and expressions on the face of their body double. Existing face-swapping methods swap only the identity of the source face without swapping the source (actor) expressions which is undesirable as the starring actor's source expressions are paramount. In video-to-video face swapping, we swap the source's facial expressions along with the identity on the target's background and pose. Our method retains the face and expressions of the source actor and the pose and background information of the target actor. Currently, our model has a few limitations. ***we would like to strongly encourage contributions and spur further research into some of the limitations listed above.*** 82 | 83 | 1. **Video Quality**: FaceOff is based on combining the temporal motion of the source and the target face videos in the reduced space of a vector quantized variational autoencoder. Consequently, it suffers from a few quality issues and the output resolution is limited to 256x256. Generating samples of very high-quality is typically required in the movie-making industry. 84 | 85 | 2. **Temporal Jitter**: Although the 3Dconv modules get rid of most of the temporal jitters in the blended output, there are a few noticeable temporal jitters that require attention. The temporal jitters occur as we try to photo-realistically blend two different motions (source and target) in a temporally coherent manner. 86 | 87 | 3. **Extreme Poses**: FaceOff was designed to face-swap actor's face and expressions with the double's pose and background information. Consequently, it is expected that the pose difference between the source and the target actors won't be extreme. FaceOff can solve **roll** related rotations in the 2D space, it would be worth investigating fixing rotations due to **yaw** and **pitch** in the 3D space and render the output back to the 2D space. 88 | 89 | 90 | ## Thanks 91 | 92 | ### VQVAE2 93 | 94 | We would like to thank the authors of [VQVAE2](https://arxiv.org/pdf/1906.00446.pdf) ([https://github.com/rosinality/vq-vae-2-pytorch])(https://github.com/rosinality/vq-vae-2-pytorch) for releasing the code. We modify on top of their codebase for performing V2V Face Swapping. 95 | 96 | ## Citation 97 | If you find our work useful in your research, please cite: 98 | ``` 99 | @InProceedings{Agarwal_2023_WACV, 100 | author = {Agarwal, Aditya and Sen, Bipasha and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C. V.}, 101 | title = {FaceOff: A Video-to-Video Face Swapping System}, 102 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 103 | month = {January}, 104 | year = {2023}, 105 | pages = {3495-3504} 106 | } 107 | ``` 108 | 109 | ## Contact 110 | If you have any questions, please feel free to email the authors. 111 | 112 | Aditya Agarwal: aditya.ag@research.iiit.ac.in
113 | Bipasha Sen: bipasha.sen@research.iiit.ac.in
114 | Rudrabha Mukhopadhyay: radrabha.m@research.iiit.ac.in
115 | Vinay Namboodiri: vpn22@bath.ac.uk
116 | C V Jawahar: jawahar@iiit.ac.in
-------------------------------------------------------------------------------- /TemporalAlignment/copy.sh: -------------------------------------------------------------------------------- 1 | cp -r /home2/bipasha31/python_scripts/CurrentWork/SLP/VQVAE2-Refact/TemporalAlignment/* . -------------------------------------------------------------------------------- /TemporalAlignment/models/mocogan_discriminator.py: -------------------------------------------------------------------------------- 1 | # modified versions of mocogan's image and video discriminators 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.utils.data 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | 11 | if torch.cuda.is_available(): 12 | T = torch.cuda 13 | else: 14 | T = torch 15 | 16 | class Noise(nn.Module): 17 | def __init__(self, use_noise, sigma=0.2): 18 | super(Noise, self).__init__() 19 | self.use_noise = use_noise 20 | self.sigma = sigma 21 | 22 | def forward(self, x): 23 | if self.use_noise: 24 | return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False) 25 | return x 26 | 27 | # modified version of mocogan's image discriminator 28 | # input dimension -> batch_size x channels x 256 x 256, output dimension -> batch_size 29 | class ImageDiscriminator(nn.Module): 30 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None): 31 | super(ImageDiscriminator, self).__init__() 32 | 33 | self.use_noise = use_noise 34 | 35 | self.main = nn.Sequential( 36 | Noise(use_noise, sigma=noise_sigma), 37 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False), 38 | nn.LeakyReLU(0.2, inplace=True), 39 | 40 | Noise(use_noise, sigma=noise_sigma), 41 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 42 | nn.BatchNorm2d(ndf * 2), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | 45 | Noise(use_noise, sigma=noise_sigma), 46 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 47 | nn.BatchNorm2d(ndf * 4), 48 | nn.LeakyReLU(0.2, inplace=True), 49 | 50 | Noise(use_noise, sigma=noise_sigma), 51 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 52 | nn.BatchNorm2d(ndf * 8), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | 55 | Noise(use_noise, sigma=noise_sigma), 56 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ndf * 16), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | Noise(use_noise, sigma=noise_sigma), 61 | nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False), 62 | nn.BatchNorm2d(ndf * 32), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | 65 | nn.Conv2d(ndf * 32, 1, 4, 1, 0, bias=False), 66 | ) 67 | 68 | def forward(self, input): 69 | h = self.main(input).squeeze() 70 | return h, None 71 | 72 | # modified version of mocogan's patch image discriminator 73 | # input dimension -> batch_size x channels x 256 x 256, output dimension -> batch_size x 4 x 4 74 | class PatchImageDiscriminator(nn.Module): 75 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None): 76 | super(PatchImageDiscriminator, self).__init__() 77 | 78 | self.use_noise = use_noise 79 | 80 | self.main = nn.Sequential( 81 | Noise(use_noise, sigma=noise_sigma), 82 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False), 83 | nn.LeakyReLU(0.2, inplace=True), 84 | 85 | Noise(use_noise, sigma=noise_sigma), 86 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 87 | nn.BatchNorm2d(ndf * 2), 88 | nn.LeakyReLU(0.2, inplace=True), 89 | 90 | Noise(use_noise, sigma=noise_sigma), 91 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 92 | nn.BatchNorm2d(ndf * 4), 93 | nn.LeakyReLU(0.2, inplace=True), 94 | 95 | Noise(use_noise, sigma=noise_sigma), 96 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 97 | nn.BatchNorm2d(ndf * 8), 98 | nn.LeakyReLU(0.2, inplace=True), 99 | 100 | Noise(use_noise, sigma=noise_sigma), 101 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False), 102 | nn.BatchNorm2d(ndf * 16), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | 105 | Noise(use_noise, sigma=noise_sigma), 106 | nn.Conv2d(ndf * 16, 1, 4, 2, 1, bias=False), 107 | ) 108 | 109 | def forward(self, input): 110 | h = self.main(input).squeeze() 111 | return h, None 112 | 113 | # modified version of mocogan's video discriminator 114 | # input dimension -> batch_size x channels x 16 x 256 x 256 115 | # output idmension -> batch_size x 16 116 | class VideoDiscriminator(nn.Module): 117 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64): 118 | super(VideoDiscriminator, self).__init__() 119 | 120 | self.n_channels = n_channels 121 | self.n_output_neurons = n_output_neurons 122 | self.use_noise = use_noise 123 | self.bn_use_gamma = bn_use_gamma 124 | 125 | self.main = nn.Sequential( 126 | Noise(use_noise, sigma=noise_sigma), 127 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 128 | nn.LeakyReLU(0.2, inplace=True), 129 | 130 | Noise(use_noise, sigma=noise_sigma), 131 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 132 | nn.BatchNorm3d(ndf * 2), 133 | nn.LeakyReLU(0.2, inplace=True), 134 | 135 | Noise(use_noise, sigma=noise_sigma), 136 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 137 | nn.BatchNorm3d(ndf * 4), 138 | nn.LeakyReLU(0.2, inplace=True), 139 | 140 | Noise(use_noise, sigma=noise_sigma), 141 | nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 142 | nn.BatchNorm3d(ndf * 8), 143 | nn.LeakyReLU(0.2, inplace=True), 144 | 145 | Noise(use_noise, sigma=noise_sigma), 146 | nn.Conv3d(ndf * 8, ndf * 16, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 147 | nn.BatchNorm3d(ndf * 16), 148 | nn.LeakyReLU(0.2, inplace=True), 149 | 150 | Noise(use_noise, sigma=noise_sigma), 151 | nn.Conv3d(ndf * 16, ndf * 32, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 152 | nn.BatchNorm3d(ndf * 32), 153 | nn.LeakyReLU(0.2, inplace=True), 154 | 155 | nn.Conv3d(ndf * 32, n_output_neurons, (1, 4, 4), 1, 0, bias=False), # cannot apply kernel size of 4 on input image of size 1 156 | ) 157 | 158 | def forward(self, input): 159 | h = self.main(input).squeeze() 160 | 161 | return h, None 162 | 163 | # modified version of mocogan's patch video disctiminator 164 | # input dimension -> batch_size x channels x 16 x 256 x 256 165 | # output dimension -> batch_size x 4 x 4 x 4 166 | class PatchVideoDiscriminator(nn.Module): 167 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64): 168 | super(PatchVideoDiscriminator, self).__init__() 169 | 170 | self.n_channels = n_channels 171 | self.n_output_neurons = n_output_neurons 172 | self.use_noise = use_noise 173 | self.bn_use_gamma = bn_use_gamma 174 | 175 | self.main = nn.Sequential( 176 | Noise(use_noise, sigma=noise_sigma), 177 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | 180 | Noise(use_noise, sigma=noise_sigma), 181 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 182 | nn.BatchNorm3d(ndf * 2), 183 | nn.LeakyReLU(0.2, inplace=True), 184 | 185 | Noise(use_noise, sigma=noise_sigma), 186 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 187 | nn.BatchNorm3d(ndf * 4), 188 | nn.LeakyReLU(0.2, inplace=True), 189 | 190 | Noise(use_noise, sigma=noise_sigma), 191 | nn.Conv3d(ndf * 4, ndf * 8, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 192 | nn.BatchNorm3d(ndf * 8), 193 | nn.LeakyReLU(0.2, inplace=True), 194 | 195 | Noise(use_noise, sigma=noise_sigma), 196 | nn.Conv3d(ndf * 8, ndf * 16, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 197 | nn.BatchNorm3d(ndf * 16), 198 | nn.LeakyReLU(0.2, inplace=True), 199 | 200 | nn.Conv3d(ndf * 16, 1, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 201 | ) 202 | 203 | def forward(self, input): 204 | h = self.main(input).squeeze() 205 | 206 | return h, None -------------------------------------------------------------------------------- /TemporalAlignment/models/mocoganhd_content_disc.py: -------------------------------------------------------------------------------- 1 | # This is the mocoganhd content discriminator code 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import functools 6 | 7 | # this is the image discriminator - for classifying the content as real/fake 8 | class ModelD_img(nn.Module): 9 | def __init__(self, nc, norm_D_3d, num_D, lr): 10 | super(ModelD_img, self).__init__() 11 | nc = nc * 2 12 | 13 | self.netD = MultiscaleDiscriminator(input_nc=nc, 14 | norm_layer=get_norm_layer( 15 | norm_D_3d), 16 | num_D=num_D) 17 | self.netD.apply(weights_init) 18 | 19 | self.optim = torch.optim.Adam(self.netD.parameters(), 20 | lr=lr, 21 | betas=(0.5, 0.999)) 22 | 23 | def forward(self, x): 24 | return self.netD.forward(x) 25 | 26 | 27 | def weights_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1 and hasattr(m, 'weight'): 30 | m.weight.data.normal_(0.0, 0.02) 31 | elif classname.find('BatchNorm2d') != -1: 32 | m.weight.data.normal_(1.0, 0.02) 33 | m.bias.data.fill_(0) 34 | 35 | 36 | def get_norm_layer(norm_type='instance'): 37 | if norm_type == 'batch': 38 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 39 | elif norm_type == 'instance': 40 | norm_layer = functools.partial(nn.InstanceNorm2d, 41 | affine=False, 42 | track_running_stats=True) 43 | else: 44 | raise NotImplementedError('normalization layer [%s] is not found' % 45 | norm_type) 46 | return norm_layer 47 | 48 | 49 | class MultiscaleDiscriminator(nn.Module): 50 | def __init__(self, 51 | input_nc, 52 | ndf=64, 53 | n_layers=3, 54 | n_frames=16, 55 | norm_layer=nn.InstanceNorm2d, 56 | num_D=2, 57 | getIntermFeat=True): 58 | super(MultiscaleDiscriminator, self).__init__() 59 | self.num_D = num_D 60 | self.n_layers = n_layers 61 | self.getIntermFeat = getIntermFeat 62 | ndf_max = 64 63 | 64 | for i in range(num_D): 65 | netD = NLayerDiscriminator( 66 | input_nc, min(ndf_max, ndf * (2**(num_D - 1 - i))), n_layers, 67 | norm_layer, getIntermFeat) 68 | if getIntermFeat: 69 | for j in range(n_layers + 2): 70 | setattr(self, 'scale' + str(i) + '_layer' + str(j), 71 | getattr(netD, 'model' + str(j))) 72 | else: 73 | setattr(self, 'layer' + str(i), netD.model) 74 | self.downsample = nn.AvgPool2d(3, 75 | stride=2, 76 | padding=[1, 1], 77 | count_include_pad=False) 78 | 79 | def singleD_forward(self, model, input): 80 | if self.getIntermFeat: 81 | result = [input] 82 | for i in range(len(model)): 83 | result.append(model[i](result[-1])) 84 | return result[1:] 85 | else: 86 | return [model(input)] 87 | 88 | def forward(self, input): 89 | num_D = self.num_D 90 | result = [] 91 | input_downsampled = input 92 | for i in range(num_D): 93 | if self.getIntermFeat: 94 | model = [ 95 | getattr(self, 96 | 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) 97 | for j in range(self.n_layers + 2) 98 | ] 99 | else: 100 | model = getattr(self, 'layer' + str(num_D - 1 - i)) 101 | # the model needs to be transferred to the GPU first 102 | result.append(self.singleD_forward(model, input_downsampled)) # the model is not on the GPU, that's why the error is being thrown 103 | if i != (num_D - 1): 104 | input_downsampled = self.downsample(input_downsampled) 105 | return result 106 | 107 | 108 | class NLayerDiscriminator(nn.Module): 109 | def __init__(self, 110 | input_nc, 111 | ndf=64, 112 | n_layers=3, 113 | norm_layer=nn.InstanceNorm2d, 114 | getIntermFeat=True): 115 | super(NLayerDiscriminator, self).__init__() 116 | self.getIntermFeat = getIntermFeat 117 | self.n_layers = n_layers 118 | 119 | kw = 4 120 | padw = int(np.ceil((kw - 1.0) / 2)) 121 | sequence = [[ 122 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 123 | nn.LeakyReLU(0.2, True) 124 | ]] 125 | 126 | nf = ndf 127 | for n in range(1, n_layers): 128 | nf_prev = nf 129 | nf = min(nf * 2, 512) 130 | sequence += [[ 131 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 132 | norm_layer(nf), 133 | nn.LeakyReLU(0.2, True) 134 | ]] 135 | 136 | nf_prev = nf 137 | nf = min(nf * 2, 512) 138 | sequence += [[ 139 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 140 | norm_layer(nf), 141 | nn.LeakyReLU(0.2, True) 142 | ]] 143 | 144 | sequence += [[ 145 | nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) 146 | ]] 147 | 148 | if getIntermFeat: 149 | for n in range(len(sequence)): 150 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 151 | else: 152 | sequence_stream = [] 153 | for n in range(len(sequence)): 154 | sequence_stream += sequence[n] 155 | self.model = nn.Sequential(*sequence_stream) 156 | 157 | def forward(self, input): 158 | if self.getIntermFeat: 159 | res = [input] 160 | for n in range(self.n_layers + 2): 161 | model = getattr(self, 'model' + str(n)) 162 | res.append(model(res[-1])) 163 | return res[1:] 164 | else: 165 | return self.model(input) 166 | -------------------------------------------------------------------------------- /TemporalAlignment/models/mocoganhd_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | import sys 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.nn as nn 14 | 15 | 16 | def loss_hinge_dis(dis_fake, dis_real): 17 | loss_real = torch.mean(F.relu(1. - dis_real)) 18 | loss_fake = torch.mean(F.relu(1. + dis_fake)) 19 | return loss_real, loss_fake 20 | 21 | 22 | def loss_hinge_gen(dis_fake): 23 | loss = -torch.mean(dis_fake) 24 | return loss 25 | 26 | 27 | def compute_gradient_penalty_T(real_B, fake_B, modelD, opt): 28 | alpha = torch.rand(list(real_B.size())[0], 1, 1, 1, 1) 29 | alpha = alpha.expand(real_B.size()).cuda(real_B.get_device()) 30 | 31 | interpolates = alpha * real_B.data + (1 - alpha) * fake_B.data 32 | interpolates = torch.tensor(interpolates, requires_grad=True) 33 | 34 | pred_interpolates = modelD(interpolates) 35 | 36 | gradient_penalty = 0 37 | if isinstance(pred_interpolates, list): 38 | for cur_pred in pred_interpolates: 39 | gradients = torch.autograd.grad(outputs=cur_pred[-1], 40 | inputs=interpolates, 41 | grad_outputs=torch.ones( 42 | cur_pred[-1].size()).cuda( 43 | real_B.get_device()), 44 | create_graph=True, 45 | retain_graph=True, 46 | only_inputs=True)[0] 47 | 48 | gradient_penalty += ((gradients.norm(2, dim=1) - 1)**2).mean() 49 | else: 50 | sys.exit('output is not list!') 51 | 52 | gradient_penalty = (gradient_penalty / opt.num_D) * 10 53 | return gradient_penalty 54 | 55 | 56 | class GANLoss(nn.Module): 57 | def __init__(self, 58 | use_lsgan=True, 59 | target_real_label=1.0, 60 | target_fake_label=0.0, 61 | tensor=torch.FloatTensor): 62 | super(GANLoss, self).__init__() 63 | self.real_label = target_real_label 64 | self.fake_label = target_fake_label 65 | self.real_label_var = None 66 | self.fake_label_var = None 67 | self.Tensor = tensor 68 | if use_lsgan: 69 | self.loss = nn.MSELoss() 70 | else: 71 | self.loss = nn.BCELoss() 72 | 73 | def get_target_tensor(self, input, target_is_real): 74 | target_tensor = None 75 | if target_is_real: 76 | create_label = ((self.real_label_var is None) 77 | or (self.real_label_var.numel() != input.numel())) 78 | if create_label: 79 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 80 | self.real_label_var = torch.tensor(real_tensor, 81 | requires_grad=False) 82 | target_tensor = self.real_label_var 83 | else: 84 | create_label = ((self.fake_label_var is None) 85 | or (self.fake_label_var.numel() != input.numel())) 86 | if create_label: 87 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 88 | self.fake_label_var = torch.tensor(fake_tensor, 89 | requires_grad=False) 90 | target_tensor = self.fake_label_var 91 | 92 | if input.is_cuda: 93 | target_tensor = target_tensor.cuda() 94 | return target_tensor 95 | 96 | def __call__(self, input, target_is_real): 97 | if isinstance(input[0], list): 98 | loss = 0 99 | for input_i in input: 100 | pred = input_i[-1] 101 | target_tensor = self.get_target_tensor(pred, target_is_real) 102 | loss += self.loss(pred, target_tensor) 103 | return loss 104 | else: 105 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 106 | return self.loss(input[-1], target_tensor) 107 | 108 | 109 | class Relativistic_Average_LSGAN(GANLoss): 110 | ''' 111 | Relativistic average LSGAN 112 | ''' 113 | 114 | def __call__(self, input_1, input_2, target_is_real): 115 | if isinstance(input_1[0], list): 116 | loss = 0 117 | for input_i, _input_i in zip(input_1, input_2): 118 | pred = input_i[-1] 119 | _pred = _input_i[-1] 120 | target_tensor = self.get_target_tensor(pred, target_is_real) 121 | loss += self.loss(pred - torch.mean(_pred), target_tensor) 122 | return loss 123 | else: 124 | target_tensor = self.get_target_tensor(input_1[-1], target_is_real) 125 | return self.loss(input_1[-1] - torch.mean(input_2[-1]), 126 | target_tensor) 127 | -------------------------------------------------------------------------------- /TemporalAlignment/models/mocoganhd_models.py: -------------------------------------------------------------------------------- 1 | # class that has additional functionalities provided in the mocogan_hd repository 2 | """ 3 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 4 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 5 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 6 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 7 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 8 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 9 | """ 10 | import os 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | 17 | def load_checkpoints(path, gpu): 18 | if gpu is None: 19 | ckpt = torch.load(path) 20 | else: 21 | loc = 'cuda:{}'.format(gpu) 22 | ckpt = torch.load(path, map_location=loc) 23 | return ckpt 24 | 25 | 26 | def model_to_gpu(model, isTrain, gpu): 27 | if isTrain: 28 | if gpu is not None: 29 | model.cuda(gpu) 30 | model = DDP(model, 31 | device_ids=[gpu], 32 | find_unused_parameters=True) 33 | else: 34 | model.cuda() 35 | model = DDP(model, find_unused_parameters=True) 36 | else: 37 | model.cuda() 38 | model = nn.DataParallel(model) 39 | 40 | return model -------------------------------------------------------------------------------- /TemporalAlignment/models/mocoganhd_video_disc.py: -------------------------------------------------------------------------------- 1 | # This is the mocoganhd motion (video) discriminator code 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import functools 6 | 7 | # this is the discriminator for classifying the real/fake video motion (video code) 8 | class ModelD_3d(nn.Module): 9 | def __init__(self, nc, norm_D_3d, num_D, lr, cross_domain, n_frames_G): 10 | super(ModelD_3d, self).__init__() 11 | if cross_domain: 12 | nc = nc 13 | n_frames_G = n_frames_G 14 | else: 15 | nc = nc * 2 16 | n_frames_G = n_frames_G - 1 17 | 18 | self.netD = MultiscaleDiscriminator(input_nc=nc, 19 | n_frames=n_frames_G, 20 | norm_layer=get_norm_layer( 21 | norm_D_3d), 22 | num_D=num_D) 23 | self.netD.apply(weights_init) 24 | 25 | self.optim = torch.optim.Adam(self.netD.parameters(), 26 | lr=lr, 27 | betas=(0.5, 0.999)) 28 | 29 | def forward(self, x): 30 | return self.netD.forward(x) 31 | 32 | 33 | def weights_init(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Conv') != -1 and hasattr(m, 'weight'): 36 | m.weight.data.normal_(0.0, 0.02) 37 | elif classname.find('BatchNorm3d') != -1: 38 | m.weight.data.normal_(1.0, 0.02) 39 | m.bias.data.fill_(0) 40 | 41 | 42 | def get_norm_layer(norm_type='instance'): 43 | if norm_type == 'batch': 44 | norm_layer = functools.partial(nn.BatchNorm3d, affine=True) 45 | elif norm_type == 'instance': 46 | norm_layer = functools.partial(nn.InstanceNorm3d, 47 | affine=False, 48 | track_running_stats=True) 49 | else: 50 | raise NotImplementedError('normalization layer [%s] is not found' % 51 | norm_type) 52 | return norm_layer 53 | 54 | 55 | class MultiscaleDiscriminator(nn.Module): 56 | def __init__(self, 57 | input_nc, 58 | ndf=64, 59 | n_layers=3, 60 | n_frames=16, 61 | norm_layer=nn.InstanceNorm3d, 62 | num_D=2, 63 | getIntermFeat=True): 64 | super(MultiscaleDiscriminator, self).__init__() 65 | self.num_D = num_D 66 | self.n_layers = n_layers 67 | self.getIntermFeat = getIntermFeat 68 | ndf_max = 64 69 | 70 | for i in range(num_D): 71 | netD = NLayerDiscriminator( 72 | input_nc, min(ndf_max, ndf * (2**(num_D - 1 - i))), n_layers, 73 | norm_layer, getIntermFeat) 74 | if getIntermFeat: 75 | for j in range(n_layers + 2): 76 | setattr(self, 'scale' + str(i) + '_layer' + str(j), 77 | getattr(netD, 'model' + str(j))) 78 | else: 79 | setattr(self, 'layer' + str(i), netD.model) 80 | if n_frames > 16: 81 | self.downsample = nn.AvgPool3d(3, 82 | stride=2, 83 | padding=[1, 1, 1], 84 | count_include_pad=False) 85 | else: 86 | self.downsample = nn.AvgPool3d(3, 87 | stride=[1, 2, 2], 88 | padding=[1, 1, 1], 89 | count_include_pad=False) 90 | 91 | def singleD_forward(self, model, input): 92 | if self.getIntermFeat: 93 | result = [input] 94 | for i in range(len(model)): 95 | result.append(model[i](result[-1])) 96 | return result[1:] 97 | else: 98 | return [model(input)] 99 | 100 | def forward(self, input): 101 | num_D = self.num_D 102 | result = [] 103 | input_downsampled = input 104 | for i in range(num_D): 105 | if self.getIntermFeat: 106 | model = [ 107 | getattr(self, 108 | 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) 109 | for j in range(self.n_layers + 2) 110 | ] 111 | else: 112 | model = getattr(self, 'layer' + str(num_D - 1 - i)) 113 | result.append(self.singleD_forward(model, input_downsampled)) 114 | if i != (num_D - 1): 115 | input_downsampled = self.downsample(input_downsampled) 116 | return result 117 | 118 | 119 | class NLayerDiscriminator(nn.Module): 120 | def __init__(self, 121 | input_nc, 122 | ndf=64, 123 | n_layers=3, 124 | norm_layer=nn.InstanceNorm3d, 125 | getIntermFeat=True): 126 | super(NLayerDiscriminator, self).__init__() 127 | self.getIntermFeat = getIntermFeat 128 | self.n_layers = n_layers 129 | 130 | kw = 4 131 | padw = int(np.ceil((kw - 1.0) / 2)) 132 | sequence = [[ 133 | nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 134 | nn.LeakyReLU(0.2, True) 135 | ]] 136 | 137 | nf = ndf 138 | for n in range(1, n_layers): 139 | nf_prev = nf 140 | nf = min(nf * 2, 512) 141 | sequence += [[ 142 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 143 | norm_layer(nf), 144 | nn.LeakyReLU(0.2, True) 145 | ]] 146 | 147 | nf_prev = nf 148 | nf = min(nf * 2, 512) 149 | sequence += [[ 150 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 151 | norm_layer(nf), 152 | nn.LeakyReLU(0.2, True) 153 | ]] 154 | 155 | sequence += [[ 156 | nn.Conv3d(nf, 1, kernel_size=kw, stride=1, padding=padw) 157 | ]] 158 | 159 | if getIntermFeat: 160 | for n in range(len(sequence)): 161 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 162 | else: 163 | sequence_stream = [] 164 | for n in range(len(sequence)): 165 | sequence_stream += sequence[n] 166 | self.model = nn.Sequential(*sequence_stream) 167 | 168 | def forward(self, input): 169 | if self.getIntermFeat: 170 | res = [input] 171 | for n in range(self.n_layers + 2): 172 | model = getattr(self, 'model' + str(n)) 173 | res.append(model(res[-1])) 174 | return res[1:] 175 | else: 176 | return self.model(input) 177 | -------------------------------------------------------------------------------- /TemporalAlignment/models/video_discriminator.py: -------------------------------------------------------------------------------- 1 | # This class implements the Video Discriminator 2 | # Discriminator based on MocoGAN's discriminator 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.utils.data 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | 11 | class Noise(nn.Module): 12 | def __init__(self, use_noise, sigma=0.2): 13 | super(Noise, self).__init__() 14 | self.use_noise = use_noise 15 | self.sigma = sigma 16 | 17 | def forward(self, x): 18 | if self.use_noise: 19 | return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False) 20 | return x 21 | 22 | class VideoDiscriminator(nn.Module): 23 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64): 24 | super(VideoDiscriminator, self).__init__() 25 | 26 | self.n_channels = n_channels 27 | self.n_output_neurons = n_output_neurons 28 | self.use_noise = use_noise 29 | self.bn_use_gamma = bn_use_gamma 30 | 31 | self.main = nn.Sequential( 32 | Noise(use_noise, sigma=noise_sigma), 33 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | 36 | Noise(use_noise, sigma=noise_sigma), 37 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 38 | nn.BatchNorm3d(ndf * 2), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | 41 | Noise(use_noise, sigma=noise_sigma), 42 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 43 | nn.BatchNorm3d(ndf * 4), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | 46 | Noise(use_noise, sigma=noise_sigma), 47 | nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 48 | nn.BatchNorm3d(ndf * 8), 49 | nn.LeakyReLU(0.2, inplace=True), 50 | 51 | nn.Conv3d(ndf * 8, n_output_neurons, 4, 1, 0, bias=False), 52 | ) 53 | 54 | self.linear = nn.Linear(13*13, 1) # add linear layer to deal with different resolution 55 | 56 | def forward(self, input): 57 | h = self.main(input).view(-1) 58 | h = self.linear(h) 59 | return h -------------------------------------------------------------------------------- /TemporalAlignment/perturbations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file generates the paired data by introducing motion errors (aka perturbations) in the source face 3 | The imperfectly blended face is a 6 channel image which the model has to generate to denoise in a temporal fashion 4 | ''' 5 | 6 | import sys 7 | import os 8 | import os.path as osp 9 | import random 10 | from glob import glob 11 | from tqdm import tqdm 12 | from enum import Enum 13 | 14 | from TemporalAlignment.ranges import * 15 | 16 | import torch 17 | import torchvision.transforms as transforms 18 | 19 | import matplotlib.pyplot as plt 20 | from PIL import Image 21 | from wand.image import Image as WandImage 22 | 23 | import cv2 24 | import numpy as np 25 | 26 | ''' 27 | The following perturbation functions are added -- 28 | 1. Affine transformations - translation, rotation, resize 29 | 2. Color transformations - change hue, saturation etc 30 | 3. Non-linear transformations aka distortions - arc, barrel, barrel_inverse 31 | ''' 32 | 33 | ''' 34 | types of distortions 35 | ''' 36 | class Distortion(Enum): 37 | ARC = 1, 38 | BARREL = 2, 39 | BARREL_INVERSE = 3, 40 | 41 | 42 | ''' 43 | horizontal image translation -- affine transformation 44 | ''' 45 | def translate_horizontal(x, image): 46 | M = np.float32([ 47 | [1, 0, x], 48 | [0, 1, 0] 49 | ]) 50 | 51 | shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0])) 52 | return shifted 53 | 54 | ''' 55 | vertical image translation -- affine transformation 56 | ''' 57 | def translate_vertical(y, image): 58 | M = np.float32([ 59 | [1, 0, 0], 60 | [0, 1, y] 61 | ]) 62 | 63 | shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0])) 64 | 65 | return shifted 66 | 67 | ''' 68 | image rotation (clockwise or anti-clockwise) -- affine transformation 69 | ''' 70 | def rotate_image(rotation, image, center=None): 71 | # rotates the image by the center of the image 72 | h, w = image.shape[:2] 73 | 74 | if center is None: 75 | cX, cY = (w//2, h//2) 76 | M = cv2.getRotationMatrix2D((cX, cY), rotation, 1.0) 77 | else: 78 | M = cv2.getRotationMatrix2D(center, rotation, 1.0) 79 | 80 | rotated = cv2.warpAffine(image, M, (w, h)) 81 | 82 | return rotated 83 | 84 | ''' 85 | image resize operations (zoom in and zoom out) -- not used currently 86 | ''' 87 | def resize_image(magnification, image): 88 | res = cv2.resize(image, None, fx=magnification, fy=magnification, interpolation=cv2.INTER_CUBIC) 89 | h, w = image.shape[:2] 90 | 91 | if magnification >= 1: 92 | cX, cY = res.shape[1] // 2, res.shape[0] // 2 93 | left_index = cX - w // 2 94 | upper_index = cY - h // 2 95 | modified_image = res[upper_index : upper_index + h, left_index : left_index + w] 96 | else: 97 | modified_image = np.zeros((image.shape), dtype=np.uint8) 98 | hs, ws = res.shape[:2] 99 | difference_h = h - hs 100 | difference_w = w - ws 101 | left_index = difference_w // 2 102 | upper_index = difference_h // 2 103 | modified_image[upper_index : upper_index + hs, left_index : left_index + ws] = res 104 | 105 | return modified_image 106 | 107 | ''' 108 | applies shear transformation along both the axes -- not used currently 109 | ''' 110 | def shear_image(shear, image): 111 | shear_x, shear_y = shear, shear 112 | M = np.float32([ 113 | [1, shear_x, 0], 114 | [shear_y, 1, 0] 115 | ]) 116 | 117 | sheared = cv2.warpAffine(image, M, (image.shape[1], image.shape[0])) 118 | 119 | return sheared 120 | 121 | ''' 122 | flips the imagein the horizontal direction -- not used currently 123 | ''' 124 | def image_flip(flip_code, image): 125 | flipped_image = cv2.flip(image, int(flip_code)) 126 | return flipped_image 127 | 128 | ''' 129 | applies non-linear transformation (distortion) to the image 130 | ''' 131 | def distort_image(distortion_type, image): 132 | img_file = np.array(image) 133 | with WandImage.from_array(img_file) as img: 134 | ''' 135 | three type of image distortions are supported - i) arc, ii) barrel, and iii) inverse-barrel 136 | ''' 137 | if distortion_type == Distortion.ARC.value: 138 | distortion_angle = random.randint(0, 30) 139 | img.distort('arc', (distortion_angle,)) 140 | img.resize(image.shape[0], image.shape[1]) 141 | 142 | opencv_image = np.array(img) 143 | 144 | elif distortion_type == Distortion.BARREL.value: 145 | a = random.randint(0, 10)/10 146 | b = random.randint(2, 7)/10 147 | c = random.randint(0, 5)/10 148 | d = random.randint(10, 10)/10 149 | args = (a, b, c, d) 150 | img.distort('barrel', (args)) 151 | img.resize(image.shape[0], image.shape[1]) 152 | 153 | opencv_image = np.array(img) 154 | 155 | else: 156 | b = random.randint(0, 2)/10 157 | c = random.randint(-5, 0)/10 158 | d = random.randint(10, 10)/10 159 | args = (0.0, b, c, d) 160 | img.distort('barrel_inverse', (args)) 161 | img.resize(image.shape[0], image.shape[1]) 162 | 163 | opencv_image = np.array(img) 164 | 165 | return opencv_image 166 | 167 | ''' 168 | blend the perturbed image and the masked image 169 | ''' 170 | def combine_images(face_mask, perturbed_image, generate_mask=True): 171 | image_masked = face_mask.copy() 172 | if generate_mask: 173 | mask = perturbed_image[..., 0] != 0 174 | image_masked[mask] = 0 175 | 176 | combined_image = image_masked + perturbed_image 177 | 178 | return combined_image 179 | 180 | ''' 181 | finds the center of the eyes from the face landmarks 182 | ''' 183 | def find_eye_center(shape): 184 | # landmark coordinates corresponding to the eyes 185 | lStart, lEnd = 36, 41 186 | rStart, rEnd = 42, 47 187 | 188 | # landmarks for the left and right eyes 189 | leftEyePoints = shape[lStart:lEnd] 190 | rightEyePoints = shape[rStart:rEnd] 191 | 192 | # compute the center of mass for each of the eyes 193 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int") 194 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int") 195 | 196 | # compute the angle between the eye centroids 197 | dY = rightEyeCenter[1] - leftEyeCenter[1] 198 | dX = rightEyeCenter[0] - leftEyeCenter[0] 199 | angle = np.degrees(np.arctan2(dY, dX)) 200 | 201 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2, 202 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2) 203 | 204 | ''' 205 | applies composite perturbations to a single image 206 | the perturbations to apply are selected randomly 207 | ''' 208 | def perturb_image_composite(face_image, landmark): 209 | perturbation_functions = [ 210 | translate_horizontal, 211 | translate_vertical, 212 | rotate_image, 213 | resize_image, 214 | # shear_image, 215 | # image_flip, 216 | distort_image, 217 | ] 218 | 219 | perturbation_function_map = { 220 | translate_horizontal : [-translation_range, translation_range, 1], 221 | translate_vertical : [-translation_range, translation_range, 1], 222 | rotate_image : [-rotation_range, rotation_range, 1], 223 | resize_image : [scale_ranges[0], scale_ranges[1], 100], 224 | # shear_image : [-10, 10, 100], 225 | # image_flip : [1, 1, 1], 226 | distort_image : [0, len(Distortion), 1], 227 | } 228 | 229 | gt_transformations = { 230 | 'translate_horizontal': 0, 231 | 'translate_vertical': 0, 232 | 'rotate_image': 0 233 | } 234 | 235 | eyes_center = find_eye_center(landmark) 236 | 237 | # maintains the perturbations to apply to the image 238 | composite_perturbations = list() 239 | # ensures atleast one perturbation is produced 240 | while len(composite_perturbations) == 0: 241 | for i, perturbation_function in enumerate(perturbation_functions): 242 | if random.randint(0, 1): 243 | composite_perturbations.append(perturbation_function) 244 | 245 | # print(f'Perturbations applied : {composite_perturbations}', flush=True) 246 | 247 | for perturbation_function in composite_perturbations: 248 | perturbation_map = perturbation_function_map[perturbation_function] 249 | perturbation_value = random.randint(perturbation_map[0], perturbation_map[1])/perturbation_map[2] 250 | normalized_value = perturbation_value/perturbation_map[1] 251 | 252 | if perturbation_function == translate_horizontal: 253 | gt_transformations['translate_horizontal'] = perturbation_value 254 | elif perturbation_function == translate_vertical: 255 | gt_transformations['translate_vertical'] = perturbation_value 256 | else: 257 | gt_transformations['rotate_image'] = perturbation_value 258 | 259 | if perturbation_function == rotate_image: 260 | face_image = perturbation_function(perturbation_value, face_image, center=eyes_center) 261 | else: 262 | face_image = perturbation_function(perturbation_value, face_image) 263 | 264 | return face_image, gt_transformations 265 | 266 | ''' 267 | the function specifies the perturbation functions and the amounts by which corresponding perturbations are applied 268 | the perturbed image (foreground) is combined with the face_mask (background) to reproduce blending seen due to two different flows 269 | applies multiple perturbations (composite perturbations) to generate complex perturbations 270 | ''' 271 | def perturb_image(face_image): 272 | perturbation_functions = [ 273 | translate_horizontal, 274 | translate_vertical, 275 | rotate_image, 276 | resize_image 277 | ] 278 | 279 | perturbation_function_map = { 280 | translate_horizontal : [-20, 20, 1], 281 | translate_vertical : [-20, 20, 1], 282 | rotate_image : [-25, 25, 1], 283 | resize_image : [90, 110, 100] 284 | } 285 | 286 | random_perturbation_index = random.randint(0, len(perturbation_functions)-1) 287 | # random_perturbation_index = 0 # used for debugging 288 | perturbation_function = perturbation_functions[random_perturbation_index] 289 | perturbation_map = perturbation_function_map[perturbation_function] 290 | perturbation_value = random.randint(perturbation_map[0], perturbation_map[1])/perturbation_map[2] 291 | # print(f'Using perturbation : {random_perturbation_index}, with value : {perturbation_value}', flush=True) 292 | intermediate_perturbed_image = perturbation_function(perturbation_value, face_image) 293 | # perturbed_image = combine_images(face_mask, intermediate_perturbed_image) 294 | 295 | return intermediate_perturbed_image 296 | 297 | ''' 298 | def test_sample(): 299 | file = '/ssd_scratch/cvit/aditya1/CelebA-HQ-img/13842.jpg' 300 | gpu_id = 0 301 | parsing, image = generate_segmentation(file, gpu_id) 302 | for i in range(PERTURBATIONS_PER_IDENTITY): 303 | face_image, background_image = generate_segmented_face(parsing, image) 304 | perturbed_image = perturb_image_composite(face_image, background_image) 305 | perturbed_filename, extension = osp.basename(file).split('.') 306 | perturbed_image_path = osp.join(perturbed_image_dir, perturbed_filename + '_' + str(i) + '.' + extension) 307 | save_image(perturbed_image_path, perturbed_image) 308 | ''' -------------------------------------------------------------------------------- /TemporalAlignment/ranges.py: -------------------------------------------------------------------------------- 1 | translation_range = 3 2 | rotation_range = 3 3 | scale_ranges = 90,110 -------------------------------------------------------------------------------- /bad_mp4s.json: -------------------------------------------------------------------------------- 1 | ["iOqgDmaYdqA/00028.mp4", "iOqgDmaYdqA/00190.mp4", "iOqgDmaYdqA/00027.mp4", "iOqgDmaYdqA/00000.mp4", "y9lSgYW2ObQ/00036.mp4", "y9lSgYW2ObQ/00041.mp4", "XGZMe5DsnRU/00000.mp4", "yYmSVsMXve0/00017.mp4", "yYmSVsMXve0/00028.mp4", "yYmSVsMXve0/00013.mp4", "yYmSVsMXve0/00047.mp4", "yYmSVsMXve0/00008.mp4", "yYmSVsMXve0/00007.mp4", "yYmSVsMXve0/00015.mp4", "yYmSVsMXve0/00009.mp4", "yYmSVsMXve0/00010.mp4", "TZpP05GHWnw/00191.mp4", "TZpP05GHWnw/00175.mp4", "TZpP05GHWnw/00196.mp4", "TZpP05GHWnw/00192.mp4", "TZpP05GHWnw/00170.mp4", "TZpP05GHWnw/00246.mp4", "TZpP05GHWnw/00206.mp4", "TZpP05GHWnw/00189.mp4", "TZpP05GHWnw/00211.mp4", "TZpP05GHWnw/00188.mp4", "TZpP05GHWnw/00187.mp4", "TZpP05GHWnw/00247.mp4", "TZpP05GHWnw/00186.mp4", "TZpP05GHWnw/00245.mp4", "TZpP05GHWnw/00207.mp4", "TZpP05GHWnw/00143.mp4", "TZpP05GHWnw/00171.mp4", "TZpP05GHWnw/00214.mp4", "TZpP05GHWnw/00227.mp4", "TZpP05GHWnw/00212.mp4", "TZpP05GHWnw/00205.mp4", "TZpP05GHWnw/00152.mp4", "TZpP05GHWnw/00193.mp4", "IGl4w4a_pcw/00073.mp4", "IGl4w4a_pcw/00108.mp4", "IGl4w4a_pcw/00055.mp4", "IGl4w4a_pcw/00047.mp4", "IGl4w4a_pcw/00066.mp4", "IGl4w4a_pcw/00107.mp4", "IGl4w4a_pcw/00057.mp4", "IGl4w4a_pcw/00048.mp4", "IGl4w4a_pcw/00051.mp4", "IGl4w4a_pcw/00049.mp4", "IGl4w4a_pcw/00056.mp4", "IGl4w4a_pcw/00000.mp4", "IGl4w4a_pcw/00050.mp4", "IGl4w4a_pcw/00212.mp4", "UjX-4h6UahQ/00005.mp4", "UjX-4h6UahQ/00002.mp4", "UjX-4h6UahQ/00001.mp4", "UjX-4h6UahQ/00000.mp4", "WOEcMOEmbC8/00037.mp4", "WOEcMOEmbC8/00031.mp4", "ozEAkq3qSi0/00063.mp4", "ozEAkq3qSi0/00273.mp4", "ozEAkq3qSi0/00175.mp4", "ozEAkq3qSi0/00122.mp4", "ozEAkq3qSi0/00275.mp4", "ozEAkq3qSi0/00095.mp4", "ozEAkq3qSi0/00100.mp4", "ozEAkq3qSi0/00153.mp4", "ozEAkq3qSi0/00277.mp4", "ozEAkq3qSi0/00204.mp4", "ozEAkq3qSi0/00278.mp4", "ozEAkq3qSi0/00272.mp4", "ozEAkq3qSi0/00174.mp4", "ozEAkq3qSi0/00276.mp4", "ozEAkq3qSi0/00205.mp4", "ozEAkq3qSi0/00152.mp4", "ozEAkq3qSi0/00202.mp4", "0uNEEfQKfwk/00112.mp4", "0uNEEfQKfwk/00298.mp4", "0uNEEfQKfwk/00106.mp4", "0uNEEfQKfwk/00318.mp4", "0uNEEfQKfwk/00174.mp4", "0uNEEfQKfwk/00276.mp4", "0uNEEfQKfwk/00297.mp4", "0uNEEfQKfwk/00000.mp4", "c3tTE3hzON8/00174.mp4", "dUrBaT_ChGc/00007.mp4", "dUrBaT_ChGc/00029.mp4", "C93trWdL3Rg/00222.mp4", "C93trWdL3Rg/00191.mp4", "C93trWdL3Rg/00200.mp4", "C93trWdL3Rg/00017.mp4", "C93trWdL3Rg/00135.mp4", "C93trWdL3Rg/00072.mp4", "C93trWdL3Rg/00185.mp4", "C93trWdL3Rg/00131.mp4", "C93trWdL3Rg/00067.mp4", "C93trWdL3Rg/00016.mp4", "C93trWdL3Rg/00073.mp4", "C93trWdL3Rg/00175.mp4", "C93trWdL3Rg/00108.mp4", "C93trWdL3Rg/00215.mp4", "C93trWdL3Rg/00055.mp4", "C93trWdL3Rg/00104.mp4", "C93trWdL3Rg/00047.mp4", "C93trWdL3Rg/00181.mp4", "C93trWdL3Rg/00094.mp4", "C93trWdL3Rg/00122.mp4", "C93trWdL3Rg/00074.mp4", "C93trWdL3Rg/00161.mp4", "C93trWdL3Rg/00225.mp4", "C93trWdL3Rg/00142.mp4", "C93trWdL3Rg/00224.mp4", "C93trWdL3Rg/00066.mp4", "C93trWdL3Rg/00130.mp4", "C93trWdL3Rg/00071.mp4", "C93trWdL3Rg/00125.mp4", "C93trWdL3Rg/00095.mp4", "C93trWdL3Rg/00190.mp4", "C93trWdL3Rg/00037.mp4", "C93trWdL3Rg/00206.mp4", "C93trWdL3Rg/00110.mp4", "C93trWdL3Rg/00134.mp4", "C93trWdL3Rg/00132.mp4", "C93trWdL3Rg/00042.mp4", "C93trWdL3Rg/00179.mp4", "C93trWdL3Rg/00219.mp4", "C93trWdL3Rg/00189.mp4", "C93trWdL3Rg/00029.mp4", "C93trWdL3Rg/00209.mp4", "C93trWdL3Rg/00048.mp4", "C93trWdL3Rg/00041.mp4", "C93trWdL3Rg/00211.mp4", "C93trWdL3Rg/00051.mp4", "C93trWdL3Rg/00201.mp4", "C93trWdL3Rg/00176.mp4", "C93trWdL3Rg/00147.mp4", "C93trWdL3Rg/00081.mp4", "C93trWdL3Rg/00083.mp4", "C93trWdL3Rg/00006.mp4", "C93trWdL3Rg/00154.mp4", "C93trWdL3Rg/00140.mp4", "C93trWdL3Rg/00045.mp4", "C93trWdL3Rg/00024.mp4", "C93trWdL3Rg/00168.mp4", "C93trWdL3Rg/00049.mp4", "C93trWdL3Rg/00156.mp4", "C93trWdL3Rg/00136.mp4", "C93trWdL3Rg/00188.mp4", "C93trWdL3Rg/00054.mp4", "C93trWdL3Rg/00009.mp4", "C93trWdL3Rg/00230.mp4", "C93trWdL3Rg/00160.mp4", "C93trWdL3Rg/00027.mp4", "C93trWdL3Rg/00164.mp4", "C93trWdL3Rg/00084.mp4", "C93trWdL3Rg/00187.mp4", "C93trWdL3Rg/00150.mp4", "C93trWdL3Rg/00091.mp4", "C93trWdL3Rg/00121.mp4", "C93trWdL3Rg/00204.mp4", "C93trWdL3Rg/00207.mp4", "C93trWdL3Rg/00031.mp4", "C93trWdL3Rg/00022.mp4", "C93trWdL3Rg/00059.mp4", "C93trWdL3Rg/00174.mp4", "C93trWdL3Rg/00010.mp4", "C93trWdL3Rg/00109.mp4", "C93trWdL3Rg/00098.mp4", "C93trWdL3Rg/00220.mp4", "C93trWdL3Rg/00143.mp4", "C93trWdL3Rg/00118.mp4", "C93trWdL3Rg/00228.mp4", "C93trWdL3Rg/00018.mp4", "C93trWdL3Rg/00119.mp4", "C93trWdL3Rg/00033.mp4", "C93trWdL3Rg/00056.mp4", "C93trWdL3Rg/00167.mp4", "C93trWdL3Rg/00138.mp4", "C93trWdL3Rg/00169.mp4", "C93trWdL3Rg/00025.mp4", "C93trWdL3Rg/00208.mp4", "C93trWdL3Rg/00030.mp4", "C93trWdL3Rg/00082.mp4", "C93trWdL3Rg/00014.mp4", "C93trWdL3Rg/00068.mp4", "C93trWdL3Rg/00114.mp4", "C93trWdL3Rg/00182.mp4", "C93trWdL3Rg/00226.mp4", "C93trWdL3Rg/00078.mp4", "C93trWdL3Rg/00145.mp4", "C93trWdL3Rg/00227.mp4", "C93trWdL3Rg/00127.mp4", "C93trWdL3Rg/00050.mp4", "C93trWdL3Rg/00178.mp4", "C93trWdL3Rg/00103.mp4", "C93trWdL3Rg/00039.mp4", "C93trWdL3Rg/00026.mp4", "C93trWdL3Rg/00124.mp4", "C93trWdL3Rg/00123.mp4", "C93trWdL3Rg/00229.mp4", "C93trWdL3Rg/00035.mp4", "C93trWdL3Rg/00205.mp4", "C93trWdL3Rg/00180.mp4", "JGCg3ARv14U/00058.mp4", "JGCg3ARv14U/00029.mp4", "JGCg3ARv14U/00059.mp4", "JGCg3ARv14U/00000.mp4", "MTgyYqY3XTE/00181.mp4", "MTgyYqY3XTE/00119.mp4", "MTgyYqY3XTE/00184.mp4", "MTgyYqY3XTE/00001.mp4", "MTgyYqY3XTE/00000.mp4", "wuS8Hq1Yhro/00017.mp4", "wuS8Hq1Yhro/00023.mp4", "wuS8Hq1Yhro/00046.mp4", "wuS8Hq1Yhro/00016.mp4", "wuS8Hq1Yhro/00044.mp4", "wuS8Hq1Yhro/00028.mp4", "wuS8Hq1Yhro/00013.mp4", "wuS8Hq1Yhro/00047.mp4", "wuS8Hq1Yhro/00012.mp4", "wuS8Hq1Yhro/00008.mp4", "wuS8Hq1Yhro/00032.mp4", "wuS8Hq1Yhro/00037.mp4", "wuS8Hq1Yhro/00036.mp4", "wuS8Hq1Yhro/00042.mp4", "wuS8Hq1Yhro/00007.mp4", "wuS8Hq1Yhro/00034.mp4", "wuS8Hq1Yhro/00029.mp4", "wuS8Hq1Yhro/00015.mp4", "wuS8Hq1Yhro/00048.mp4", "wuS8Hq1Yhro/00020.mp4", "wuS8Hq1Yhro/00004.mp4", "wuS8Hq1Yhro/00041.mp4", "wuS8Hq1Yhro/00006.mp4", "wuS8Hq1Yhro/00005.mp4", "wuS8Hq1Yhro/00045.mp4", "wuS8Hq1Yhro/00024.mp4", "wuS8Hq1Yhro/00019.mp4", "wuS8Hq1Yhro/00009.mp4", "wuS8Hq1Yhro/00027.mp4", "wuS8Hq1Yhro/00043.mp4", "wuS8Hq1Yhro/00031.mp4", "wuS8Hq1Yhro/00022.mp4", "wuS8Hq1Yhro/00040.mp4", "wuS8Hq1Yhro/00010.mp4", "wuS8Hq1Yhro/00018.mp4", "wuS8Hq1Yhro/00002.mp4", "wuS8Hq1Yhro/00033.mp4", "wuS8Hq1Yhro/00011.mp4", "wuS8Hq1Yhro/00025.mp4", "wuS8Hq1Yhro/00030.mp4", "wuS8Hq1Yhro/00014.mp4", "wuS8Hq1Yhro/00001.mp4", "wuS8Hq1Yhro/00038.mp4", "wuS8Hq1Yhro/00000.mp4", "wuS8Hq1Yhro/00039.mp4", "wuS8Hq1Yhro/00026.mp4", "wuS8Hq1Yhro/00035.mp4", "wuS8Hq1Yhro/00021.mp4", "wuS8Hq1Yhro/00003.mp4"] -------------------------------------------------------------------------------- /bash_scripts/train_videovqvae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=acm_abl_wo_temporalmodule 4 | #SBATCH --mem-per-cpu=2048 5 | #SBATCH --partition long 6 | #SBATCH --account research 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --mincpus=38 9 | #SBATCH --nodes=1 10 | #SBATCH --time 4-00:00:00 11 | #SBATCH --signal=B:HUP@600 12 | #SBATCH -w gnode031 13 | 14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact 15 | 16 | source /home2/aditya1/miniconda3/bin/activate base 17 | 18 | CUDA_VISIBLE_DEVICES=0 python train_vqvae_perceptual.py --epoch 1000 --colorjit const --batch_size 1 -------------------------------------------------------------------------------- /bash_scripts/train_videovqvae_perturbations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=acmabl_wo_translation 4 | #SBATCH --mem-per-cpu=2048 5 | #SBATCH --partition long 6 | #SBATCH --account cvit_bhaasha 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --mincpus=38 9 | #SBATCH --nodes=1 10 | #SBATCH --time 10-00:00:00 11 | #SBATCH --signal=B:HUP@600 12 | #SBATCH -w gnode061 13 | 14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact 15 | 16 | source /home2/aditya1/miniconda3/bin/activate base 17 | 18 | CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 \ 19 | python train_vqvae_perceptual.py --epoch 1000 \ 20 | --colorjit const --batch_size 1 \ 21 | --sample_folder /ssd_scratch/cvit/aditya1/video_vqvae2_results/ablation_translation_disabled_samples \ 22 | --checkpoint_dir /ssd_scratch/cvit/aditya1/video_vqvae2_results/ablation_translation_disabled_checkpoints -------------------------------------------------------------------------------- /bash_scripts/train_videovqvae_resume_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=acmwotemporal_abl 4 | #SBATCH --mem-per-cpu=2048 5 | #SBATCH --partition long 6 | #SBATCH --account research 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --mincpus=38 9 | #SBATCH --nodes=1 10 | #SBATCH --time 4-00:00:00 11 | #SBATCH --signal=B:HUP@600 12 | #SBATCH -w gnode031 13 | 14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact 15 | 16 | source /home2/aditya1/miniconda3/bin/activate base 17 | 18 | CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 python train_vqvae_perceptual.py \ 19 | --epoch 1000 --colorjit const --batch_size 1 \ 20 | --ckpt /ssd_scratch/cvit/aditya1/video_vqvae2_results/checkpoint_iquff/vqvae_7_5121.pt \ 21 | --sample_folder /ssd_scratch/cvit/aditya1/video_vqvae2_results/without_temporal_resume_samples \ 22 | --checkpoint_dir /ssd_scratch/cvit/aditya1/video_vqvae2_results/without_temporal_resume_checkpoints -------------------------------------------------------------------------------- /command.txt: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python train_video_vqvae.py --epoch 20000 --ckpt vqvae_068.pt 2 | 3 | CUDA_VISIBLE_DEVICES=1 python train_video_vqvae.py --epoch 20000 --max_frame_len 8 --ckpt checkpoint_target_all5losses/vqvae_501.pt,checkpoint_target_all5losses/adversarial_501.pt --validate_at 2048 --epoch 200000 4 | 5 | CUDA_VISIBLE_DEVICES=0 python train_video_vqvae.py --epoch 20000 --max_frame_len 8 --ckpt checkpoint_target_all5losses/vqvae_5301.pt,checkpoint_target_all5losses/adversarial_5301.pt --validate_at 2048 --epoch 200000 6 | 7 | CUDA_VISIBLE_DEVICES=2 python train_video_vqvae.py --max_frame_len 8 --ckpt checkpoint_vlog_all5losses/vqvae_5401.pt,checkpoint_vlog_all5losses/adversarial_5401.pt 8 | 9 | 10 | CUDA_VISIBLE_DEVICES=2 python train_video_vqvae_nta.py --max_frame_len 8 --ckpt multi_005.pt --validate_at 512 --epoch 200000 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Config file containing the hyperparameters 3 | ''' 4 | DATASET = 11 5 | LATENT_LOSS_WEIGHT = 1 # hyperparameter weight for the latent loss 6 | PERCEPTUAL_LOSS_WEIGHT = 1 # hyperparameter weight for the perceptual loss 7 | 8 | # weights for the mocoganhd discriminator 9 | G_LOSS_2D_WEIGHT = 0.25 10 | G_LOSS_3D_WEIGHT = 0.25 11 | 12 | image_disc_weight = 0.5 13 | video_disc_weight = 0.5 14 | 15 | D_LOSS_WEIGHT = 0.1 16 | 17 | SAMPLE_SIZE_FOR_VISUALIZATION = 8 18 | DISC_LOSS_WEIGHT = 0.25 # TODO - modify -------------------------------------------------------------------------------- /copy.sh: -------------------------------------------------------------------------------- 1 | cp /home2/aditya1/cvit/content_sync/FaceOff/*.py . 2 | cp /home2/aditya1/cvit/content_sync/FaceOff/datasets/*.py datasets/. 3 | cp /home2/aditya1/cvit/content_sync/FaceOff/disc_trainers/*.py disc_trainers/. 4 | cp /home2/aditya1/cvit/content_sync/FaceOff/models/*.py models/. 5 | cp -r /home2/aditya1/cvit/content_sync/FaceOff/TemporalAlignment/* TemporalAlignment/. -------------------------------------------------------------------------------- /datasets/face_translation_videos3_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | from glob import glob 4 | import os 5 | import json 6 | import os.path as osp 7 | 8 | from skimage import io 9 | from skimage.transform import resize 10 | from scipy.ndimage import laplace 11 | import numpy as np 12 | import cv2 13 | from skimage import transform as tf 14 | 15 | from scipy.ndimage import laplace 16 | 17 | target_without_face_apply = True 18 | 19 | def resize_frame(frame, resize_dim=256): 20 | h, w, _ = frame.shape 21 | 22 | if h > w: 23 | padw, padh = (h-w)//2, 0 24 | else: 25 | padw, padh = 0, (w-h)//2 26 | 27 | padded = cv2.copyMakeBorder(frame, padh, padh, padw, padw, cv2.BORDER_CONSTANT, value=0) 28 | padded = cv2.resize(padded, (resize_dim, resize_dim), interpolation=cv2.INTER_LINEAR) 29 | 30 | return padded 31 | 32 | def readPoints(nparray) : 33 | points = [] 34 | 35 | for row in nparray: 36 | x, y = row[0], row[1] 37 | points.append((int(x), int(y))) 38 | 39 | return points 40 | 41 | def generate_convex_hull(img, points): 42 | # points = np.load(landmark_path, allow_pickle=True)['landmark'].astype(np.uint8) 43 | points = readPoints(points) 44 | 45 | hull = [] 46 | hullIndex = cv2.convexHull(np.array(points), returnPoints = False) 47 | 48 | for i in range(0, len(hullIndex)): 49 | hull.append(points[int(hullIndex[i])]) 50 | 51 | sizeImg = img.shape 52 | rect = (0, 0, sizeImg[1], sizeImg[0]) 53 | 54 | hull8U = [] 55 | for i in range(0, len(hull)): 56 | hull8U.append((hull[i][0], hull[i][1])) 57 | 58 | mask = np.zeros(img.shape, dtype = img.dtype) 59 | 60 | cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255)) 61 | 62 | # convex_face = ((mask/255.) * img).astype(np.uint8) 63 | 64 | return mask 65 | 66 | def enlarge_mask(img_mask, enlargement=5): 67 | img1 = img_mask.copy() 68 | img = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 69 | 70 | ret, thresh = cv2.threshold(img,50,255,0) 71 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 72 | 73 | for i in range(len(contours)): 74 | hull = cv2.convexHull(contours[i]) 75 | cv2.drawContours(img1, [hull], -1, (255, 255, 255), enlargement) 76 | 77 | return img1 78 | 79 | def poisson_blend(target_img, src_img, mask_img, iter: int = 1024): 80 | for _ in range(iter): 81 | target_img = target_img + 0.25 * mask_img * laplace(target_img - src_img) 82 | return target_img.clip(0, 1) 83 | 84 | # -- Face Transformation 85 | def warp_img(src, dst, img, std_size): 86 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix 87 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image 88 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 89 | warped = warped.astype('uint8') 90 | return warped, tform 91 | 92 | def apply_transform(transform, img, std_size): 93 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) 94 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 95 | warped = warped.astype('uint8') 96 | return warped 97 | 98 | # Method to combine the face segmented with the face segmentation mask 99 | def combine_images(face_mask, face_image, generate_mask=True): 100 | image_masked = face_mask.copy() 101 | if generate_mask: 102 | mask = face_image[..., 0] != 0 103 | image_masked[mask] = 0 104 | 105 | combined_image = image_masked + face_image 106 | 107 | return combined_image 108 | 109 | # computes the rotation of the face using the angle of the line connecting the eye centroids 110 | def compute_rotation(shape): 111 | # landmark coordinates corresponding to the eyes 112 | lStart, lEnd = 36, 41 113 | rStart, rEnd = 42, 47 114 | 115 | # landmarks for the left and right eyes 116 | leftEyePoints = shape[lStart:lEnd] 117 | rightEyePoints = shape[rStart:rEnd] 118 | 119 | # compute the center of mass for each of the eyes 120 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int") 121 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int") 122 | 123 | # compute the angle between the eye centroids 124 | dY = rightEyeCenter[1] - leftEyeCenter[1] 125 | dX = rightEyeCenter[0] - leftEyeCenter[0] 126 | angle = np.degrees(np.arctan2(dY, dX)) 127 | 128 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2, 129 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2) 130 | 131 | dist = np.sqrt((dX ** 2) + (dY ** 2)) # this indicates the distance between the two eyes 132 | 133 | return angle, eyesCenter, dist 134 | 135 | def apply_mask(mask, image): 136 | return ((mask / 255.) * image).astype(np.uint8) 137 | 138 | # code to generate the alignment between the source and the target image 139 | def generate_warped_image(source_landmark_npz, target_landmark_npz, 140 | source_image_path, target_image_path, 141 | poisson_blend_required = False, 142 | require_full_mask = False): 143 | 144 | stablePoints = [33, 36, 39, 42, 45] 145 | std_size = (256, 256) 146 | 147 | source_image = resize_frame(io.imread(source_image_path)) 148 | target_image = resize_frame(io.imread(target_image_path)) 149 | 150 | source_landmarks = np.load(source_landmark_npz)['landmark'] 151 | 152 | target_landmarks = np.load(target_landmark_npz)['landmark'] 153 | 154 | if require_full_mask: 155 | source_convex_mask = generate_convex_hull(source_image, source_landmarks) 156 | source_convex_mask_no_enlargement = source_convex_mask.copy() 157 | else: 158 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[17:]) 159 | # enlarge the convex mask 160 | source_convex_mask_no_enlargement = source_convex_mask.copy() 161 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=10) 162 | 163 | # apply the convex mask to the face 164 | source_face_segmented = apply_mask(source_convex_mask, source_image) 165 | source_face_transformed, transformation = warp_img(source_landmarks[stablePoints, :], 166 | target_landmarks[stablePoints, :], 167 | source_face_segmented, 168 | std_size) 169 | 170 | source_convex_mask_transformed = apply_transform(transformation, source_convex_mask, std_size) 171 | source_convex_mask_no_enlargement_transformed = apply_transform(transformation, source_convex_mask_no_enlargement, std_size) 172 | source_image_transformed = apply_transform(transformation, source_image, std_size) 173 | 174 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks)) 175 | # target_background = apply_mask(target_convex_mask, target_image) 176 | 177 | target_convex_mask_without_jaw = generate_convex_hull(target_image, target_landmarks[17:]) 178 | target_convex_mask_without_jaw = enlarge_mask(target_convex_mask_without_jaw, enlargement=10) 179 | 180 | target_convex_mask_without_jaw = np.invert(target_convex_mask_without_jaw) 181 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image) 182 | target_without_face = apply_mask(target_convex_mask, target_image) 183 | 184 | if poisson_blend_required: 185 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.) 186 | else: 187 | if target_without_face_apply: 188 | combined_image = combine_images(target_without_face, source_face_transformed) 189 | else: 190 | combined_image = combine_images(target_image, source_face_transformed) 191 | # apply the transformed convex mask to the target face for sanity 192 | # target_masked = apply_mask(source_convex_mask_transformed, target_image) 193 | 194 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, source_convex_mask_no_enlargement, target_image, target_convex_mask, combined_image, target_without_face_features, source_image 195 | 196 | # code to generate the alignment between the source and the target image 197 | def generate_aligned_image(source_landmark_npz, target_landmark_npz, 198 | source_image_path, target_image_path, 199 | poisson_blend_required = False, 200 | require_full_mask = False): 201 | 202 | source_image = resize_frame(io.imread(source_image_path)) 203 | target_image = resize_frame(io.imread(target_image_path)) 204 | 205 | source_landmarks = np.load(source_landmark_npz)['landmark'] 206 | source_rotation, source_center, source_distance = compute_rotation(source_landmarks) 207 | 208 | target_landmarks = np.load(target_landmark_npz)['landmark'] 209 | target_rotation, target_center, target_distance = compute_rotation(target_landmarks) 210 | 211 | # rotation of the target conditioned on the source orientation 212 | target_conditioned_source_rotation = source_rotation - target_rotation 213 | 214 | # calculate the scaling that needs to be applied on the source image 215 | scaling = target_distance / source_distance 216 | 217 | # apply the rotation on the source image 218 | height, width = 256, 256 219 | # print(f'Angle of rotation is : {target_conditioned_source_rotation}') 220 | rotate_matrix = cv2.getRotationMatrix2D(center=source_center, angle=target_conditioned_source_rotation, scale=scaling) 221 | 222 | # calculate the translation component of the matrix M 223 | rotate_matrix[0, 2] += (target_center[0] - source_center[0]) 224 | rotate_matrix[1, 2] += (target_center[1] - source_center[1]) 225 | 226 | if require_full_mask: 227 | source_convex_mask = generate_convex_hull(source_image, source_landmarks) 228 | else: 229 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[17:]) 230 | # enlarge the convex mask using the enlargement 231 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=5) 232 | 233 | # apply the convex mask to the face 234 | source_face_segmented = apply_mask(source_convex_mask, source_image) 235 | source_face_transformed = cv2.warpAffine(source_face_segmented, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 236 | source_convex_mask_transformed = cv2.warpAffine(source_convex_mask, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 237 | source_image_transformed = cv2.warpAffine(source_image, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 238 | 239 | # used for computing the target background 240 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks)) 241 | # target_background = ((target_convex_mask/255.)*target_image).astype(np.uint8) 242 | target_convex_mask_without_jaw = np.invert(generate_convex_hull(target_image, target_landmarks[17:])) 243 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image) 244 | target_without_face = apply_mask(target_convex_mask, target_image) 245 | 246 | if poisson_blend_required: 247 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.) 248 | else: 249 | if target_without_face_apply: 250 | combined_image = combine_images(target_without_face, source_face_transformed) 251 | else: 252 | combined_image = combine_images(target_image, source_face_transformed) 253 | 254 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, target_image, target_convex_mask, combined_image 255 | -------------------------------------------------------------------------------- /datasets/face_translation_videos3_utils_bb.py: -------------------------------------------------------------------------------- 1 | ''' 2 | For generating the masked region, use the bounding box around the landmarks 3 | ''' 4 | 5 | import sys 6 | import random 7 | from glob import glob 8 | import os 9 | import json 10 | import os.path as osp 11 | 12 | import cv2 13 | import numpy as np 14 | from skimage import io 15 | from skimage.transform import resize 16 | from scipy.ndimage import laplace 17 | from skimage import transform as tf 18 | 19 | from scipy.ndimage import laplace 20 | 21 | target_without_face_apply = True 22 | 23 | requires_bb = False 24 | extract_lip_region = False 25 | 26 | ''' 27 | coordinates of the lip region 28 | ''' 29 | if extract_lip_region: 30 | start_idx = 49 31 | end_idx = 61 32 | else: 33 | start_idx = 0 34 | end_idx = 67 35 | 36 | ''' 37 | resize the frame 38 | the learned transformation is a combination of affine, non-linear, and color transformations 39 | ''' 40 | def resize_frame(frame, resize_dim=256): 41 | h, w, _ = frame.shape 42 | 43 | if h > w: 44 | padw, padh = (h-w)//2, 0 45 | else: 46 | padw, padh = 0, (w-h)//2 47 | 48 | padded = cv2.copyMakeBorder(frame, padh, padh, padw, padw, cv2.BORDER_CONSTANT, value=0) 49 | padded = cv2.resize(padded, (resize_dim, resize_dim), interpolation=cv2.INTER_LINEAR) 50 | 51 | return padded 52 | 53 | def readPoints(nparray) : 54 | points = [] 55 | 56 | for row in nparray: 57 | x, y = row[0], row[1] 58 | points.append((int(x), int(y))) 59 | 60 | return points 61 | 62 | ''' 63 | generate the convex hull based on the coordinates of the bounding box 64 | ''' 65 | def generate_convex_hull_bb(img, points): 66 | mask = np.zeros(img.shape, dtype=img.dtype) 67 | min_x, max_x, min_y, max_y = estimate_bb_coordinates(points) 68 | mask[min_y:max_y, min_x:max_x] = 255 69 | 70 | return mask 71 | 72 | ''' 73 | generates the convex hull mask of the face 74 | ''' 75 | def generate_convex_hull(img, points): 76 | # points = np.load(landmark_path, allow_pickle=True)['landmark'].astype(np.uint8) 77 | points = readPoints(points) 78 | 79 | hull = [] 80 | hullIndex = cv2.convexHull(np.array(points), returnPoints = False) 81 | 82 | for i in range(0, len(hullIndex)): 83 | hull.append(points[int(hullIndex[i])]) 84 | 85 | sizeImg = img.shape 86 | rect = (0, 0, sizeImg[1], sizeImg[0]) 87 | 88 | hull8U = [] 89 | for i in range(0, len(hull)): 90 | hull8U.append((hull[i][0], hull[i][1])) 91 | 92 | mask = np.zeros(img.shape, dtype = img.dtype) 93 | 94 | cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255)) 95 | 96 | return mask 97 | 98 | ''' 99 | Enlarge the face mask to accomodate the region close to the face boundary 100 | ''' 101 | def enlarge_mask(img_mask, enlargement=5): 102 | img1 = img_mask.copy() 103 | img = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 104 | 105 | ret, thresh = cv2.threshold(img,50,255,0) 106 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 107 | 108 | for i in range(len(contours)): 109 | hull = cv2.convexHull(contours[i]) 110 | cv2.drawContours(img1, [hull], -1, (255, 255, 255), enlargement) 111 | 112 | return img1 113 | 114 | ''' 115 | Heuristic blending of the face images using poisson blend 116 | ''' 117 | def poisson_blend(target_img, src_img, mask_img, iter: int = 1024): 118 | for _ in range(iter): 119 | target_img = target_img + 0.25 * mask_img * laplace(target_img - src_img) 120 | return target_img.clip(0, 1) 121 | 122 | # -- Face Transformation 123 | def warp_img(src, dst, img, std_size): 124 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix 125 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image 126 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 127 | warped = warped.astype('uint8') 128 | return warped, tform 129 | 130 | 131 | def apply_transform(transform, img, std_size): 132 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) 133 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 134 | warped = warped.astype('uint8') 135 | return warped 136 | 137 | ''' 138 | Combines the segmented face image (foreground) with the segmented face mask (background) 139 | ''' 140 | # Method to combine the face segmented with the face segmentation mask 141 | def combine_images(face_mask, face_image, generate_mask=True): 142 | image_masked = face_mask.copy() 143 | if generate_mask: 144 | mask = face_image[..., 0] != 0 145 | image_masked[mask] = 0 146 | 147 | combined_image = image_masked + face_image 148 | 149 | return combined_image 150 | 151 | 152 | ''' 153 | Computes the rotation of the face w.r.t the horizontal 154 | The angle of the line connecting the eye centroids is estimated 155 | ''' 156 | def compute_rotation(shape): 157 | # landmark coordinates corresponding to the eyes 158 | lStart, lEnd = 36, 41 159 | rStart, rEnd = 42, 47 160 | 161 | # landmarks for the left and right eyes 162 | leftEyePoints = shape[lStart:lEnd] 163 | rightEyePoints = shape[rStart:rEnd] 164 | 165 | # compute the center of mass for each of the eyes 166 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int") 167 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int") 168 | 169 | # compute the angle between the eye centroids 170 | dY = rightEyeCenter[1] - leftEyeCenter[1] 171 | dX = rightEyeCenter[0] - leftEyeCenter[0] 172 | angle = np.degrees(np.arctan2(dY, dX)) 173 | 174 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2, 175 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2) 176 | 177 | dist = np.sqrt((dX ** 2) + (dY ** 2)) # this indicates the distance between the two eyes 178 | 179 | return angle, eyesCenter, dist 180 | 181 | ''' 182 | Apply the mask on the input face image 183 | ''' 184 | def apply_mask(mask, image): 185 | return ((mask / 255.) * image).astype(np.uint8) 186 | 187 | ''' 188 | Estimate the coordinates of the bounding box from the face landmarks 189 | ''' 190 | def estimate_bb_coordinates(landmarks, epsilon=10): 191 | min_x, max_x, min_y, max_y = int(np.min(landmarks[:, 0])), int(np.max(landmarks[:, 0])), \ 192 | int(np.min(landmarks[:, 1])), int(np.max(landmarks[:, 1])) 193 | 194 | return min_x - epsilon, max_x + epsilon, min_y - epsilon, max_y + epsilon 195 | 196 | # code to generate the alignment between the source and the target image 197 | def generate_warped_image(source_landmark_npz, target_landmark_npz, 198 | source_image_path, target_image_path, 199 | poisson_blend_required = False, 200 | require_full_mask = False): 201 | 202 | stablePoints = [33, 36, 39, 42, 45] 203 | std_size = (256, 256) 204 | 205 | source_image = resize_frame(io.imread(source_image_path)) 206 | target_image = resize_frame(io.imread(target_image_path)) 207 | 208 | source_landmarks = np.load(source_landmark_npz)['landmark'] 209 | 210 | target_landmarks = np.load(target_landmark_npz)['landmark'] 211 | 212 | if require_full_mask: 213 | source_convex_mask = generate_convex_hull(source_image, source_landmarks) 214 | source_convex_mask_no_enlargement = source_convex_mask.copy() 215 | else: 216 | if requires_bb: 217 | source_convex_mask = generate_convex_hull_bb(source_image, source_landmarks[start_idx:end_idx]) 218 | else: 219 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[start_idx:end_idx]) 220 | 221 | # enlarge the convex mask 222 | source_convex_mask_no_enlargement = source_convex_mask.copy() 223 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=10) 224 | 225 | # apply the convex mask to the face 226 | source_face_segmented = apply_mask(source_convex_mask, source_image) 227 | source_face_transformed, transformation = warp_img(source_landmarks[stablePoints, :], 228 | target_landmarks[stablePoints, :], 229 | source_face_segmented, 230 | std_size) 231 | 232 | source_convex_mask_transformed = apply_transform(transformation, source_convex_mask, std_size) 233 | source_convex_mask_no_enlargement_transformed = apply_transform(transformation, source_convex_mask_no_enlargement, std_size) 234 | source_image_transformed = apply_transform(transformation, source_image, std_size) 235 | 236 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks)) 237 | # target_background = apply_mask(target_convex_mask, target_image) 238 | 239 | if requires_bb: 240 | target_convex_mask_without_jaw = generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx]) 241 | else: 242 | target_convex_mask_without_jaw = generate_convex_hull(target_image, target_landmarks[start_idx:end_idx]) 243 | 244 | target_convex_mask_without_jaw = enlarge_mask(target_convex_mask_without_jaw, enlargement=10) 245 | 246 | target_convex_mask_without_jaw = np.invert(target_convex_mask_without_jaw) 247 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image) 248 | target_without_face = apply_mask(target_convex_mask, target_image) 249 | 250 | if poisson_blend_required: 251 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.) 252 | else: 253 | if target_without_face_apply: 254 | combined_image = combine_images(target_without_face, source_face_transformed) 255 | else: 256 | combined_image = combine_images(target_image, source_face_transformed) 257 | # apply the transformed convex mask to the target face for sanity 258 | # target_masked = apply_mask(source_convex_mask_transformed, target_image) 259 | 260 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, source_convex_mask_no_enlargement, target_image, target_convex_mask, combined_image, target_without_face_features, source_image 261 | 262 | # code to generate the alignment between the source and the target image 263 | def generate_aligned_image(source_landmark_npz, target_landmark_npz, 264 | source_image_path, target_image_path, 265 | poisson_blend_required = False, 266 | require_full_mask = False): 267 | 268 | source_image = resize_frame(io.imread(source_image_path)) 269 | target_image = resize_frame(io.imread(target_image_path)) 270 | 271 | source_landmarks = np.load(source_landmark_npz)['landmark'] 272 | source_rotation, source_center, source_distance = compute_rotation(source_landmarks) 273 | 274 | target_landmarks = np.load(target_landmark_npz)['landmark'] 275 | target_rotation, target_center, target_distance = compute_rotation(target_landmarks) 276 | 277 | # rotation of the source conditioned on the source orientation 278 | target_conditioned_source_rotation = source_rotation - target_rotation 279 | 280 | # calculate the scaling that needs to be applied on the source image 281 | scaling = target_distance / source_distance 282 | 283 | # apply the rotation on the source image 284 | height, width = 256, 256 285 | # print(f'Angle of rotation is : {target_conditioned_source_rotation}') 286 | rotate_matrix = cv2.getRotationMatrix2D(center=source_center, angle=target_conditioned_source_rotation, scale=scaling) 287 | 288 | # calculate the translation component of the matrix M 289 | rotate_matrix[0, 2] += (target_center[0] - source_center[0]) 290 | rotate_matrix[1, 2] += (target_center[1] - source_center[1]) 291 | 292 | if require_full_mask: 293 | source_convex_mask = generate_convex_hull(source_image, source_landmarks) 294 | else: 295 | if requires_bb: 296 | source_convex_mask = generate_convex_hull_bb(source_image, source_landmarks[start_idx:end_idx]) 297 | else: 298 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[start_idx:end_idx]) 299 | 300 | # enlarge the convex mask using the enlargement 301 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=5) 302 | 303 | # apply the convex mask to the face 304 | source_face_segmented = apply_mask(source_convex_mask, source_image) 305 | source_face_transformed = cv2.warpAffine(source_face_segmented, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 306 | source_convex_mask_transformed = cv2.warpAffine(source_convex_mask, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 307 | source_image_transformed = cv2.warpAffine(source_image, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC) 308 | 309 | # used for computing the target background 310 | if requires_bb: 311 | target_convex_mask = np.invert(generate_convex_hull_bb(target_image, target_landmarks)) 312 | target_convex_mask_without_jaw = np.invert(generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx])) 313 | else: 314 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks)) 315 | target_convex_mask_without_jaw = np.invert(generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx])) 316 | 317 | # target_background = ((target_convex_mask/255.)*target_image).astype(np.uint8) 318 | # target_convex_mask_without_jaw = np.invert(generate_convex_hull(target_image, target_landmarks[start_idx:end_idx])) 319 | 320 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image) 321 | target_without_face = apply_mask(target_convex_mask, target_image) 322 | 323 | if poisson_blend_required: 324 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.) 325 | else: 326 | if target_without_face_apply: 327 | combined_image = combine_images(target_without_face, source_face_transformed) 328 | else: 329 | combined_image = combine_images(target_image, source_face_transformed) 330 | 331 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, target_image, target_convex_mask, combined_image 332 | -------------------------------------------------------------------------------- /disc_trainers/train_vqvae_mocogan_disc.py: -------------------------------------------------------------------------------- 1 | # This code uses the mocogan content and temporal discriminators 2 | 3 | import argparse 4 | import sys 5 | import os 6 | import random 7 | import os.path as osp 8 | 9 | import torch 10 | from torch import nn, optim 11 | import torch.nn.functional as F 12 | 13 | from torchvision import datasets, transforms, utils 14 | 15 | from TemporalAlignment.models import mocogan_discriminator 16 | 17 | from tqdm import tqdm 18 | 19 | from scheduler import CycleScheduler 20 | 21 | from utils import * 22 | from config import DATASET, LATENT_LOSS_WEIGHT, image_disc_weight, video_disc_weight, SAMPLE_SIZE_FOR_VISUALIZATION 23 | 24 | criterion = nn.MSELoss() 25 | 26 | gan_criterion = nn.BCEWithLogitsLoss() 27 | 28 | global_step = 0 29 | 30 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION 31 | 32 | dataset = DATASET 33 | 34 | CONST_FRAMES_TO_CHECK = 16 35 | 36 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results' 37 | # sample_folder = '/home2/bipasha31/python_scripts/CurrentWork/samples/{}' 38 | 39 | 40 | def run_step(model, data, device, run='train'): 41 | img, S, ground_truth, source_image = process_data(data, device, dataset) 42 | 43 | out, latent_loss = model(img) 44 | 45 | out = out[:, :3] 46 | 47 | recon_loss = criterion(out, ground_truth) 48 | latent_loss = latent_loss.mean() 49 | 50 | if run == 'train': 51 | return recon_loss, latent_loss, S, out, ground_truth 52 | else: 53 | return ground_truth, img, out, source_image 54 | 55 | def run_step_custom(model, data, device, run='train'): 56 | img, S, ground_truth = process_data(data, device, dataset) 57 | 58 | out, latent_loss = model(img) 59 | 60 | out = out[:, :3] # first 3 channels of the prediction 61 | 62 | return out, ground_truth 63 | 64 | 65 | def jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder): 66 | for i, data in enumerate(tqdm(val_loader)): 67 | with torch.no_grad(): 68 | source_images, input, prediction, source_images_original = run_step(model, data, device, run='val') 69 | 70 | source_hulls = input[:, :3] 71 | background = input[:, 3:] 72 | 73 | saves = { 74 | 'source': source_hulls, 75 | 'background': background, 76 | 'prediction': prediction, 77 | 'source_images': source_images, 78 | 'source_original': source_images_original 79 | } 80 | 81 | # if i % (len(val_loader) // 10) == 0 or run_type != 'train': 82 | if True: 83 | def denormalize(x): 84 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 85 | 86 | for name in saves: 87 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4" 88 | frames = saves[name].detach().cpu() 89 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames] 90 | 91 | # os.makedirs(sample_folder, exist_ok=True) 92 | save_frames_as_video(frames, saveas, fps=25) 93 | 94 | def base_validation(model, val_loader, device, epoch, i, run_type, sample_folder): 95 | def get_proper_shape(x): 96 | shape = x.shape 97 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3]) 98 | 99 | for val_i, data in enumerate(tqdm(val_loader)): 100 | with torch.no_grad(): 101 | sample, _, out, source_img = run_step(model, data, device, 'val') 102 | 103 | if val_i % (len(val_loader)//10) == 0: # save 10 results 104 | 105 | def denormalize(x): 106 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 107 | 108 | save_as_sample = f"{sample_folder}/{val_i}_sample.mp4" 109 | save_as_out = f"{sample_folder}/{val_i}_out.mp4" 110 | 111 | sample = sample.detach().cpu() 112 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample] 113 | 114 | out = out.detach().cpu() 115 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out] 116 | 117 | save_frames_as_video(sample, save_as_sample, fps=25) 118 | save_frames_as_video(out, save_as_out, fps=25) 119 | 120 | def validation(model, val_loader, device, epoch, i, sample_folder, run_type='train'): 121 | if dataset >= 6: 122 | jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder) 123 | else: 124 | base_validation(model, val_loader, device, epoch, i, run_type, sample_folder) 125 | 126 | 127 | def flip_video(x): 128 | num = random.randint(0, 1) 129 | if num == 0: 130 | return torch.flip(x, [2]) 131 | else: 132 | return x 133 | 134 | 135 | # inside the train discriminator - disc, real, fake is required 136 | def train_discriminator(opt, discriminator, real_batch, fake_batch): 137 | opt.zero_grad() 138 | 139 | real_disc_preds, _ = discriminator(real_batch) 140 | fake_disc_preds, _ = discriminator(fake_batch.detach()) 141 | 142 | ones = torch.ones_like(real_disc_preds) 143 | zeros = torch.zeros_like(fake_disc_preds) 144 | 145 | l_discriminator = (gan_criterion(real_disc_preds, ones) + gan_criterion(fake_disc_preds, zeros))/2 146 | 147 | l_discriminator.backward() 148 | opt.step() 149 | 150 | return l_discriminator 151 | 152 | def train_generator(opt, image_discriminator, video_discriminator, 153 | fake_batch, recon_loss, latent_loss): 154 | opt.zero_grad() 155 | 156 | # image disc predictions 157 | fake_image_disc_preds, _ = image_discriminator(fake_batch) 158 | all_ones = torch.ones_like(fake_image_disc_preds) 159 | fake_image_disc_loss = gan_criterion(fake_image_disc_preds, all_ones) 160 | 161 | # video disc predictions 162 | fake_video_batch = fake_batch.unsqueeze(0).permute(0, 2, 1, 3, 4) 163 | fake_video_disc_preds, _ = video_discriminator(fake_video_batch) 164 | all_ones = torch.ones_like(fake_video_disc_preds) 165 | fake_video_disc_loss = gan_criterion(fake_video_disc_preds, all_ones) 166 | 167 | gen_loss = recon_loss \ 168 | + LATENT_LOSS_WEIGHT * latent_loss \ 169 | + image_disc_weight * fake_image_disc_loss \ 170 | + video_disc_weight * fake_video_disc_loss 171 | gen_loss.backward(retain_graph=True) 172 | opt.step() 173 | 174 | return gen_loss 175 | 176 | # training is done in an alternating fashion 177 | # disc and gen alternate their training 178 | def train(model, patch_image_disc, patch_video_disc, 179 | gen_optim, image_disc_optim, video_disc_optim, 180 | loader, val_loader, scheduler, device, 181 | epoch, validate_at, checkpoint_dir, sample_folder): 182 | 183 | SAMPLE_FRAMES = 16 # sample 16 frames for the discriminator 184 | 185 | for i, data in enumerate(loader): 186 | 187 | global global_step 188 | 189 | global_step += 1 190 | 191 | model.train() 192 | patch_image_disc.train() 193 | patch_video_disc.train() 194 | 195 | model.zero_grad() 196 | patch_image_disc.zero_grad() 197 | patch_video_disc.zero_grad() 198 | 199 | # train video generator 200 | recon_loss, latent_loss, S, out, ground_truth = run_step(model, data, device) 201 | 202 | # print(f'Frames : {out.shape[0]}') 203 | 204 | # skip if the number of frames is less than SAMPLE_FRAMES 205 | if out.shape[0] < SAMPLE_FRAMES: 206 | print(f'Encountered {out.shape[0]} frames which is less than {SAMPLE_FRAMES}. Continuing ...') 207 | continue 208 | 209 | # sample SAMPLE_FRAMES frames from out 210 | fake_sampled = out[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256 211 | real_sampled = ground_truth[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256 212 | 213 | # print(f'fake sampled : {fake_sampled.shape}') 214 | gen_loss = train_generator(gen_optim, patch_image_disc, patch_video_disc, fake_sampled, recon_loss, latent_loss) 215 | 216 | # train image discriminator 217 | l_image_dis = train_discriminator(image_disc_optim, patch_image_disc, real_sampled, fake_sampled) 218 | 219 | # train video discriminator - adding the batch dimension 220 | l_video_dis = train_discriminator(video_disc_optim, patch_video_disc, 221 | real_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4), 222 | fake_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4)) 223 | 224 | lr = gen_optim.param_groups[0]["lr"] 225 | 226 | # indicates that both the generator and discriminator steps would have been performed 227 | print(f'Epoch : {epoch+1}, step : {global_step}, gen loss : {gen_loss.item():.5f}, image disc loss : {l_image_dis.item():.5f}, video disc loss : {l_video_dis.item():.5f}, lr : {lr:.5f}') 228 | 229 | # check if validation is required 230 | if i%validate_at == 0: 231 | # set the model to eval and generate the predictions 232 | model.eval() 233 | 234 | validation(model, val_loader, device, epoch, i, sample_folder) 235 | 236 | os.makedirs(checkpoint_dir, exist_ok=True) 237 | 238 | # save the vqvae2 generator weights 239 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(global_step).zfill(4)}.pt") 240 | 241 | # save the discriminator weights 242 | # save the video disc weights 243 | torch.save(patch_video_disc.state_dict(), f"{checkpoint_dir}/patch_video_disc_{epoch+1}_{str(global_step).zfill(4)}.pt") 244 | 245 | # save the image/content disc weights 246 | torch.save(patch_image_disc.state_dict(), f"{checkpoint_dir}/patch_image_disc_{epoch+1}_{str(global_step).zfill(4)}.pt") 247 | 248 | # reverse the training state of the model 249 | model.train() 250 | 251 | 252 | def main(args): 253 | device = "cuda" 254 | 255 | default_transform = transforms.Compose( 256 | [ 257 | transforms.ToPILImage(), 258 | transforms.Resize(args.size), 259 | transforms.ToTensor(), 260 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 261 | ] 262 | ) 263 | 264 | train_loader, val_loader, model, patch_image_disc, patch_video_disc, \ 265 | image_disc_optim, video_disc_optim = \ 266 | get_loaders_and_models(args, dataset, default_transform, device, test=args.test) 267 | 268 | 269 | # load the models on the gpu 270 | model = model.to(device) 271 | patch_image_disc = patch_image_disc.to(device) 272 | patch_video_disc = patch_video_disc.to(device) 273 | 274 | # loading the pretrained generator (vqvae2) model weights 275 | if args.ckpt: 276 | print(f'Loading pretrained generator model : {args.ckpt}') 277 | state_dict = torch.load(args.ckpt) 278 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() } 279 | try: 280 | model.module.load_state_dict(state_dict) 281 | except: 282 | model.load_state_dict(state_dict) 283 | 284 | # load the image and video disc models if required 285 | if args.load_disc: 286 | # image_disc_path = 'patch_image_disc_' + args.ckpt.split('_', 1)[1] 287 | # video_disc_path = 'patch_video_disc_' + args.ckpt.split('_', 1)[1] 288 | 289 | image_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_image_disc_' + osp.basename(args.ckpt).split('_', 1)[1]) 290 | video_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_video_disc_' + osp.basename(args.ckpt).split('_', 1)[1]) 291 | 292 | print(f'Loading pretrained disc models : {image_disc_path}, {video_disc_path}') 293 | 294 | image_disc_state_dict = torch.load(image_disc_path) 295 | image_disc_state_dict = { k.replace('module.', ''): v for k, v in image_disc_state_dict.items() } 296 | 297 | video_disc_state_dict = torch.load(video_disc_path) 298 | video_disc_state_dict = { k.replace('module.', ''): v for k, v in video_disc_state_dict.items() } 299 | 300 | try: 301 | patch_image_disc.module.load_state_dict(image_disc_state_dict) 302 | except: 303 | patch_image_disc.load_state_dict(image_disc_state_dict) 304 | 305 | try: 306 | patch_video_disc.module.load_state_dict(video_disc_state_dict) 307 | except: 308 | patch_video_disc.load_state_dict(video_disc_state_dict) 309 | 310 | if args.test: 311 | validation(model, val_loader, device, 0, 0, args.sample_folder, 'val') 312 | else: 313 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 314 | 315 | scheduler = None 316 | 317 | if args.sched == "cycle": 318 | scheduler = CycleScheduler( 319 | optimizer, 320 | args.lr, 321 | n_iter=len(train_loader) * args.epoch, 322 | momentum=None, 323 | warmup_proportion=0.05, 324 | ) 325 | 326 | for i in range(args.epoch): 327 | 328 | train(model, patch_image_disc, patch_video_disc, 329 | optimizer, image_disc_optim, video_disc_optim, 330 | train_loader, val_loader, scheduler, device, i, 331 | args.validate_at, args.checkpoint_dir, args.sample_folder) 332 | 333 | 334 | def get_random_name(cipher_length=5): 335 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 336 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)]) 337 | 338 | if __name__ == "__main__": 339 | parser = argparse.ArgumentParser() 340 | parser.add_argument("--n_gpu", type=int, default=1) 341 | 342 | port = random.randint(51000, 52000) 343 | 344 | # parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 345 | parser.add_argument("--batch_size", type=int, default=32) 346 | parser.add_argument("--size", type=int, default=256) 347 | parser.add_argument("--epoch", type=int, default=560) 348 | parser.add_argument("--lr", type=float, default=3e-4) 349 | parser.add_argument("--sched", type=str) 350 | parser.add_argument("--checkpoint_suffix", type=str, default='') 351 | parser.add_argument("--validate_at", type=int, default=1024) 352 | parser.add_argument("--ckpt", required=False) 353 | parser.add_argument("--test", action='store_true', required=False) 354 | parser.add_argument("--gray", action='store_true', required=False) 355 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty') 356 | parser.add_argument("--crossid", action='store_true', required=False) 357 | parser.add_argument("--sample_folder", type=str, default='samples') 358 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint') 359 | parser.add_argument("--validation_folder", type=str, default=None) 360 | parser.add_argument("--custom_validation", action='store_true', required=False) 361 | parser.add_argument("--load_disc", action='store_true', required=False) 362 | 363 | args = parser.parse_args() 364 | 365 | # args.n_gpu = torch.cuda.device_count() 366 | current_run = get_random_name() 367 | 368 | # sample_folder = sample_folder.format(args.checkpoint_suffix) 369 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run) 370 | os.makedirs(args.sample_folder, exist_ok=True) 371 | 372 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run) 373 | # os.makedirs(args.checkpoint_dir, exist_ok=True) 374 | 375 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix) 376 | 377 | print(args, flush=True) 378 | 379 | print(f'Weight configuration used, latent loss : {LATENT_LOSS_WEIGHT}, image disc weight : {image_disc_weight}, video disc weight : {video_disc_weight}') 380 | 381 | # print(f'Weight configuration used : + \ 382 | # {LATENT_LOSS_WEIGHT}, 2d disc weight : {G_LOSS_2D_WEIGHT}, + \ 383 | # temporal disc weight : {G_LOSS_3D_WEIGHT}') 384 | 385 | # dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,)) 386 | main(args) -------------------------------------------------------------------------------- /disc_trainers/train_vqvae_mocogan_disc_perceptual.py: -------------------------------------------------------------------------------- 1 | # This code uses the mocogan content and temporal discriminators with perceptual loss 2 | 3 | import argparse 4 | import sys 5 | import os 6 | import random 7 | import os.path as osp 8 | 9 | import torch 10 | from torch import nn, optim 11 | import torch.nn.functional as F 12 | 13 | from torchvision import datasets, transforms, utils 14 | 15 | from TemporalAlignment.models import mocogan_discriminator 16 | 17 | from tqdm import tqdm 18 | 19 | from scheduler import CycleScheduler 20 | 21 | from utils import * 22 | from config import DATASET, LATENT_LOSS_WEIGHT, PERCEPTUAL_LOSS_WEIGHT, image_disc_weight, video_disc_weight, SAMPLE_SIZE_FOR_VISUALIZATION 23 | 24 | criterion = nn.MSELoss() 25 | 26 | gan_criterion = nn.BCEWithLogitsLoss() 27 | 28 | global_step = 0 29 | 30 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION 31 | 32 | dataset = DATASET 33 | 34 | CONST_FRAMES_TO_CHECK = 16 35 | 36 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results' 37 | # sample_folder = '/home2/bipasha31/python_scripts/CurrentWork/samples/{}' 38 | 39 | 40 | def run_step(model, vqlpips, data, device, run='train'): 41 | img, S, ground_truth, source_image = process_data(data, device, dataset) 42 | 43 | out, latent_loss = model(img) 44 | 45 | out = out[:, :3] 46 | 47 | recon_loss = criterion(out, ground_truth) 48 | latent_loss = latent_loss.mean() 49 | 50 | perceptual_loss = vqlpips(ground_truth, out) 51 | 52 | if run == 'train': 53 | return recon_loss, latent_loss, S, out, ground_truth, perceptual_loss 54 | else: 55 | return ground_truth, img, out, source_image 56 | 57 | def run_step_custom(model, data, device, run='train'): 58 | img, S, ground_truth = process_data(data, device, dataset) 59 | 60 | out, latent_loss = model(img) 61 | 62 | out = out[:, :3] # first 3 channels of the prediction 63 | 64 | return out, ground_truth 65 | 66 | 67 | def jitter_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder): 68 | for i, data in enumerate(tqdm(val_loader)): 69 | with torch.no_grad(): 70 | source_images, input, prediction, source_images_original = run_step(model, vqlpips, data, device, run='val') 71 | 72 | source_hulls = input[:, :3] 73 | background = input[:, 3:] 74 | 75 | saves = { 76 | 'source': source_hulls, 77 | 'background': background, 78 | 'prediction': prediction, 79 | 'source_images': source_images, 80 | 'source_original': source_images_original 81 | } 82 | 83 | # if i % (len(val_loader) // 10) == 0 or run_type != 'train': 84 | if True: 85 | def denormalize(x): 86 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 87 | 88 | for name in saves: 89 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4" 90 | frames = saves[name].detach().cpu() 91 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames] 92 | 93 | # os.makedirs(sample_folder, exist_ok=True) 94 | save_frames_as_video(frames, saveas, fps=25) 95 | 96 | def base_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder): 97 | def get_proper_shape(x): 98 | shape = x.shape 99 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3]) 100 | 101 | for val_i, data in enumerate(tqdm(val_loader)): 102 | with torch.no_grad(): 103 | sample, _, out, source_img = run_step(model, vqlpips, data, device, 'val') 104 | 105 | if val_i % (len(val_loader)//10) == 0: # save 10 results 106 | 107 | def denormalize(x): 108 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 109 | 110 | save_as_sample = f"{sample_folder}/{val_i}_sample.mp4" 111 | save_as_out = f"{sample_folder}/{val_i}_out.mp4" 112 | 113 | sample = sample.detach().cpu() 114 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample] 115 | 116 | out = out.detach().cpu() 117 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out] 118 | 119 | save_frames_as_video(sample, save_as_sample, fps=25) 120 | save_frames_as_video(out, save_as_out, fps=25) 121 | 122 | def validation(model, vqlpips, val_loader, device, epoch, i, sample_folder, run_type='train'): 123 | if dataset >= 6: 124 | jitter_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder) 125 | else: 126 | base_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder) 127 | 128 | 129 | def flip_video(x): 130 | num = random.randint(0, 1) 131 | if num == 0: 132 | return torch.flip(x, [2]) 133 | else: 134 | return x 135 | 136 | 137 | # inside the train discriminator - disc, real, fake is required 138 | def train_discriminator(opt, discriminator, real_batch, fake_batch): 139 | opt.zero_grad() 140 | 141 | real_disc_preds, _ = discriminator(real_batch) 142 | fake_disc_preds, _ = discriminator(fake_batch.detach()) 143 | 144 | ones = torch.ones_like(real_disc_preds) 145 | zeros = torch.zeros_like(fake_disc_preds) 146 | 147 | l_discriminator = (gan_criterion(real_disc_preds, ones) + gan_criterion(fake_disc_preds, zeros))/2 148 | 149 | l_discriminator.backward() 150 | opt.step() 151 | 152 | return l_discriminator 153 | 154 | def train_generator(opt, image_discriminator, video_discriminator, 155 | fake_batch, recon_loss, latent_loss, perceptual_loss): 156 | opt.zero_grad() 157 | 158 | # image disc predictions 159 | fake_image_disc_preds, _ = image_discriminator(fake_batch) 160 | all_ones = torch.ones_like(fake_image_disc_preds) 161 | fake_image_disc_loss = gan_criterion(fake_image_disc_preds, all_ones) 162 | 163 | # video disc predictions 164 | fake_video_batch = fake_batch.unsqueeze(0).permute(0, 2, 1, 3, 4) 165 | fake_video_disc_preds, _ = video_discriminator(fake_video_batch) 166 | all_ones = torch.ones_like(fake_video_disc_preds) 167 | fake_video_disc_loss = gan_criterion(fake_video_disc_preds, all_ones) 168 | 169 | gen_loss = recon_loss \ 170 | + LATENT_LOSS_WEIGHT * latent_loss \ 171 | + image_disc_weight * fake_image_disc_loss \ 172 | + video_disc_weight * fake_video_disc_loss \ 173 | + PERCEPTUAL_LOSS_WEIGHT * perceptual_loss 174 | 175 | gen_loss.backward(retain_graph=True) 176 | opt.step() 177 | 178 | return gen_loss 179 | 180 | # training is done in an alternating fashion 181 | # disc and gen alternate their training 182 | def train(model, patch_image_disc, patch_video_disc, 183 | gen_optim, image_disc_optim, video_disc_optim, 184 | loader, val_loader, scheduler, device, 185 | epoch, validate_at, checkpoint_dir, sample_folder, vqlpips): 186 | 187 | SAMPLE_FRAMES = 16 # sample 16 frames for the discriminator 188 | 189 | for i, data in enumerate(loader): 190 | 191 | global global_step 192 | 193 | global_step += 1 194 | 195 | model.train() 196 | patch_image_disc.train() 197 | patch_video_disc.train() 198 | 199 | model.zero_grad() 200 | patch_image_disc.zero_grad() 201 | patch_video_disc.zero_grad() 202 | 203 | # train video generator 204 | recon_loss, latent_loss, S, out, ground_truth, perceptual_loss = run_step(model, vqlpips, data, device) 205 | 206 | # print(f'Frames : {out.shape[0]}') 207 | 208 | # skip if the number of frames is less than SAMPLE_FRAMES 209 | if out.shape[0] < SAMPLE_FRAMES: 210 | print(f'Encountered {out.shape[0]} frames which is less than {SAMPLE_FRAMES}. Continuing ...') 211 | continue 212 | 213 | # sample SAMPLE_FRAMES frames from out 214 | fake_sampled = out[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256 215 | real_sampled = ground_truth[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256 216 | 217 | # print(f'fake sampled : {fake_sampled.shape}') 218 | gen_loss = train_generator(gen_optim, patch_image_disc, patch_video_disc, fake_sampled, recon_loss, latent_loss, perceptual_loss) 219 | 220 | # train image discriminator 221 | l_image_dis = train_discriminator(image_disc_optim, patch_image_disc, real_sampled, fake_sampled) 222 | 223 | # train video discriminator - adding the batch dimension 224 | l_video_dis = train_discriminator(video_disc_optim, patch_video_disc, 225 | real_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4), 226 | fake_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4)) 227 | 228 | lr = gen_optim.param_groups[0]["lr"] 229 | 230 | # indicates that both the generator and discriminator steps would have been performed 231 | print(f'Epoch : {epoch+1}, step : {global_step}, gen loss : {gen_loss.item():.5f}, image disc loss : {l_image_dis.item():.5f}, video disc loss : {l_video_dis.item():.5f}, lr : {lr:.5f}') 232 | 233 | # check if validation is required 234 | if i%validate_at == 0: 235 | # set the model to eval and generate the predictions 236 | model.eval() 237 | 238 | validation(model, vqlpips, val_loader, device, epoch, i, sample_folder) 239 | 240 | os.makedirs(checkpoint_dir, exist_ok=True) 241 | 242 | # save the vqvae2 generator weights 243 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(global_step).zfill(4)}.pt") 244 | 245 | # save the discriminator weights 246 | # save the video disc weights 247 | torch.save(patch_video_disc.state_dict(), f"{checkpoint_dir}/patch_video_disc_{epoch+1}_{str(global_step).zfill(4)}.pt") 248 | 249 | # save the image/content disc weights 250 | torch.save(patch_image_disc.state_dict(), f"{checkpoint_dir}/patch_image_disc_{epoch+1}_{str(global_step).zfill(4)}.pt") 251 | 252 | # reverse the training state of the model 253 | model.train() 254 | 255 | 256 | def main(args): 257 | device = "cuda" 258 | 259 | default_transform = transforms.Compose( 260 | [ 261 | transforms.ToPILImage(), 262 | transforms.Resize(args.size), 263 | transforms.ToTensor(), 264 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 265 | ] 266 | ) 267 | 268 | train_loader, val_loader, model, patch_image_disc, patch_video_disc, \ 269 | image_disc_optim, video_disc_optim, vqlpips = \ 270 | get_loaders_and_models(args, dataset, default_transform, device, test=args.test) 271 | 272 | 273 | # load the models on the gpu 274 | model = model.to(device) 275 | patch_image_disc = patch_image_disc.to(device) 276 | patch_video_disc = patch_video_disc.to(device) 277 | vqlpips = vqlpips.to(device) 278 | 279 | # loading the pretrained generator (vqvae2) model weights 280 | if args.ckpt: 281 | print(f'Loading pretrained generator model : {args.ckpt}') 282 | state_dict = torch.load(args.ckpt) 283 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() } 284 | try: 285 | model.module.load_state_dict(state_dict) 286 | except: 287 | model.load_state_dict(state_dict) 288 | 289 | # load the image and video disc models if required 290 | if args.load_disc: 291 | # image_disc_path = 'patch_image_disc_' + args.ckpt.split('_', 1)[1] 292 | # video_disc_path = 'patch_video_disc_' + args.ckpt.split('_', 1)[1] 293 | 294 | image_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_image_disc_' + osp.basename(args.ckpt).split('_', 1)[1]) 295 | video_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_video_disc_' + osp.basename(args.ckpt).split('_', 1)[1]) 296 | 297 | print(f'Loading pretrained disc models : {image_disc_path}, {video_disc_path}') 298 | 299 | image_disc_state_dict = torch.load(image_disc_path) 300 | image_disc_state_dict = { k.replace('module.', ''): v for k, v in image_disc_state_dict.items() } 301 | 302 | video_disc_state_dict = torch.load(video_disc_path) 303 | video_disc_state_dict = { k.replace('module.', ''): v for k, v in video_disc_state_dict.items() } 304 | 305 | try: 306 | patch_image_disc.module.load_state_dict(image_disc_state_dict) 307 | except: 308 | patch_image_disc.load_state_dict(image_disc_state_dict) 309 | 310 | try: 311 | patch_video_disc.module.load_state_dict(video_disc_state_dict) 312 | except: 313 | patch_video_disc.load_state_dict(video_disc_state_dict) 314 | 315 | if args.test: 316 | validation(model, vqlpips, val_loader, device, 0, 0, args.sample_folder, 'val') 317 | else: 318 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 319 | 320 | scheduler = None 321 | 322 | if args.sched == "cycle": 323 | scheduler = CycleScheduler( 324 | optimizer, 325 | args.lr, 326 | n_iter=len(train_loader) * args.epoch, 327 | momentum=None, 328 | warmup_proportion=0.05, 329 | ) 330 | 331 | for i in range(args.epoch): 332 | 333 | train(model, patch_image_disc, patch_video_disc, 334 | optimizer, image_disc_optim, video_disc_optim, 335 | train_loader, val_loader, scheduler, device, i, 336 | args.validate_at, args.checkpoint_dir, args.sample_folder, vqlpips) 337 | 338 | 339 | def get_random_name(cipher_length=5): 340 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 341 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)]) 342 | 343 | if __name__ == "__main__": 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument("--n_gpu", type=int, default=1) 346 | 347 | port = random.randint(51000, 52000) 348 | 349 | # parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 350 | parser.add_argument("--batch_size", type=int, default=32) 351 | parser.add_argument("--size", type=int, default=256) 352 | parser.add_argument("--epoch", type=int, default=560) 353 | parser.add_argument("--lr", type=float, default=3e-4) 354 | parser.add_argument("--sched", type=str) 355 | parser.add_argument("--checkpoint_suffix", type=str, default='') 356 | parser.add_argument("--validate_at", type=int, default=1024) 357 | parser.add_argument("--ckpt", required=False) 358 | parser.add_argument("--test", action='store_true', required=False) 359 | parser.add_argument("--gray", action='store_true', required=False) 360 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty') 361 | parser.add_argument("--crossid", action='store_true', required=False) 362 | parser.add_argument("--sample_folder", type=str, default='samples') 363 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint') 364 | parser.add_argument("--validation_folder", type=str, default=None) 365 | parser.add_argument("--custom_validation", action='store_true', required=False) 366 | parser.add_argument("--load_disc", action='store_true', required=False) 367 | 368 | args = parser.parse_args() 369 | 370 | # args.n_gpu = torch.cuda.device_count() 371 | current_run = get_random_name() 372 | 373 | # sample_folder = sample_folder.format(args.checkpoint_suffix) 374 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run) 375 | os.makedirs(args.sample_folder, exist_ok=True) 376 | 377 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run) 378 | # os.makedirs(args.checkpoint_dir, exist_ok=True) 379 | 380 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix) 381 | 382 | print(args, flush=True) 383 | 384 | print(f'Weight configuration used, latent loss : {LATENT_LOSS_WEIGHT}, \ 385 | image disc weight : {image_disc_weight}, video disc weight : {video_disc_weight} \ 386 | perceptual loss weight : {PERCEPTUAL_LOSS_WEIGHT}') 387 | 388 | main(args) -------------------------------------------------------------------------------- /distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import ( 2 | get_rank, 3 | get_local_rank, 4 | is_primary, 5 | synchronize, 6 | get_world_size, 7 | all_reduce, 8 | all_gather, 9 | reduce_dict, 10 | data_sampler, 11 | LOCAL_PROCESS_GROUP, 12 | ) 13 | from .launch import launch 14 | -------------------------------------------------------------------------------- /distributed/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils import data 7 | 8 | 9 | LOCAL_PROCESS_GROUP = None 10 | 11 | 12 | def is_primary(): 13 | return get_rank() == 0 14 | 15 | 16 | def get_rank(): 17 | if not dist.is_available(): 18 | return 0 19 | 20 | if not dist.is_initialized(): 21 | return 0 22 | 23 | return dist.get_rank() 24 | 25 | 26 | def get_local_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | 30 | if not dist.is_initialized(): 31 | return 0 32 | 33 | if LOCAL_PROCESS_GROUP is None: 34 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 35 | 36 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 37 | 38 | 39 | def synchronize(): 40 | if not dist.is_available(): 41 | return 42 | 43 | if not dist.is_initialized(): 44 | return 45 | 46 | world_size = dist.get_world_size() 47 | 48 | if world_size == 1: 49 | return 50 | 51 | dist.barrier() 52 | 53 | 54 | def get_world_size(): 55 | if not dist.is_available(): 56 | return 1 57 | 58 | if not dist.is_initialized(): 59 | return 1 60 | 61 | return dist.get_world_size() 62 | 63 | 64 | def all_reduce(tensor, op=dist.ReduceOp.SUM): 65 | world_size = get_world_size() 66 | 67 | if world_size == 1: 68 | return tensor 69 | 70 | dist.all_reduce(tensor, op=op) 71 | 72 | return tensor 73 | 74 | 75 | def all_gather(data): 76 | world_size = get_world_size() 77 | 78 | if world_size == 1: 79 | return [data] 80 | 81 | buffer = pickle.dumps(data) 82 | storage = torch.ByteStorage.from_buffer(buffer) 83 | tensor = torch.ByteTensor(storage).to("cuda") 84 | 85 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 86 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 87 | dist.all_gather(size_list, local_size) 88 | size_list = [int(size.item()) for size in size_list] 89 | max_size = max(size_list) 90 | 91 | tensor_list = [] 92 | for _ in size_list: 93 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 94 | 95 | if local_size != max_size: 96 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 97 | tensor = torch.cat((tensor, padding), 0) 98 | 99 | dist.all_gather(tensor_list, tensor) 100 | 101 | data_list = [] 102 | 103 | for size, tensor in zip(size_list, tensor_list): 104 | buffer = tensor.cpu().numpy().tobytes()[:size] 105 | data_list.append(pickle.loads(buffer)) 106 | 107 | return data_list 108 | 109 | 110 | def reduce_dict(input_dict, average=True): 111 | world_size = get_world_size() 112 | 113 | if world_size < 2: 114 | return input_dict 115 | 116 | with torch.no_grad(): 117 | keys = [] 118 | values = [] 119 | 120 | for k in sorted(input_dict.keys()): 121 | keys.append(k) 122 | values.append(input_dict[k]) 123 | 124 | values = torch.stack(values, 0) 125 | dist.reduce(values, dst=0) 126 | 127 | if dist.get_rank() == 0 and average: 128 | values /= world_size 129 | 130 | reduced_dict = {k: v for k, v in zip(keys, values)} 131 | 132 | return reduced_dict 133 | 134 | 135 | def data_sampler(dataset, shuffle, distributed): 136 | if distributed: 137 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 138 | 139 | if shuffle: 140 | return data.RandomSampler(dataset) 141 | 142 | else: 143 | return data.SequentialSampler(dataset) 144 | -------------------------------------------------------------------------------- /distributed/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import distributed as dist 5 | from torch import multiprocessing as mp 6 | 7 | import distributed as dist_fn 8 | 9 | 10 | def find_free_port(): 11 | import socket 12 | 13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | 15 | sock.bind(("", 0)) 16 | port = sock.getsockname()[1] 17 | sock.close() 18 | 19 | return port 20 | 21 | 22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): 23 | world_size = n_machine * n_gpu_per_machine 24 | 25 | if world_size > 1: 26 | if "OMP_NUM_THREADS" not in os.environ: 27 | os.environ["OMP_NUM_THREADS"] = "1" 28 | 29 | if dist_url == "auto": 30 | if n_machine != 1: 31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 32 | 33 | port = find_free_port() 34 | dist_url = f"tcp://127.0.0.1:{port}" 35 | 36 | if n_machine > 1 and dist_url.startswith("file://"): 37 | raise ValueError( 38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 39 | ) 40 | 41 | mp.spawn( 42 | distributed_worker, 43 | nprocs=n_gpu_per_machine, 44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 45 | daemon=False, 46 | ) 47 | 48 | else: 49 | fn(*args) 50 | 51 | 52 | def distributed_worker( 53 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 54 | ): 55 | if not torch.cuda.is_available(): 56 | raise OSError("CUDA is not available. Please check your environments") 57 | 58 | global_rank = machine_rank * n_gpu_per_machine + local_rank 59 | 60 | try: 61 | dist.init_process_group( 62 | backend="NCCL", 63 | init_method=dist_url, 64 | world_size=world_size, 65 | rank=global_rank, 66 | ) 67 | 68 | except Exception: 69 | raise OSError("failed to initialize NCCL groups") 70 | 71 | dist_fn.synchronize() 72 | 73 | if n_gpu_per_machine > torch.cuda.device_count(): 74 | raise ValueError( 75 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 76 | ) 77 | 78 | torch.cuda.set_device(local_rank) 79 | 80 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 81 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 82 | 83 | n_machine = world_size // n_gpu_per_machine 84 | 85 | for i in range(n_machine): 86 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 87 | pg = dist.new_group(ranks_on_i) 88 | 89 | if i == machine_rank: 90 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg 91 | 92 | fn(*args) 93 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: faceoff 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - backcall=0.2.0=pyhd3eb1b0_0 10 | - ca-certificates=2022.07.19=h06a4308_0 11 | - debugpy=1.5.1=py37h295c915_0 12 | - decorator=5.1.1=pyhd3eb1b0_0 13 | - entrypoints=0.4=py37h06a4308_0 14 | - ipykernel=6.9.1=py37h06a4308_0 15 | - ipython=7.31.1=py37h06a4308_1 16 | - jedi=0.18.1=py37h06a4308_1 17 | - jupyter_client=7.2.2=py37h06a4308_0 18 | - jupyter_core=4.10.0=py37h06a4308_0 19 | - ld_impl_linux-64=2.39=hcc3a1bd_1 20 | - libffi=3.4.2=h7f98852_5 21 | - libgcc-ng=12.2.0=h65d4601_19 22 | - libgomp=12.2.0=h65d4601_19 23 | - libnsl=2.0.0=h7f98852_0 24 | - libsodium=1.0.18=h7b6447c_0 25 | - libsqlite=3.40.0=h753d276_0 26 | - libstdcxx-ng=12.2.0=h46fd767_19 27 | - libzlib=1.2.13=h166bdaf_4 28 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 29 | - ncurses=6.3=h27087fc_1 30 | - nest-asyncio=1.5.5=py37h06a4308_0 31 | - openssl=3.0.7=h0b41bf4_1 32 | - parso=0.8.3=pyhd3eb1b0_0 33 | - pexpect=4.8.0=pyhd3eb1b0_3 34 | - pickleshare=0.7.5=pyhd3eb1b0_1003 35 | - pip=22.3.1=pyhd8ed1ab_0 36 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 37 | - ptyprocess=0.7.0=pyhd3eb1b0_2 38 | - pygments=2.11.2=pyhd3eb1b0_0 39 | - python=3.7.12=hf930737_100_cpython 40 | - python-dateutil=2.8.2=pyhd3eb1b0_0 41 | - pyzmq=23.2.0=py37h6a678d5_0 42 | - readline=8.1.2=h0f457ee_0 43 | - setuptools=65.6.3=pyhd8ed1ab_0 44 | - six=1.16.0=pyhd3eb1b0_1 45 | - sqlite=3.40.0=h4ff8645_0 46 | - tk=8.6.12=h27826a3_0 47 | - tornado=6.1=py37h27cfd23_0 48 | - traitlets=5.1.1=pyhd3eb1b0_0 49 | - wcwidth=0.2.5=pyhd3eb1b0_0 50 | - wheel=0.38.4=pyhd8ed1ab_0 51 | - xz=5.2.6=h166bdaf_0 52 | - zeromq=4.3.4=h2531618_0 53 | - pip: 54 | - bchlib==0.14.0 55 | - certifi==2022.12.7 56 | - charset-normalizer==2.1.1 57 | - cycler==0.11.0 58 | - fonttools==4.38.0 59 | - idna==3.4 60 | - imageio==2.24.0 61 | - joblib==1.2.0 62 | - kiwisolver==1.4.4 63 | - matplotlib==3.5.3 64 | - networkx==2.6.3 65 | - numpy==1.21.6 66 | - nvidia-cublas-cu11==11.10.3.66 67 | - nvidia-cuda-nvrtc-cu11==11.7.99 68 | - nvidia-cuda-runtime-cu11==11.7.99 69 | - nvidia-cudnn-cu11==8.5.0.96 70 | - opencv-python==4.6.0.66 71 | - packaging==22.0 72 | - pillow==9.3.0 73 | - pyparsing==3.0.9 74 | - pywavelets==1.3.0 75 | - requests==2.28.1 76 | - scikit-image==0.19.3 77 | - scikit-learn==1.0.2 78 | - scipy==1.7.3 79 | - sklearn==0.0.post1 80 | - stn==1.0.1 81 | - threadpoolctl==3.1.0 82 | - tifffile==2021.11.2 83 | - torch==1.13.1 84 | - torchvision==0.14.1 85 | - tqdm==4.64.1 86 | - typing-extensions==4.4.0 87 | - urllib3==1.26.13 88 | - wand==0.6.11 89 | prefix: /home2/aditya1/miniconda3/envs/stegastamp 90 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.lpips import LPIPS 6 | from models.discriminator import NLayerDiscriminator, weights_init 7 | 8 | criterion = nn.L1Loss() 9 | 10 | def adopt_weight(weight, global_step, threshold=0, value=0.): 11 | if global_step < threshold: 12 | weight = value 13 | return weight 14 | 15 | def hinge_d_loss(logits_real, logits_fake): 16 | loss_real = torch.mean(F.relu(1. - logits_real)) 17 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | def vanilla_d_loss(logits_real, logits_fake): 22 | d_loss = 0.5 * ( 23 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 24 | torch.mean(torch.nn.functional.softplus(logits_fake))) 25 | return d_loss 26 | 27 | class VQLPIPS(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.perceptual_loss = LPIPS().eval() 31 | 32 | def forward(self, targets, reconstructions): 33 | return self.perceptual_loss(targets.contiguous(), reconstructions.contiguous()).mean() 34 | 35 | class VQLPIPSWithDiscriminator(nn.Module): 36 | def __init__(self, disc_start, 37 | disc_num_layers=3, disc_in_channels=3, 38 | disc_factor=1.0, disc_weight=0.8, use_actnorm=False, 39 | disc_ndf=64, disc_loss="hinge"): 40 | super().__init__() 41 | assert disc_loss in ["hinge", "vanilla"] 42 | 43 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 44 | n_layers=disc_num_layers, 45 | use_actnorm=use_actnorm, 46 | ndf=disc_ndf 47 | ).apply(weights_init) 48 | 49 | self.discriminator_iter_start = disc_start 50 | 51 | self.perceptual_loss = LPIPS().eval() 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | 62 | self.perceptual_weight = 1.0 63 | 64 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 65 | if last_layer is not None: 66 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 67 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 68 | else: 69 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 70 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 71 | 72 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 73 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 74 | d_weight = d_weight * self.discriminator_weight 75 | return d_weight 76 | 77 | def forward(self, rec_loss, targets, reconstructions): 78 | p_loss = self.perceptual_loss(targets.contiguous(), reconstructions.contiguous()) 79 | rec_loss = rec_loss + self.perceptual_weight * p_loss 80 | 81 | nll_loss = torch.mean(rec_loss) 82 | 83 | return p_loss.mean(), nll_loss 84 | 85 | def second_forward(self, rec_loss, targets, reconstructions, optimizer_idx, 86 | global_step, perceptual_loss=False, last_layer=None): 87 | 88 | if perceptual_loss: 89 | p_loss = self.perceptual_loss(targets.contiguous(), reconstructions.contiguous()) 90 | rec_loss = rec_loss + self.perceptual_weight * p_loss 91 | else: 92 | p_loss = 0 93 | 94 | nll_loss = torch.mean(rec_loss) 95 | 96 | # now the GAN part 97 | if optimizer_idx == 0: 98 | # generator update 99 | logits_fake = self.discriminator(reconstructions.contiguous()) 100 | # g_loss = -torch.mean(logits_fake) 101 | g_loss = criterion(logits_fake, torch.ones_like(logits_fake, device=logits_fake.device)) 102 | 103 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 104 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 105 | return d_weight * disc_factor * g_loss, p_loss.mean(), nll_loss 106 | 107 | if optimizer_idx == 1: 108 | # second pass for discriminator update 109 | logits_real = self.discriminator(targets.contiguous().detach()) 110 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 111 | 112 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 113 | # d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 114 | loss_real = criterion(logits_real, torch.ones_like(logits_real, device=logits_real.device)) 115 | loss_fake = criterion(logits_fake, torch.zeros_like(logits_fake, device=logits_fake.device)) 116 | d_loss = disc_factor * (loss_real + loss_fake).mean() 117 | 118 | return d_loss, p_loss.mean(), nll_loss 119 | 120 | class ContrastiveLoss(torch.nn.Module): 121 | """ 122 | Contrastive loss function. 123 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 124 | """ 125 | 126 | def __init__(self, margin=2.0): 127 | super(ContrastiveLoss, self).__init__() 128 | self.margin = margin 129 | 130 | def forward(self, output1, output2, label): 131 | euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True) 132 | loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + 133 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 134 | 135 | 136 | return loss_contrastive 137 | 138 | class SiameseNetworkFaceSimilarity(nn.Module): 139 | def __init__(self): 140 | super(SiameseNetworkFaceSimilarity, self).__init__() 141 | self.cnn1 = nn.Sequential( 142 | nn.ReflectionPad2d(1), 143 | nn.Conv2d(1, 4, kernel_size=3), 144 | nn.ReLU(inplace=True), 145 | nn.BatchNorm2d(4), 146 | 147 | nn.ReflectionPad2d(1), 148 | nn.Conv2d(4, 8, kernel_size=3), 149 | nn.ReLU(inplace=True), 150 | nn.BatchNorm2d(8), 151 | 152 | nn.ReflectionPad2d(1), 153 | nn.Conv2d(8, 8, kernel_size=3), 154 | nn.ReLU(inplace=True), 155 | nn.BatchNorm2d(8), 156 | ) 157 | 158 | self.fc1 = nn.Sequential( 159 | nn.Linear(8*256*256, 500), 160 | nn.ReLU(inplace=True), 161 | 162 | nn.Linear(500, 500), 163 | nn.ReLU(inplace=True), 164 | 165 | nn.Linear(500, 5)) 166 | 167 | def forward_once(self, x): 168 | x = x.unsqueeze(1) 169 | output = self.cnn1(x) 170 | output = output.view(output.size()[0], -1) 171 | output = self.fc1(output) 172 | return output 173 | 174 | def forward(self, input1, input2): 175 | output1 = self.forward_once(input1) 176 | output2 = self.forward_once(input2) 177 | return F.pairwise_distance(output1, output2).mean() 178 | -------------------------------------------------------------------------------- /models/actnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from models.actnorm import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return nn.Sigmoid()(self.main(input)) 68 | -------------------------------------------------------------------------------- /models/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | import os, hashlib 9 | import requests 10 | from tqdm import tqdm 11 | 12 | URL_MAP = { 13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 14 | } 15 | 16 | CKPT_MAP = { 17 | "vgg_lpips": "vgg.pth" 18 | } 19 | 20 | MD5_MAP = { 21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 22 | } 23 | 24 | def download(url, local_path, chunk_size=1024): 25 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 26 | with requests.get(url, stream=True) as r: 27 | total_size = int(r.headers.get("content-length", 0)) 28 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 29 | with open(local_path, "wb") as f: 30 | for data in r.iter_content(chunk_size=chunk_size): 31 | if data: 32 | f.write(data) 33 | pbar.update(chunk_size) 34 | 35 | def md5_hash(path): 36 | with open(path, "rb") as f: 37 | content = f.read() 38 | return hashlib.md5(content).hexdigest() 39 | 40 | def get_ckpt_path(name, root, check=False): 41 | assert name in URL_MAP 42 | path = os.path.join(root, CKPT_MAP[name]) 43 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 44 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 45 | download(URL_MAP[name], path) 46 | md5 = md5_hash(path) 47 | assert md5 == MD5_MAP[name], md5 48 | return path 49 | 50 | class LPIPS(nn.Module): 51 | # Learned perceptual metric 52 | def __init__(self, use_dropout=True): 53 | super().__init__() 54 | self.scaling_layer = ScalingLayer() 55 | self.chns = [64, 128, 256, 512, 512] # vg16 features 56 | self.net = vgg16(pretrained=True, requires_grad=False) 57 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 58 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 59 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 60 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 61 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 62 | self.load_from_pretrained() 63 | for param in self.parameters(): 64 | param.requires_grad = False 65 | 66 | def load_from_pretrained(self, name="vgg_lpips"): 67 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 68 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 69 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 70 | 71 | @classmethod 72 | def from_pretrained(cls, name="vgg_lpips"): 73 | if name != "vgg_lpips": 74 | raise NotImplementedError 75 | model = cls() 76 | ckpt = get_ckpt_path(name) 77 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 78 | return model 79 | 80 | def forward(self, input, target): 81 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 82 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 83 | feats0, feats1, diffs = {}, {}, {} 84 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 85 | for kk in range(len(self.chns)): 86 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 87 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 88 | 89 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 90 | val = res[0] 91 | for l in range(1, len(self.chns)): 92 | val += res[l] 93 | return val 94 | 95 | 96 | class ScalingLayer(nn.Module): 97 | def __init__(self): 98 | super(ScalingLayer, self).__init__() 99 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 100 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 101 | 102 | def forward(self, inp): 103 | return (inp - self.shift) / self.scale 104 | 105 | 106 | class NetLinLayer(nn.Module): 107 | """ A single linear layer which does a 1x1 conv """ 108 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 109 | super(NetLinLayer, self).__init__() 110 | layers = [nn.Dropout(), ] if (use_dropout) else [] 111 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 112 | self.model = nn.Sequential(*layers) 113 | 114 | 115 | class vgg16(torch.nn.Module): 116 | def __init__(self, requires_grad=False, pretrained=True): 117 | super(vgg16, self).__init__() 118 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 119 | self.slice1 = torch.nn.Sequential() 120 | self.slice2 = torch.nn.Sequential() 121 | self.slice3 = torch.nn.Sequential() 122 | self.slice4 = torch.nn.Sequential() 123 | self.slice5 = torch.nn.Sequential() 124 | self.N_slices = 5 125 | for x in range(4): 126 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 127 | for x in range(4, 9): 128 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 129 | for x in range(9, 16): 130 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 131 | for x in range(16, 23): 132 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 133 | for x in range(23, 30): 134 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 135 | if not requires_grad: 136 | for param in self.parameters(): 137 | param.requires_grad = False 138 | 139 | def forward(self, X): 140 | h = self.slice1(X) 141 | h_relu1_2 = h 142 | h = self.slice2(h) 143 | h_relu2_2 = h 144 | h = self.slice3(h) 145 | h_relu3_3 = h 146 | h = self.slice4(h) 147 | h_relu4_3 = h 148 | h = self.slice5(h) 149 | h_relu5_3 = h 150 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 151 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 152 | return out 153 | 154 | 155 | def normalize_tensor(x,eps=1e-10): 156 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 157 | return x/(norm_factor+eps) 158 | 159 | 160 | def spatial_average(x, keepdim=True): 161 | return x.mean([2,3],keepdim=keepdim) 162 | 163 | -------------------------------------------------------------------------------- /models/vqvae_conv3d_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import distributed as dist_fn 8 | 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | # Copyright 2018 The Sonnet Authors. All Rights Reserved. 15 | # 16 | # Licensed under the Apache License, Version 2.0 (the "License"); 17 | # you may not use this file except in compliance with the License. 18 | # You may obtain a copy of the License at 19 | # 20 | # http://www.apache.org/licenses/LICENSE-2.0 21 | # 22 | # Unless required by applicable law or agreed to in writing, software 23 | # distributed under the License is distributed on an "AS IS" BASIS, 24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 | # See the License for the specific language governing permissions and 26 | # limitations under the License. 27 | # ============================================================================ 28 | 29 | 30 | # Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch 31 | 32 | 33 | class Quantize(nn.Module): 34 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): 35 | super().__init__() 36 | 37 | self.dim = dim 38 | self.n_embed = n_embed 39 | self.decay = decay 40 | self.eps = eps 41 | 42 | embed = torch.randn(dim, n_embed) 43 | self.register_buffer("embed", embed) 44 | self.register_buffer("cluster_size", torch.zeros(n_embed)) 45 | self.register_buffer("embed_avg", embed.clone()) 46 | 47 | def forward(self, input): 48 | flatten = input.reshape(-1, self.dim) 49 | dist = ( 50 | flatten.pow(2).sum(1, keepdim=True) 51 | - 2 * flatten @ self.embed 52 | + self.embed.pow(2).sum(0, keepdim=True) 53 | ) 54 | _, embed_ind = (-dist).max(1) 55 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) 56 | embed_ind = embed_ind.view(*input.shape[:-1]) 57 | quantize = self.embed_code(embed_ind) 58 | 59 | if self.training: 60 | embed_onehot_sum = embed_onehot.sum(0) 61 | embed_sum = flatten.transpose(0, 1) @ embed_onehot 62 | 63 | dist_fn.all_reduce(embed_onehot_sum) 64 | dist_fn.all_reduce(embed_sum) 65 | 66 | self.cluster_size.data.mul_(self.decay).add_( 67 | embed_onehot_sum, alpha=1 - self.decay 68 | ) 69 | self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) 70 | n = self.cluster_size.sum() 71 | cluster_size = ( 72 | (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n 73 | ) 74 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) 75 | self.embed.data.copy_(embed_normalized) 76 | 77 | diff = (quantize.detach() - input).pow(2).mean() 78 | quantize = input + (quantize - input).detach() 79 | 80 | return quantize, diff, embed_ind 81 | 82 | def embed_code(self, embed_id): 83 | return F.embedding(embed_id, self.embed.transpose(0, 1)) 84 | 85 | 86 | class ResBlock(nn.Module): 87 | def __init__(self, in_channel, channel): 88 | super().__init__() 89 | 90 | self.conv = nn.Sequential( 91 | nn.ReLU(), 92 | nn.Conv2d(in_channel, channel, 3, padding=1), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(channel, in_channel, 1), 95 | ) 96 | 97 | def forward(self, input): 98 | out = self.conv(input) 99 | out += input 100 | 101 | return out 102 | 103 | class Encoder(nn.Module): 104 | def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): 105 | super().__init__() 106 | 107 | if stride == 4: 108 | blocks = [ 109 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), 112 | nn.ReLU(inplace=True), 113 | nn.Conv2d(channel, channel, 3, padding=1), 114 | ] 115 | 116 | elif stride == 2: 117 | blocks = [ 118 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), 119 | nn.ReLU(inplace=True), 120 | nn.Conv2d(channel // 2, channel, 3, padding=1), 121 | ] 122 | 123 | for i in range(n_res_block): 124 | blocks.append(ResBlock(channel, n_res_channel)) 125 | 126 | blocks.append(nn.ReLU(inplace=True)) 127 | 128 | self.blocks = nn.Sequential(*blocks) 129 | 130 | def forward(self, input): 131 | return self.blocks(input) 132 | 133 | 134 | class Decoder(nn.Module): 135 | def __init__( 136 | self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride 137 | ): 138 | super().__init__() 139 | 140 | blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] 141 | 142 | for i in range(n_res_block): 143 | blocks.append(ResBlock(channel, n_res_channel)) 144 | 145 | blocks.append(nn.ReLU(inplace=True)) 146 | 147 | if stride == 4: 148 | blocks.extend( 149 | [ 150 | nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), 151 | nn.ReLU(inplace=True), 152 | nn.ConvTranspose2d( 153 | channel // 2, out_channel, 4, stride=2, padding=1 154 | ), 155 | ] 156 | ) 157 | 158 | elif stride == 2: 159 | blocks.append( 160 | nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) 161 | ) 162 | 163 | self.blocks = nn.Sequential(*blocks) 164 | 165 | def forward(self, input): 166 | return self.blocks(input) 167 | 168 | # applies a sequence of conv3d layers with relu activation for preserving temporal coherence 169 | class Conv3dLatentPostnet(nn.Module): 170 | def __init__(self, channels): 171 | super().__init__() 172 | self.conv3d = nn.Sequential( 173 | self.conv3d_layer(channels=channels), 174 | self.conv3d_layer(channels=channels), 175 | self.conv3d_layer(channels=channels, is_final=True) 176 | ) 177 | 178 | def conv3d_layer(self, channels=128, kernel_size=3, padding=1, is_final=False): 179 | if is_final: 180 | return nn.Sequential( 181 | nn.Conv3d(channels, channels, kernel_size, padding=padding) 182 | ) 183 | else: 184 | return nn.Sequential( 185 | nn.Conv3d(channels, channels, kernel_size, padding=padding), 186 | nn.ReLU() 187 | ) 188 | 189 | def forward(self, input): 190 | return self.conv3d(input) 191 | 192 | class VQVAE(nn.Module): 193 | def __init__( 194 | self, 195 | in_channel=3, 196 | channel=128, 197 | n_res_block=2, 198 | n_res_channel=32, 199 | embed_dim=64, 200 | n_embed=512, 201 | decay=0.99, 202 | residual=False, 203 | ): 204 | super().__init__() 205 | 206 | self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) 207 | self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) 208 | self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1) 209 | self.quantize_t = Quantize(embed_dim, n_embed) 210 | self.dec_t = Decoder( 211 | embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2 212 | ) 213 | self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1) 214 | self.quantize_b = Quantize(embed_dim, n_embed) 215 | self.upsample_t = nn.ConvTranspose2d( 216 | embed_dim, embed_dim, 4, stride=2, padding=1 217 | ) 218 | self.dec = Decoder( 219 | embed_dim + embed_dim, 220 | in_channel, 221 | channel, 222 | n_res_block, 223 | n_res_channel, 224 | stride=4, 225 | ) 226 | 227 | self.residual = residual 228 | 229 | # bottom encoded dimension - batch_size x 128 x 64 x 64 230 | self.conv3d_encoded_b = Conv3dLatentPostnet(128) 231 | self.conv3d_encoded_t = Conv3dLatentPostnet(128) 232 | 233 | # self.conv3d_encoded_b = nn.Conv3d(128, 128, 3, padding=1) 234 | # self.conv3d_encoded_t = nn.Conv3d(128, 128, 3, padding=1) 235 | 236 | 237 | def only_encode(self, input): 238 | enc_b = self.enc_b(input) 239 | enc_t = self.enc_t(enc_b) 240 | 241 | return enc_b, enc_t 242 | 243 | def forward(self, input): 244 | enc_b, enc_t = self.only_encode(input) 245 | 246 | # enc_b dimension -> batch_size x 128 x 64 x 64 247 | enc_b, enc_t = enc_b.unsqueeze(0).permute(0, 2, 1, 3, 4), enc_t.unsqueeze(0).permute(0, 2, 1, 3, 4) 248 | 249 | # apply the 3d conv on the encoded representations 250 | enc_b_conv, enc_t_conv = self.conv3d_encoded_b(enc_b), self.conv3d_encoded_t(enc_t) 251 | enc_b_conv, enc_t_conv = enc_b_conv.squeeze(0).permute(1, 0, 2, 3), enc_t_conv.squeeze(0).permute(1, 0, 2, 3) 252 | 253 | # generate the quantized representation 254 | quant_t, quant_b, diff, _, _ = self.encode_quantized(enc_b_conv, enc_t_conv) 255 | 256 | # generate the decoded representation 257 | dec = self.decode(quant_t, quant_b) 258 | 259 | return dec, diff 260 | 261 | def encode_quantized(self, enc_b, enc_t): 262 | # enc_b = self.enc_b(input) 263 | # enc_t = self.enc_t(enc_b) 264 | 265 | quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) 266 | quant_t, diff_t, id_t = self.quantize_t(quant_t) 267 | quant_t = quant_t.permute(0, 3, 1, 2) 268 | diff_t = diff_t.unsqueeze(0) 269 | 270 | dec_t = self.dec_t(quant_t) 271 | enc_b = torch.cat([dec_t, enc_b], 1) 272 | 273 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1) 274 | quant_b, diff_b, id_b = self.quantize_b(quant_b) 275 | quant_b = quant_b.permute(0, 3, 1, 2) 276 | diff_b = diff_b.unsqueeze(0) 277 | 278 | return quant_t, quant_b, diff_t + diff_b, id_t, id_b 279 | 280 | def decode(self, quant_t, quant_b): 281 | upsample_t = self.upsample_t(quant_t) 282 | quant = torch.cat([upsample_t, quant_b], 1) 283 | dec = self.dec(quant) 284 | 285 | return dec 286 | 287 | def decode_code(self, code_t, code_b): 288 | quant_t = self.quantize_t.embed_code(code_t) 289 | quant_t = quant_t.permute(0, 3, 1, 2) 290 | quant_b = self.quantize_b.embed_code(code_b) 291 | quant_b = quant_b.permute(0, 3, 1, 2) 292 | 293 | dec = self.decode(quant_t, quant_b) 294 | 295 | return dec 296 | 297 | # ============================================================================ 298 | # ============================= VQVAE BLOB2FULL ============================== 299 | # ============================================================================ 300 | 301 | class Encode(nn.Module): 302 | def __init__( 303 | self, 304 | in_channel, 305 | channel, 306 | n_res_block, 307 | n_res_channel, 308 | embed_dim, 309 | n_embed, 310 | decay, 311 | ): 312 | super().__init__() 313 | 314 | self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) 315 | self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) 316 | self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1) 317 | self.quantize_t = Quantize(embed_dim, n_embed) 318 | self.dec_t = Decoder( 319 | embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2 320 | ) 321 | self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1) 322 | self.quantize_b = Quantize(embed_dim, n_embed) 323 | 324 | def forward(self, input): 325 | enc_b = self.enc_b(input) 326 | enc_t = self.enc_t(enc_b) 327 | 328 | quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) 329 | quant_t, diff_t, id_t = self.quantize_t(quant_t) 330 | quant_t = quant_t.permute(0, 3, 1, 2) 331 | diff_t = diff_t.unsqueeze(0) 332 | 333 | dec_t = self.dec_t(quant_t) 334 | enc_b = torch.cat([dec_t, enc_b], 1) 335 | 336 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1) 337 | quant_b, diff_b, id_b = self.quantize_b(quant_b) 338 | quant_b = quant_b.permute(0, 3, 1, 2) 339 | diff_b = diff_b.unsqueeze(0) 340 | 341 | return quant_t, quant_b, diff_t + diff_b, id_t, id_b 342 | 343 | class VQVAE_B2F(nn.Module): 344 | def __init__( 345 | self, 346 | in_channel=3, 347 | channel=128, 348 | n_res_block=2, 349 | n_res_channel=32, 350 | embed_dim=64, 351 | n_embed=512, 352 | decay=0.99, 353 | ): 354 | super().__init__() 355 | 356 | self.encode_face = Encode(in_channel, 357 | channel, 358 | n_res_block, 359 | n_res_channel, 360 | embed_dim, 361 | n_embed, 362 | decay) 363 | 364 | self.encode_rhand = Encode(in_channel, 365 | channel, 366 | n_res_block, 367 | n_res_channel, 368 | embed_dim, 369 | n_embed, 370 | decay) 371 | 372 | self.encode_lhand = Encode(in_channel, 373 | channel, 374 | n_res_block, 375 | n_res_channel, 376 | embed_dim, 377 | n_embed, 378 | decay) 379 | 380 | self.upsample_t = nn.ConvTranspose2d( 381 | embed_dim, embed_dim, 4, stride=2, padding=1 382 | ) 383 | self.dec = Decoder( 384 | embed_dim + embed_dim, 385 | in_channel, 386 | channel, 387 | n_res_block, 388 | n_res_channel, 389 | stride=4, 390 | ) 391 | 392 | def forward(self, input, save_idx=None, visual_folder=None): 393 | face, rhand, lhand = input 394 | 395 | quant_t_face, quant_b_face, diff_face, _, _ = self.encode_face(face) 396 | quant_t_rhand, quant_b_rhand, diff_rhand, _, _ = self.encode_rhand(rhand) 397 | quant_t_lhand, quant_b_lhand, diff_lhand, _, _ = self.encode_lhand(lhand) 398 | 399 | quant_t = quant_t_face + quant_t_rhand + quant_t_lhand 400 | quant_b = quant_b_face + quant_b_rhand + quant_b_lhand 401 | diff = diff_face + diff_rhand + diff_lhand 402 | 403 | dec = self.decode(quant_t, quant_b) 404 | 405 | if save_idx is not None: 406 | def save_img(img, save_idx, i, dtype): 407 | img = (img.detach().cpu() + 0.5).numpy() 408 | img = np.transpose(img, (1,2,0)) 409 | fig = plt.imshow(img, interpolation='nearest') 410 | fig.axes.get_xaxis().set_visible(False) 411 | fig.axes.get_yaxis().set_visible(False) 412 | plt.savefig('{}/{}_{}_{}.jpg'.\ 413 | format(visual_folder, save_idx, i, dtype), bbox_inches='tight') 414 | 415 | for i in tqdm(range(min(face.shape[0], 8))): 416 | save_img(face[i], save_idx, i, 'face') 417 | save_img(rhand[i], save_idx, i, 'rhand') 418 | save_img(lhand[i], save_idx, i, 'lhand') 419 | save_img(x_hat[i], save_idx, i, 'reconstructed') 420 | 421 | return dec, diff 422 | 423 | def decode(self, quant_t, quant_b): 424 | upsample_t = self.upsample_t(quant_t) 425 | quant = torch.cat([upsample_t, quant_b], 1) 426 | dec = self.dec(quant) 427 | 428 | return dec 429 | 430 | def decode_code(self, code_t, code_b): 431 | quant_t = self.quantize_t.embed_code(code_t) 432 | quant_t = quant_t.permute(0, 3, 1, 2) 433 | quant_b = self.quantize_b.embed_code(code_b) 434 | quant_b = quant_b.permute(0, 3, 1, 2) 435 | 436 | dec = self.decode(quant_t, quant_b) 437 | 438 | return dec -------------------------------------------------------------------------------- /preprocessing/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for preprocessing the video dataset 3 | Processes a video, generates face crop, and creates a constant crop for all face crops generated 4 | Writes the generated face crops (as frames) to a video 5 | ''' 6 | 7 | import math 8 | import os 9 | import os.path as osp 10 | from tqdm import tqdm 11 | import gc 12 | from glob import glob 13 | 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | import mediapipe as mp 17 | 18 | def display_image(image, requires_colorization=True): 19 | plt.figure() 20 | if requires_colorization: 21 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 22 | else: 23 | plt.imshow(image) 24 | 25 | def get_iou(bb1, bb2): 26 | """ 27 | Calculate the Intersection over Union (IoU) of two bounding boxes. 28 | 29 | Parameters 30 | ---------- 31 | bb1 : dict 32 | Keys: {'x1', 'x2', 'y1', 'y2'} 33 | The (x1, y1) position is at the top left corner, 34 | the (x2, y2) position is at the bottom right corner 35 | bb2 : dict 36 | Keys: {'x1', 'x2', 'y1', 'y2'} 37 | The (x, y) position is at the top left corner, 38 | the (x2, y2) position is at the bottom right corner 39 | 40 | Returns 41 | ------- 42 | float 43 | in [0, 1] 44 | """ 45 | assert bb1['x1'] < bb1['x2'] 46 | assert bb1['y1'] < bb1['y2'] 47 | assert bb2['x1'] < bb2['x2'] 48 | assert bb2['y1'] < bb2['y2'] 49 | 50 | # determine the coordinates of the intersection rectangle 51 | x_left = max(bb1['x1'], bb2['x1']) 52 | y_top = max(bb1['y1'], bb2['y1']) 53 | x_right = min(bb1['x2'], bb2['x2']) 54 | y_bottom = min(bb1['y2'], bb2['y2']) 55 | 56 | if x_right < x_left or y_bottom < y_top: 57 | return 0.0 58 | 59 | # The intersection of two axis-aligned bounding boxes is always an 60 | # axis-aligned bounding box 61 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 62 | 63 | # compute the area of both AABBs 64 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) 65 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) 66 | 67 | ''' 68 | comptute the iou by taking the intersecting area over the sum of bounding boxes 69 | ''' 70 | 71 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 72 | assert iou >= 0.0 73 | assert iou <= 1.0 74 | return iou 75 | 76 | 77 | ''' 78 | Function for writing frames to video 79 | ''' 80 | def write_frames_to_video(frames, current_frames, video_order_id, VIDEO_DIR, fps=30): 81 | height, width, _ = frames[0].shape 82 | # the frame coordinates have to be taken from the bounding box 83 | video_path = osp.join(VIDEO_DIR, str(video_order_id).zfill(5) + '.mp4') 84 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) 85 | for frame_index in current_frames: 86 | video.write(frames[frame_index]) 87 | 88 | cv2.destroyAllWindows() 89 | video.release() 90 | 91 | print(f'Video {video_path} written successfully') 92 | 93 | 94 | def crop_get_video(frames, current_indexes, bounding_box, VIDEO_DIR, video_order_id, fps=30): 95 | left, top, right, down = bounding_box['x1'], bounding_box['y1'], bounding_box['x2'], bounding_box['y2'] 96 | width, height = right - left, down - top 97 | video_path = osp.join(VIDEO_DIR, str(video_order_id).zfill(5) + '.mp4') 98 | 99 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) 100 | 101 | for index in current_indexes: 102 | current_frame = frames[index] 103 | cropped = current_frame[top : down, left : right] 104 | 105 | video.write(cropped) 106 | 107 | print(f'Writing to file : {video_order_id}') 108 | cv2.destroyAllWindows() 109 | 110 | 111 | ''' 112 | Function to return the bb coordinates of the cropped face 113 | ''' 114 | def crop_face_coordinates(image, x_px, y_px, width_px, height_px): 115 | # using different thresholds/bounds for the upper and lower faces 116 | image_height, image_width, _ = image.shape 117 | lower_face_buffer, upper_face_buffer = 0.25, 0.65 118 | min_x, min_y, max_x, max_y = x_px, y_px, x_px + width_px, y_px + height_px 119 | 120 | x_left = max(0, int(min_x - (max_x - min_x) * lower_face_buffer)) 121 | x_right = min(image_width, int(max_x + (max_x - min_x) * lower_face_buffer)) 122 | y_top = max(0, int(min_y - (max_y - min_y) * upper_face_buffer)) 123 | y_down = min(image_height, int(max_y + (max_y - min_y) * lower_face_buffer)) 124 | 125 | size = max(x_right - x_left, y_down - y_top) 126 | sw = int((x_left + x_right)/2 - size // 2) 127 | 128 | return sw, y_top, sw+size, y_down 129 | 130 | 131 | ''' 132 | function to return the bb coordinates given the image 133 | ''' 134 | def bb_coordinates(image): 135 | mp_face_detection = mp.solutions.face_detection 136 | face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5) 137 | 138 | results = face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 139 | 140 | result = results.detections 141 | image_height, image_width, _ = image.shape 142 | 143 | if result is None: 144 | return -1, -1, -1, -1 145 | 146 | bb_values = result[0].location_data.relative_bounding_box 147 | 148 | normalized_x, normalized_y, normalized_width, normalized_height = \ 149 | bb_values.xmin, bb_values.ymin, bb_values.width, bb_values.height 150 | 151 | # the bounding box coordinates are given as normalized values, unnormalize them by multiplying by height and width 152 | x_px = min(math.floor(normalized_x * image_width), image_width - 1) 153 | y_px = min(math.floor(normalized_y * image_height), image_height - 1) 154 | width_px = min(math.floor(normalized_width * image_width), image_width - 1) 155 | height_px = min(math.floor(normalized_height * image_height), image_height - 1) 156 | 157 | return x_px, y_px, width_px, height_px 158 | 159 | ''' 160 | the formula for using the intersection over union works differently for different resolutions 161 | this is because it is dependent on the pixel information and not necessarily what's inside 162 | the frames are read till either i) end of the frames are reached, or ii) 4000 frames (cpu limit) are read 163 | ''' 164 | 165 | def process_frames(frames, VIDEO_DIR): 166 | # declaring global scopre for the video_order_id variable 167 | global video_order_id 168 | 169 | current_frames = list() 170 | mean_bounding_box = dict() 171 | iou_threshold = 0.7 172 | frame_count = 0 173 | frame_writing_threshold = 30 # minimum number of frames to write 174 | bb_prev_mean = dict() 175 | 176 | for index, frame in tqdm(enumerate(frames)): 177 | image_height, image_width, _ = frame.shape 178 | x_px, y_px, width_px, height_px = bb_coordinates(frame) 179 | 180 | if x_px == -1: 181 | if len(current_frames) > frame_writing_threshold: 182 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id) 183 | video_order_id += 1 184 | 185 | # reset 186 | current_frames = list() 187 | frame_count = 0 188 | mean_bounding_box = dict() 189 | bb_prev_mean = dict() 190 | 191 | else: 192 | left, top, right, bottom = crop_face_coordinates(frame, x_px, y_px, width_px, height_px) 193 | current_bounding_box = {'x1' : left, 'x2' : right, 'y1' : top, 'y2' : bottom} 194 | 195 | if len(mean_bounding_box) == 0: 196 | 197 | mean_bounding_box = current_bounding_box.copy() 198 | bb_prev_mean = current_bounding_box.copy() 199 | 200 | frame_count += 1 201 | current_frames.append(index) 202 | 203 | else: 204 | # UPDATE - compute the iou between the current bounding box and the mean bounding box 205 | iou = get_iou(bb_prev_mean, current_bounding_box) 206 | 207 | if iou < iou_threshold: 208 | mean_left, mean_right, mean_top, mean_down = mean_bounding_box['x1'], mean_bounding_box['x2'], mean_bounding_box['y1'], mean_bounding_box['y2'] 209 | 210 | if len(current_frames) > frame_writing_threshold: 211 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id) 212 | video_order_id += 1 213 | 214 | current_frames = list() 215 | frame_count = 0 216 | mean_bounding_box = dict() 217 | 218 | # Add the current bounding box to the list of bounding boxes and compute the mean 219 | else: 220 | mean_bounding_box['x1'] = min(mean_bounding_box['x1'], current_bounding_box['x1']) 221 | mean_bounding_box['y1'] = min(mean_bounding_box['y1'], current_bounding_box['y1']) 222 | mean_bounding_box['x2'] = max(mean_bounding_box['x2'], current_bounding_box['x2']) 223 | mean_bounding_box['y2'] = max(mean_bounding_box['y2'], current_bounding_box['y2']) 224 | 225 | # update the coordinates of the mean bounding box 226 | for item in bb_prev_mean.keys(): 227 | bb_prev_mean[item] = int((bb_prev_mean[item] * frame_count + current_bounding_box[item])/(frame_count + 1)) 228 | 229 | frame_count += 1 230 | current_frames.append(index) 231 | 232 | if len(current_frames) > frame_writing_threshold: 233 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id) 234 | video_order_id += 1 235 | 236 | 237 | ''' 238 | method for processing a single video 239 | reads video frames and calls function to process the frames 240 | ''' 241 | def process_video(video_file, processed_videos_dir): 242 | video_stream = cv2.VideoCapture(video_file) 243 | print(f'Total number of frames in the current video : {video_stream.get(cv2.CAP_PROP_FRAME_COUNT)}') 244 | print(f'Processing video file {video_file}') 245 | frames = list() 246 | 247 | # keep reading the frames if the processing of the current frames is over 248 | frame_reading_threshold = 8000 249 | frames_processed = 0 250 | frames_processed_threshold = 2000 251 | 252 | video_file_name = osp.basename(video_file).split('.')[0] 253 | VIDEO_DIR = osp.join(processed_videos_dir, video_file_name) 254 | 255 | os.makedirs(VIDEO_DIR, exist_ok=True) 256 | 257 | ret, frame = video_stream.read() 258 | while ret: 259 | frames.append(frame) 260 | frames_processed += 1 261 | 262 | if frames_processed % frames_processed_threshold == 0: 263 | print(f'{frames_processed} frames read') 264 | 265 | if frames_processed%frame_reading_threshold == 0: 266 | # perform processing and generate frame videos 267 | print(f'Processing the frames for the current batch') 268 | process_frames(frames, VIDEO_DIR) 269 | print(f'Done with processing of the current batch') 270 | 271 | del frames 272 | gc.collect() 273 | frames = list() # clear out the frames 274 | 275 | ret, frame = video_stream.read() 276 | 277 | # check if more frames that were read need to be processed 278 | if len(frames) != 0: 279 | print(f'Processing remaining frames') 280 | process_frames(frames, VIDEO_DIR) 281 | 282 | video_stream.release() 283 | 284 | 285 | ''' 286 | method to process multiple videos 287 | ''' 288 | def process_videos(video_dir): 289 | video_files = glob(video_dir + '/*.mp4') 290 | 291 | for video_file in tqdm(video_files): 292 | process_video(video_file) 293 | 294 | 295 | if __name__ == '__main__': 296 | video_dir = 'videos_dir' 297 | process_videos(video_dir) -------------------------------------------------------------------------------- /results/inference_pipeline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/inference_pipeline.gif -------------------------------------------------------------------------------- /results/training_pipeline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/training_pipeline.gif -------------------------------------------------------------------------------- /results/v2v_comparisons1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_comparisons1.gif -------------------------------------------------------------------------------- /results/v2v_comparisons31.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_comparisons31.gif -------------------------------------------------------------------------------- /results/v2v_face_swapping.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_face_swapping.gif -------------------------------------------------------------------------------- /results/v2v_faceswapping_looped2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_faceswapping_looped2.gif -------------------------------------------------------------------------------- /results/v2v_more_result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_more_result.gif -------------------------------------------------------------------------------- /results/v2v_results1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results1.gif -------------------------------------------------------------------------------- /results/v2v_results2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results2.gif -------------------------------------------------------------------------------- /results/v2v_results3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results3.gif -------------------------------------------------------------------------------- /results/v2v_results4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results4.gif -------------------------------------------------------------------------------- /results/v2v_same_identity1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity1.gif -------------------------------------------------------------------------------- /results/v2v_same_identity2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity2.gif -------------------------------------------------------------------------------- /results/v2v_same_identity3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity3.gif -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, floor, sin 2 | 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | class CosineLR(lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, lr_min, lr_max, step_size): 8 | self.lr_min = lr_min 9 | self.lr_max = lr_max 10 | self.step_size = step_size 11 | self.iteration = 0 12 | 13 | super().__init__(optimizer, -1) 14 | 15 | def get_lr(self): 16 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 17 | 1 + cos(self.iteration / self.step_size * pi) 18 | ) 19 | self.iteration += 1 20 | 21 | if self.iteration == self.step_size: 22 | self.iteration = 0 23 | 24 | return [lr for base_lr in self.base_lrs] 25 | 26 | 27 | class PowerLR(lr_scheduler._LRScheduler): 28 | def __init__(self, optimizer, lr_min, lr_max, warmup): 29 | self.lr_min = lr_min 30 | self.lr_max = lr_max 31 | self.warmup = warmup 32 | self.iteration = 0 33 | 34 | super().__init__(optimizer, -1) 35 | 36 | def get_lr(self): 37 | if self.iteration < self.warmup: 38 | lr = ( 39 | self.lr_min + (self.lr_max - self.lr_min) / self.warmup * self.iteration 40 | ) 41 | 42 | else: 43 | lr = self.lr_max * (self.iteration - self.warmup + 1) ** -0.5 44 | 45 | self.iteration += 1 46 | 47 | return [lr for base_lr in self.base_lrs] 48 | 49 | 50 | class SineLR(lr_scheduler._LRScheduler): 51 | def __init__(self, optimizer, lr_min, lr_max, step_size): 52 | self.lr_min = lr_min 53 | self.lr_max = lr_max 54 | self.step_size = step_size 55 | self.iteration = 0 56 | 57 | super().__init__(optimizer, -1) 58 | 59 | def get_lr(self): 60 | lr = self.lr_min + (self.lr_max - self.lr_min) * sin( 61 | self.iteration / self.step_size * pi 62 | ) 63 | self.iteration += 1 64 | 65 | if self.iteration == self.step_size: 66 | self.iteration = 0 67 | 68 | return [lr for base_lr in self.base_lrs] 69 | 70 | 71 | class LinearLR(lr_scheduler._LRScheduler): 72 | def __init__(self, optimizer, lr_min, lr_max, warmup, step_size): 73 | self.lr_min = lr_min 74 | self.lr_max = lr_max 75 | self.step_size = step_size 76 | self.warmup = warmup 77 | self.iteration = 0 78 | 79 | super().__init__(optimizer, -1) 80 | 81 | def get_lr(self): 82 | if self.iteration < self.warmup: 83 | lr = self.lr_max 84 | 85 | else: 86 | lr = self.lr_max + (self.iteration - self.warmup) * ( 87 | self.lr_min - self.lr_max 88 | ) / (self.step_size - self.warmup) 89 | self.iteration += 1 90 | 91 | if self.iteration == self.step_size: 92 | self.iteration = 0 93 | 94 | return [lr for base_lr in self.base_lrs] 95 | 96 | 97 | class CLR(lr_scheduler._LRScheduler): 98 | def __init__(self, optimizer, lr_min, lr_max, step_size): 99 | self.epoch = 0 100 | self.lr_min = lr_min 101 | self.lr_max = lr_max 102 | self.current_lr = lr_min 103 | self.step_size = step_size 104 | 105 | super().__init__(optimizer, -1) 106 | 107 | def get_lr(self): 108 | cycle = floor(1 + self.epoch / (2 * self.step_size)) 109 | x = abs(self.epoch / self.step_size - 2 * cycle + 1) 110 | lr = self.lr_min + (self.lr_max - self.lr_min) * max(0, 1 - x) 111 | self.current_lr = lr 112 | 113 | self.epoch += 1 114 | 115 | return [lr for base_lr in self.base_lrs] 116 | 117 | 118 | class Warmup(lr_scheduler._LRScheduler): 119 | def __init__(self, optimizer, model_dim, factor=1, warmup=16000): 120 | self.optimizer = optimizer 121 | self.model_dim = model_dim 122 | self.factor = factor 123 | self.warmup = warmup 124 | self.iteration = 0 125 | 126 | super().__init__(optimizer, -1) 127 | 128 | def get_lr(self): 129 | self.iteration += 1 130 | lr = ( 131 | self.factor 132 | * self.model_dim ** (-0.5) 133 | * min(self.iteration ** (-0.5), self.iteration * self.warmup ** (-1.5)) 134 | ) 135 | 136 | return [lr for base_lr in self.base_lrs] 137 | 138 | 139 | # Copyright 2019 fastai 140 | 141 | # Licensed under the Apache License, Version 2.0 (the "License"); 142 | # you may not use this file except in compliance with the License. 143 | # You may obtain a copy of the License at 144 | 145 | # http://www.apache.org/licenses/LICENSE-2.0 146 | 147 | # Unless required by applicable law or agreed to in writing, software 148 | # distributed under the License is distributed on an "AS IS" BASIS, 149 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 150 | # See the License for the specific language governing permissions and 151 | # limitations under the License. 152 | 153 | 154 | # Borrowed from https://github.com/fastai/fastai and changed to make it runs like PyTorch lr scheduler 155 | 156 | 157 | class CycleAnnealScheduler: 158 | def __init__( 159 | self, optimizer, lr_max, lr_divider, cut_point, step_size, momentum=None 160 | ): 161 | self.lr_max = lr_max 162 | self.lr_divider = lr_divider 163 | self.cut_point = step_size // cut_point 164 | self.step_size = step_size 165 | self.iteration = 0 166 | self.cycle_step = int(step_size * (1 - cut_point / 100) / 2) 167 | self.momentum = momentum 168 | self.optimizer = optimizer 169 | 170 | def get_lr(self): 171 | if self.iteration > 2 * self.cycle_step: 172 | cut = (self.iteration - 2 * self.cycle_step) / ( 173 | self.step_size - 2 * self.cycle_step 174 | ) 175 | lr = self.lr_max * (1 + (cut * (1 - 100) / 100)) / self.lr_divider 176 | 177 | elif self.iteration > self.cycle_step: 178 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 179 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 180 | 181 | else: 182 | cut = self.iteration / self.cycle_step 183 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 184 | 185 | return lr 186 | 187 | def get_momentum(self): 188 | if self.iteration > 2 * self.cycle_step: 189 | momentum = self.momentum[0] 190 | 191 | elif self.iteration > self.cycle_step: 192 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 193 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 194 | 195 | else: 196 | cut = self.iteration / self.cycle_step 197 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 198 | 199 | return momentum 200 | 201 | def step(self): 202 | lr = self.get_lr() 203 | 204 | if self.momentum is not None: 205 | momentum = self.get_momentum() 206 | 207 | self.iteration += 1 208 | 209 | if self.iteration == self.step_size: 210 | self.iteration = 0 211 | 212 | for group in self.optimizer.param_groups: 213 | group['lr'] = lr 214 | 215 | if self.momentum is not None: 216 | group['betas'] = (momentum, group['betas'][1]) 217 | 218 | return lr 219 | 220 | 221 | def anneal_linear(start, end, proportion): 222 | return start + proportion * (end - start) 223 | 224 | 225 | def anneal_cos(start, end, proportion): 226 | cos_val = cos(pi * proportion) + 1 227 | 228 | return end + (start - end) / 2 * cos_val 229 | 230 | 231 | class Phase: 232 | def __init__(self, start, end, n_iter, anneal_fn): 233 | self.start, self.end = start, end 234 | self.n_iter = n_iter 235 | self.anneal_fn = anneal_fn 236 | self.n = 0 237 | 238 | def step(self): 239 | self.n += 1 240 | 241 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 242 | 243 | def reset(self): 244 | self.n = 0 245 | 246 | @property 247 | def is_done(self): 248 | return self.n >= self.n_iter 249 | 250 | 251 | class CycleScheduler: 252 | def __init__( 253 | self, 254 | optimizer, 255 | lr_max, 256 | n_iter, 257 | momentum=(0.95, 0.85), 258 | divider=25, 259 | warmup_proportion=0.3, 260 | phase=('linear', 'cos'), 261 | ): 262 | self.optimizer = optimizer 263 | 264 | phase1 = int(n_iter * warmup_proportion) 265 | phase2 = n_iter - phase1 266 | lr_min = lr_max / divider 267 | 268 | phase_map = {'linear': anneal_linear, 'cos': anneal_cos} 269 | 270 | self.lr_phase = [ 271 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]), 272 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]), 273 | ] 274 | 275 | self.momentum = momentum 276 | 277 | if momentum is not None: 278 | mom1, mom2 = momentum 279 | self.momentum_phase = [ 280 | Phase(mom1, mom2, phase1, phase_map[phase[0]]), 281 | Phase(mom2, mom1, phase2, phase_map[phase[1]]), 282 | ] 283 | 284 | else: 285 | self.momentum_phase = [] 286 | 287 | self.phase = 0 288 | 289 | def step(self): 290 | lr = self.lr_phase[self.phase].step() 291 | 292 | if self.momentum is not None: 293 | momentum = self.momentum_phase[self.phase].step() 294 | 295 | else: 296 | momentum = None 297 | 298 | for group in self.optimizer.param_groups: 299 | group['lr'] = lr 300 | 301 | if self.momentum is not None: 302 | if 'betas' in group: 303 | group['betas'] = (momentum, group['betas'][1]) 304 | 305 | else: 306 | group['momentum'] = momentum 307 | 308 | if self.lr_phase[self.phase].is_done: 309 | self.phase += 1 310 | 311 | if self.phase >= len(self.lr_phase): 312 | for phase in self.lr_phase: 313 | phase.reset() 314 | 315 | for phase in self.momentum_phase: 316 | phase.reset() 317 | 318 | self.phase = 0 319 | 320 | return lr, momentum 321 | 322 | 323 | class LRFinder(lr_scheduler._LRScheduler): 324 | def __init__(self, optimizer, lr_min, lr_max, step_size, linear=False): 325 | ratio = lr_max / lr_min 326 | self.linear = linear 327 | self.lr_min = lr_min 328 | self.lr_mult = (ratio / step_size) if linear else ratio ** (1 / step_size) 329 | self.iteration = 0 330 | self.lrs = [] 331 | self.losses = [] 332 | 333 | super().__init__(optimizer, -1) 334 | 335 | def get_lr(self): 336 | lr = ( 337 | self.lr_mult * self.iteration 338 | if self.linear 339 | else self.lr_mult ** self.iteration 340 | ) 341 | lr = self.lr_min + lr if self.linear else self.lr_min * lr 342 | 343 | self.iteration += 1 344 | self.lrs.append(lr) 345 | 346 | return [lr for base_lr in self.base_lrs] 347 | 348 | def record(self, loss): 349 | self.losses.append(loss) 350 | 351 | def save(self, filename): 352 | with open(filename, 'w') as f: 353 | for lr, loss in zip(self.lrs, self.losses): 354 | f.write('{},{}\n'.format(lr, loss)) 355 | -------------------------------------------------------------------------------- /train_faceoff.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import random 5 | import os.path as osp 6 | 7 | import torch 8 | from torch import nn, optim 9 | from torch.utils.data import DataLoader 10 | 11 | from torchvision import datasets, transforms, utils 12 | 13 | from tqdm import tqdm 14 | 15 | from scheduler import CycleScheduler 16 | import distributed as dist 17 | 18 | from utils import * 19 | from config import DATASET, LATENT_LOSS_WEIGHT, SAMPLE_SIZE_FOR_VISUALIZATION 20 | 21 | criterion = nn.MSELoss() 22 | 23 | global_step = 0 24 | 25 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION 26 | 27 | dataset = DATASET 28 | 29 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results' 30 | 31 | def run_step(model, data, device, run='train'): 32 | img, S, ground_truth, source_images_original = process_data(data, device, dataset) 33 | 34 | out, latent_loss = model(img) 35 | 36 | out = out[:, :3] 37 | 38 | recon_loss = criterion(out, ground_truth) 39 | latent_loss = latent_loss.mean() 40 | 41 | if run == 'train': 42 | return recon_loss, latent_loss, S 43 | else: 44 | return ground_truth, img, out 45 | 46 | def blob2full_validation(model, img): 47 | face, rhand, lhand = img 48 | 49 | face = face[:sample_size] 50 | rhand = rhand[:sample_size] 51 | lhand = lhand[:sample_size] 52 | sample = face, rhand, lhand 53 | 54 | gt = gt[:sample_size] 55 | 56 | with torch.no_grad(): 57 | out, _ = model(sample) 58 | 59 | save_image(torch.cat([face, rhand, lhand, out, gt], 0), 60 | f"sample/{epoch + 1}_{i}.png") 61 | 62 | def jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder): 63 | for i, data in enumerate(tqdm(val_loader)): 64 | with torch.no_grad(): 65 | source_images, input, prediction = run_step(model, data, device, run='val') 66 | 67 | source_hulls = input[:, :3] 68 | background = input[:, 3:] 69 | 70 | saves = { 71 | 'source': source_hulls, 72 | 'background': background, 73 | 'prediction': prediction, 74 | 'source_images': source_images 75 | } 76 | 77 | if i % (len(val_loader) // 10) == 0 or run_type != 'train': 78 | def denormalize(x): 79 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 80 | 81 | for name in saves: 82 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4" 83 | frames = saves[name].detach().cpu() 84 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames] 85 | 86 | # os.makedirs(sample_folder, exist_ok=True) 87 | save_frames_as_video(frames, saveas, fps=25) 88 | 89 | def base_validation(model, val_loader, device, epoch, i, run_type, sample_folder): 90 | def get_proper_shape(x): 91 | shape = x.shape 92 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3]) 93 | 94 | for val_i, data in enumerate(tqdm(val_loader)): 95 | with torch.no_grad(): 96 | sample, _, out = run_step(model, data, device, 'val') 97 | 98 | if val_i % (len(val_loader)//10) == 0: # save 10 results 99 | 100 | def denormalize(x): 101 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 102 | 103 | # save_image( 104 | # torch.cat([sample[:3*3], out[:3*3]], 0), 105 | # f"{sample_folder}/{epoch + 1}_{i}_{val_i}.png") 106 | 107 | save_as_sample = f"{sample_folder}/{epoch+1}_{global_step}_{i}_sample.mp4" 108 | save_as_out = f"{sample_folder}/{epoch+1}_{global_step}_{i}_out.mp4" 109 | 110 | # save_as_sample = f"{sample_folder}/{val_i}_sample.mp4" 111 | # save_as_out = f"{sample_folder}/{val_i}_out.mp4" 112 | 113 | sample = sample.detach().cpu() 114 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample] 115 | 116 | out = out.detach().cpu() 117 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out] 118 | 119 | save_frames_as_video(sample, save_as_sample, fps=25) 120 | save_frames_as_video(out, save_as_out, fps=25) 121 | 122 | def validation(model, val_loader, device, epoch, i, sample_folder, run_type='train'): 123 | if dataset >= 6: 124 | jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder) 125 | else: 126 | base_validation(model, val_loader, device, epoch, i, run_type, sample_folder) 127 | 128 | def train(model, loader, val_loader, optimizer, scheduler, device, epoch, validate_at, checkpoint_dir, sample_folder): 129 | if dist.is_primary(): 130 | loader = tqdm(loader, file=sys.stdout) 131 | 132 | mse_sum = 0 133 | mse_n = 0 134 | 135 | for i, data in enumerate(loader): 136 | model.zero_grad() 137 | 138 | recon_loss, latent_loss, S = run_step(model, data, device) 139 | 140 | loss = recon_loss + LATENT_LOSS_WEIGHT * latent_loss 141 | 142 | loss.backward() 143 | 144 | if scheduler is not None: 145 | scheduler.step() 146 | 147 | optimizer.step() 148 | 149 | global global_step 150 | 151 | global_step += 1 152 | 153 | part_mse_sum = recon_loss.item() * S 154 | part_mse_n = S 155 | 156 | comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} 157 | comm = dist.all_gather(comm) 158 | 159 | for part in comm: 160 | mse_sum += part["mse_sum"] 161 | mse_n += part["mse_n"] 162 | 163 | if dist.is_primary(): 164 | lr = optimizer.param_groups[0]["lr"] 165 | 166 | loader.set_description( 167 | ( 168 | f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " 169 | f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; " 170 | f"lr: {lr:.5f}" 171 | ) 172 | ) 173 | 174 | if i % validate_at == 0: 175 | model.eval() 176 | 177 | # going inside the validation 178 | print(f'Going inside the validation') 179 | 180 | validation(model, val_loader, device, epoch, i, sample_folder) 181 | 182 | if dist.is_primary(): 183 | os.makedirs(checkpoint_dir, exist_ok=True) 184 | 185 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(i + 1).zfill(4)}.pt") 186 | 187 | model.train() 188 | 189 | def main(args): 190 | device = "cuda" 191 | 192 | args.distributed = dist.get_world_size() > 1 193 | 194 | default_transform = transforms.Compose( 195 | [ 196 | transforms.ToPILImage(), 197 | transforms.Resize(args.size), 198 | transforms.ToTensor(), 199 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 200 | ] 201 | ) 202 | 203 | loader, val_loader, model = get_loaders_and_models( 204 | args, dataset, default_transform, device, test=args.test) 205 | 206 | if args.distributed: 207 | model = nn.parallel.DistributedDataParallel( 208 | model, 209 | device_ids=[dist.get_local_rank()], 210 | output_device=dist.get_local_rank(), 211 | ) 212 | 213 | if args.ckpt: 214 | print(f'Loading pretrained checkpoint - {args.ckpt}') 215 | state_dict = torch.load(args.ckpt) 216 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() } 217 | try: 218 | model.module.load_state_dict(state_dict) 219 | except: 220 | model.load_state_dict(state_dict) 221 | 222 | if args.test: 223 | # test(loader, model, device) 224 | validation(model, val_loader, device, 0, 0, args.sample_folder, 'val') 225 | else: 226 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 227 | 228 | scheduler = None 229 | 230 | if args.sched == "cycle": 231 | scheduler = CycleScheduler( 232 | optimizer, 233 | args.lr, 234 | n_iter=len(loader) * args.epoch, 235 | momentum=None, 236 | warmup_proportion=0.05, 237 | ) 238 | 239 | for i in range(args.epoch): 240 | train(model, loader, val_loader, optimizer, scheduler, device, i, args.validate_at, args.checkpoint_dir, args.sample_folder) 241 | 242 | def get_random_name(cipher_length=5): 243 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 244 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)]) 245 | 246 | if __name__ == "__main__": 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument("--n_gpu", type=int, default=1) 249 | 250 | # port = ( 251 | # 2 ** 15 252 | # + 2 ** 14 253 | # + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 254 | # ) 255 | 256 | port = random.randint(51000, 52000) 257 | 258 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 259 | parser.add_argument("--batch_size", type=int, default=32) 260 | parser.add_argument("--size", type=int, default=256) 261 | parser.add_argument("--epoch", type=int, default=560) 262 | parser.add_argument("--lr", type=float, default=3e-4) 263 | parser.add_argument("--sched", type=str) 264 | parser.add_argument("--checkpoint_suffix", type=str, default='') 265 | parser.add_argument("--validate_at", type=int, default=1024) 266 | parser.add_argument("--ckpt", required=False) 267 | parser.add_argument("--test", action='store_true', required=False) 268 | parser.add_argument("--gray", action='store_true', required=False) 269 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty') 270 | parser.add_argument("--crossid", action='store_true', required=False) 271 | parser.add_argument("--sample_folder", type=str, default='samples') 272 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint') 273 | 274 | args = parser.parse_args() 275 | 276 | # args.n_gpu = torch.cuda.device_count() 277 | current_run = get_random_name() 278 | 279 | # sample_folder = sample_folder.format(args.checkpoint_suffix) 280 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run) 281 | os.makedirs(args.sample_folder, exist_ok=True) 282 | 283 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run) 284 | # os.makedirs(args.checkpoint_dir, exist_ok=True) 285 | 286 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix) 287 | 288 | print(args, flush=True) 289 | 290 | dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,)) 291 | -------------------------------------------------------------------------------- /train_faceoff_perceptual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import random 5 | import os.path as osp 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms, utils 13 | 14 | from scheduler import CycleScheduler 15 | import distributed as dist 16 | 17 | from utils import * 18 | from config import DATASET, LATENT_LOSS_WEIGHT, PERCEPTUAL_LOSS_WEIGHT, SAMPLE_SIZE_FOR_VISUALIZATION 19 | 20 | 21 | criterion = nn.MSELoss() 22 | 23 | global_step = 0 24 | 25 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION 26 | 27 | dataset = DATASET 28 | 29 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results' 30 | 31 | 32 | def run_step(model, vqlpips, data, device, run='train'): 33 | img, S, ground_truth, source_image = process_data(data, device, dataset) 34 | 35 | out, latent_loss = model(img) 36 | 37 | out = out[:, :3] 38 | 39 | recon_loss = criterion(out, ground_truth) 40 | latent_loss = latent_loss.mean() 41 | 42 | perceptual_loss = vqlpips(ground_truth, out) 43 | 44 | if run == 'train': 45 | return recon_loss, latent_loss, perceptual_loss, S 46 | else: 47 | return ground_truth, img, out, source_image 48 | 49 | 50 | ''' 51 | runs the validation step, saves the blended output video 52 | ''' 53 | def validation(model, vqlpips, val_loader, device, epoch, i, sample_folder, run_type='train'): 54 | print(f'Inside jitter validation') 55 | for i, data in enumerate(tqdm(val_loader)): 56 | with torch.no_grad(): 57 | source_images, input, prediction, source_images_original = run_step(model, vqlpips, data, device, run='val') 58 | 59 | source_hulls = input[:, :3] 60 | background = input[:, 3:] 61 | 62 | saves = { 63 | 'source': source_hulls, 64 | 'background': background, 65 | 'prediction': prediction, 66 | 'source_images': source_images, 67 | 'source_original': source_images_original, 68 | } 69 | 70 | if True: 71 | def denormalize(x): 72 | return (x.clamp(min=-1.0, max=1.0) + 1)/2 73 | 74 | for name in saves: 75 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4" 76 | frames = saves[name].detach().cpu() 77 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames] 78 | 79 | save_frames_as_video(frames, saveas, fps=25) 80 | 81 | ''' 82 | trainer code -- loss is computed as the weighted sum of reconstruction loss, latent loss, and perceptual loss 83 | ''' 84 | def train(model, vqlpips, loader, val_loader, optimizer, scheduler, device, epoch, validate_at, checkpoint_dir, sample_folder): 85 | if dist.is_primary(): 86 | loader = tqdm(loader, file=sys.stdout) 87 | 88 | mse_sum = 0 89 | mse_n = 0 90 | perceptual_losses = [] 91 | 92 | for i, data in enumerate(loader): 93 | model.zero_grad() 94 | 95 | recon_loss, latent_loss, perceptual_loss, S = run_step(model, vqlpips, data, device) 96 | 97 | # loss is a weighted sum of i) reconstruction loss, ii) latent loss, and iii) perceptual loss 98 | loss = recon_loss + LATENT_LOSS_WEIGHT * latent_loss + PERCEPTUAL_LOSS_WEIGHT * perceptual_loss 99 | 100 | loss.backward() 101 | 102 | perceptual_losses.append(perceptual_loss.item()) 103 | 104 | if scheduler is not None: 105 | scheduler.step() 106 | 107 | optimizer.step() 108 | 109 | global global_step 110 | 111 | global_step += 1 112 | 113 | part_mse_sum = recon_loss.item() * S 114 | part_mse_n = S 115 | 116 | comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} 117 | comm = dist.all_gather(comm) 118 | 119 | for part in comm: 120 | mse_sum += part["mse_sum"] 121 | mse_n += part["mse_n"] 122 | 123 | if dist.is_primary(): 124 | lr = optimizer.param_groups[0]["lr"] 125 | 126 | loader.set_description( 127 | ( 128 | f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " 129 | f"perceptual: {np.array(perceptual_losses).mean():.3f} " 130 | f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; " 131 | f"lr: {lr:.5f}" 132 | ) 133 | ) 134 | 135 | if i % validate_at == 0: 136 | model.eval() 137 | 138 | validation(model, vqlpips, val_loader, device, epoch, i, sample_folder) 139 | 140 | if dist.is_primary(): 141 | os.makedirs(checkpoint_dir, exist_ok=True) 142 | 143 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(i + 1).zfill(4)}.pt") 144 | 145 | model.train() 146 | 147 | def main(args): 148 | device = "cuda" 149 | 150 | args.distributed = dist.get_world_size() > 1 151 | 152 | default_transform = transforms.Compose( 153 | [ 154 | transforms.ToPILImage(), 155 | transforms.Resize(args.size), 156 | transforms.ToTensor(), 157 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 158 | ] 159 | ) 160 | 161 | loader, val_loader, model, vqlpips = get_loaders_and_models( 162 | args, device) 163 | 164 | if args.distributed: 165 | model = nn.parallel.DistributedDataParallel( 166 | model, 167 | device_ids=[dist.get_local_rank()], 168 | output_device=dist.get_local_rank(), 169 | ) 170 | 171 | vqlpips = nn.parallel.DistributedDataParallel( 172 | vqlpips, 173 | device_ids=[dist.get_local_rank()], 174 | output_device=dist.get_local_rank(), 175 | ) 176 | 177 | # load the pretrained checkpoints if available 178 | if args.ckpt: 179 | print(f'Loading pretrained checkpoint - {args.ckpt}') 180 | state_dict = torch.load(args.ckpt) 181 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() } 182 | try: 183 | model.module.load_state_dict(state_dict) 184 | except: 185 | model.load_state_dict(state_dict) 186 | 187 | if args.test: 188 | validation(model, vqlpips, val_loader, device, 0, 0, args.sample_folder, 'val') 189 | else: 190 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 191 | 192 | scheduler = None 193 | 194 | if args.sched == "cycle": 195 | scheduler = CycleScheduler( 196 | optimizer, 197 | args.lr, 198 | n_iter=len(loader) * args.epoch, 199 | momentum=None, 200 | warmup_proportion=0.05, 201 | ) 202 | 203 | for i in range(args.epoch): 204 | train(model, vqlpips, loader, val_loader, optimizer, scheduler, device, i, args.validate_at, args.checkpoint_dir, args.sample_folder) 205 | 206 | def get_random_name(cipher_length=5): 207 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 208 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)]) 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument("--n_gpu", type=int, default=1) 213 | 214 | # port = ( 215 | # 2 ** 15 216 | # + 2 ** 14 217 | # + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 218 | # ) 219 | 220 | port = random.randint(51000, 52000) 221 | 222 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 223 | parser.add_argument("--batch_size", type=int, default=32) 224 | parser.add_argument("--size", type=int, default=256, help='image resolution') 225 | parser.add_argument("--epoch", type=int, default=560, help='total epochs') 226 | parser.add_argument("--lr", type=float, default=3e-4, help='learning rate') 227 | parser.add_argument("--sched", type=str) 228 | parser.add_argument("--checkpoint_suffix", type=str, default='') 229 | parser.add_argument("--validate_at", type=int, default=1024, help='validate after (number of steps)') 230 | parser.add_argument("--ckpt", required=False) 231 | parser.add_argument("--test", action='store_true', required=False) 232 | parser.add_argument("--gray", action='store_true', required=False) 233 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty') 234 | parser.add_argument("--crossid", action='store_true', required=False) 235 | parser.add_argument("--custom_validation", action='store_true', required=False, help='perform custom validation') 236 | parser.add_argument("--sample_folder", type=str, default='samples') 237 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint', help='dir path for saving the checkpoints') 238 | parser.add_argument("--validation_folder", type=str, default=None, help='dir path for saving the validated samples') 239 | 240 | args = parser.parse_args() 241 | 242 | # args.n_gpu = torch.cuda.device_count() 243 | current_run = get_random_name() 244 | 245 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run) 246 | os.makedirs(args.sample_folder, exist_ok=True) 247 | 248 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run) 249 | 250 | print(args, flush=True) 251 | 252 | print(f'Weight configuration used - perceptual loss weight : {PERCEPTUAL_LOSS_WEIGHT}, latent loss weight : {LATENT_LOSS_WEIGHT}') 253 | 254 | dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision import utils 7 | 8 | 9 | def save_frames_as_video(frames, video_path, fps=30): 10 | height, width, layers = frames[0].shape 11 | 12 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) 13 | for frame in frames: 14 | video.write(cv2.cvtColor((frame*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) 15 | 16 | cv2.destroyAllWindows() 17 | video.release() 18 | 19 | def save_image(data, saveas, video=False): 20 | utils.save_image( 21 | data, 22 | saveas, 23 | nrow=data.shape[0]//2, 24 | normalize=True, 25 | range=(-1, 1), 26 | ) 27 | 28 | 29 | def process_data(data, device, dataset): 30 | source, target, background, source_images, source_images_original = data 31 | 32 | img = torch.cat([source, background], axis=2).squeeze(0).to(device) 33 | source = source.squeeze(0) 34 | ground_truth = source_images.squeeze(0).to(device) 35 | 36 | S = source.shape[0] 37 | 38 | return img, S, ground_truth, source_images_original.squeeze(0).to(device) 39 | 40 | 41 | ''' 42 | Conv3D based temporal module is added before the quantization step 43 | LPIPS based perceptual loss is added 44 | ''' 45 | def get_facetranslation_latent_conv_perceptual(args, device): 46 | from TemporalAlignment.dataset import TemporalAlignmentDataset 47 | from models.vqvae_conv3d_latent import VQVAE 48 | from loss import VQLPIPS 49 | 50 | print(f'Inside conv3d applied before quantization along with perceptual loss') 51 | 52 | model = VQVAE(in_channel=3*2).to(device) 53 | vqlpips = VQLPIPS().to(device) 54 | 55 | train_dataset = TemporalAlignmentDataset( 56 | 'train', 30, 57 | color_jitter_type=args.colorjit, 58 | grayscale_required=args.gray) 59 | 60 | val_dataset = TemporalAlignmentDataset( 61 | 'val', 50, 62 | color_jitter_type=args.colorjit, 63 | cross_identity_required=args.crossid, 64 | grayscale_required=args.gray, 65 | custom_validation_required=args.custom_validation, 66 | validation_datapoints=args.validation_folder) 67 | 68 | try: 69 | train_loader = DataLoader( 70 | train_dataset, 71 | batch_size=1, 72 | shuffle=True, 73 | num_workers=2) 74 | except: 75 | train_loader = None 76 | 77 | val_loader = DataLoader( 78 | val_dataset, 79 | shuffle=False, 80 | batch_size=1, 81 | num_workers=2) 82 | 83 | return train_loader, val_loader, model, vqlpips 84 | 85 | 86 | ''' 87 | get the loaders and the models 88 | ''' 89 | def get_loaders_and_models(args, device): 90 | return get_facetranslation_latent_conv_perceptual(args, device) -------------------------------------------------------------------------------- /valid_videos.txt: -------------------------------------------------------------------------------- 1 | c3tTE3hzON8 2 | XGZMe5DsnRU 3 | iOqgDmaYdqA 4 | C93trWdL3Rg 5 | _J0XtWtloWE 6 | MTgyYqY3XTE 7 | TZpP05GHWnw 8 | y9lSgYW2ObQ 9 | ozEAkq3qSi0 10 | rYVGOvyAs-I 11 | 0uNEEfQKfwk 12 | WOEcMOEmbC8 13 | IGl4w4a_pcw 14 | wuS8Hq1Yhro 15 | --------------------------------------------------------------------------------