├── README.md
├── TemporalAlignment
├── copy.sh
├── dataset.py
├── models
│ ├── mocogan_discriminator.py
│ ├── mocoganhd_content_disc.py
│ ├── mocoganhd_losses.py
│ ├── mocoganhd_models.py
│ ├── mocoganhd_video_disc.py
│ └── video_discriminator.py
├── perturbations.py
└── ranges.py
├── bad_mp4s.json
├── bash_scripts
├── train_videovqvae.sh
├── train_videovqvae_perturbations.sh
└── train_videovqvae_resume_ckpt.sh
├── command.txt
├── config.py
├── copy.sh
├── datasets
├── face_translation_videos3_utils.py
└── face_translation_videos3_utils_bb.py
├── disc_trainers
├── train_vqvae_mocogan_disc.py
├── train_vqvae_mocogan_disc_perceptual.py
├── train_vqvae_mocoganhd_disc.py
├── train_vqvae_mocoganhd_disc_single.py
└── train_vqvae_perceptual_mocoganhd_disc.py
├── distributed
├── __init__.py
├── distributed.py
└── launch.py
├── environment.yml
├── loss.py
├── models
├── actnorm.py
├── discriminator.py
├── lpips.py
└── vqvae_conv3d_latent.py
├── preprocessing
├── landmark_generation.py
└── preprocess_dataset.py
├── results
├── inference_pipeline.gif
├── training_pipeline.gif
├── v2v_comparisons1.gif
├── v2v_comparisons31.gif
├── v2v_face_swapping.gif
├── v2v_faceswapping_looped2.gif
├── v2v_more_result.gif
├── v2v_results1.gif
├── v2v_results2.gif
├── v2v_results3.gif
├── v2v_results4.gif
├── v2v_same_identity1.gif
├── v2v_same_identity2.gif
└── v2v_same_identity3.gif
├── sample
└── .gitignore
├── scheduler.py
├── train_faceoff.py
├── train_faceoff_perceptual.py
├── utils.py
├── valid_folders_ft.json
└── valid_videos.txt
/README.md:
--------------------------------------------------------------------------------
1 | # FaceOff: A Video-to-Video Face Swapping System
2 |
3 | [Aditya Agarwal](http://skymanaditya1.github.io/)\*1,
4 | [Bipasha Sen](https://bipashasen.github.io/)\*1,
5 | [Rudrabha Mukhopadhyay](https://rudrabha.github.io/)1,
6 | [Vinay Namboodiri](https://vinaypn.github.io/)2,
7 | [C V Jawahar](https://faculty.iiit.ac.in/~jawahar/)1
8 | 1International Institute of Information Technology, Hyderabad, 2University of Bath
9 |
10 | \*denotes equal contribution
11 |
12 | This is the official implementation of the paper "FaceOff: A Video-to-Video Face Swapping System" **published** at WACV 2023.
13 |
14 |
15 |
16 |
17 |
18 | For more results, information, and details visit our [**project page**](http://cvit.iiit.ac.in/research/projects/cvit-projects/faceoff) and read our [**paper**](https://openaccess.thecvf.com/content/WACV2023/papers/Agarwal_FaceOff_A_Video-to-Video_Face_Swapping_System_WACV_2023_paper.pdf). Following are some outputs from our network on the V2V Face Swapping Task.
19 |
20 |
21 | ## Results on same identity
22 |
23 |
24 |
30 |
31 | ## Getting started
32 |
33 | 1. Set up a conda environment with all dependencies using the following commands:
34 |
35 | ```
36 | conda env create -f environment.yml
37 | conda activate faceoff
38 | ```
39 |
40 | ## Training FaceOff
41 |
42 |
43 |
44 | The following command trains the V2V Face Swapping network. At set intervals, it will generate the data on the validation dataset.
45 |
46 | ```
47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_faceoff_perceptual.py
48 | ```
49 | **Parameters**
50 | Below is the full list of parameters
51 |
52 | ```--dist_url``` - port on which experiment is run
53 | ```--batch_size``` - batch size, default is ```32```
54 | ```--size``` - image size, default is ```256```
55 | ```--epoch``` - number of epochs to train for
56 | ```--lr``` - learning rate, default is ```3e-4```
57 | ```--sched``` - scheduler to use
58 | ```--checkpoint_suffix``` - folder where checkpoints are saved, in default mode a random folder name is created
59 | ```--validate_at``` - number of steps after which validation is performed, default is ```1024```
60 | ```--ckpt``` - indicates a pretrained checkpoint, default is None
61 | ```--test``` - whether testing the model
62 | ```--gray``` - whether testing on gray scale
63 | ```--colorjit``` - type of color jitter to add, ```const```, ```random``` or ```empty``` are the possible options
64 | ```--crossid``` - whether cross id required during validation, default is ```True```
65 | ```--custom_validation``` - used to test FaceOff on two videos, default is ```False```
66 | ```--sample_folder``` - path where the validation videos are stored
67 | ```--checkpoint_dir``` - dir path where checkpoints are saved
68 | ```--validation_folder``` - dir path where validated samples are saved
69 |
70 | All the values can be left at their default values to train FaceOff in the vanilla setting. An example is given below:
71 |
72 | ```
73 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_faceoff_perceptual.py
74 | ```
75 |
76 | ## Checkpoints
77 | Pretrained checkpoints would be released soon!
78 |
79 | ## We would love your contributions to improve FaceOff
80 |
81 | FaceOff introduces the novel task of Video-to-Video Face Swapping that tackles a pressing challenge in the moviemaking industry: swapping the actor's face and expressions on the face of their body double. Existing face-swapping methods swap only the identity of the source face without swapping the source (actor) expressions which is undesirable as the starring actor's source expressions are paramount. In video-to-video face swapping, we swap the source's facial expressions along with the identity on the target's background and pose. Our method retains the face and expressions of the source actor and the pose and background information of the target actor. Currently, our model has a few limitations. ***we would like to strongly encourage contributions and spur further research into some of the limitations listed above.***
82 |
83 | 1. **Video Quality**: FaceOff is based on combining the temporal motion of the source and the target face videos in the reduced space of a vector quantized variational autoencoder. Consequently, it suffers from a few quality issues and the output resolution is limited to 256x256. Generating samples of very high-quality is typically required in the movie-making industry.
84 |
85 | 2. **Temporal Jitter**: Although the 3Dconv modules get rid of most of the temporal jitters in the blended output, there are a few noticeable temporal jitters that require attention. The temporal jitters occur as we try to photo-realistically blend two different motions (source and target) in a temporally coherent manner.
86 |
87 | 3. **Extreme Poses**: FaceOff was designed to face-swap actor's face and expressions with the double's pose and background information. Consequently, it is expected that the pose difference between the source and the target actors won't be extreme. FaceOff can solve **roll** related rotations in the 2D space, it would be worth investigating fixing rotations due to **yaw** and **pitch** in the 3D space and render the output back to the 2D space.
88 |
89 |
90 | ## Thanks
91 |
92 | ### VQVAE2
93 |
94 | We would like to thank the authors of [VQVAE2](https://arxiv.org/pdf/1906.00446.pdf) ([https://github.com/rosinality/vq-vae-2-pytorch])(https://github.com/rosinality/vq-vae-2-pytorch) for releasing the code. We modify on top of their codebase for performing V2V Face Swapping.
95 |
96 | ## Citation
97 | If you find our work useful in your research, please cite:
98 | ```
99 | @InProceedings{Agarwal_2023_WACV,
100 | author = {Agarwal, Aditya and Sen, Bipasha and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C. V.},
101 | title = {FaceOff: A Video-to-Video Face Swapping System},
102 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
103 | month = {January},
104 | year = {2023},
105 | pages = {3495-3504}
106 | }
107 | ```
108 |
109 | ## Contact
110 | If you have any questions, please feel free to email the authors.
111 |
112 | Aditya Agarwal: aditya.ag@research.iiit.ac.in
113 | Bipasha Sen: bipasha.sen@research.iiit.ac.in
114 | Rudrabha Mukhopadhyay: radrabha.m@research.iiit.ac.in
115 | Vinay Namboodiri: vpn22@bath.ac.uk
116 | C V Jawahar: jawahar@iiit.ac.in
--------------------------------------------------------------------------------
/TemporalAlignment/copy.sh:
--------------------------------------------------------------------------------
1 | cp -r /home2/bipasha31/python_scripts/CurrentWork/SLP/VQVAE2-Refact/TemporalAlignment/* .
--------------------------------------------------------------------------------
/TemporalAlignment/models/mocogan_discriminator.py:
--------------------------------------------------------------------------------
1 | # modified versions of mocogan's image and video discriminators
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.parallel
6 | import torch.utils.data
7 | from torch.autograd import Variable
8 |
9 | import numpy as np
10 |
11 | if torch.cuda.is_available():
12 | T = torch.cuda
13 | else:
14 | T = torch
15 |
16 | class Noise(nn.Module):
17 | def __init__(self, use_noise, sigma=0.2):
18 | super(Noise, self).__init__()
19 | self.use_noise = use_noise
20 | self.sigma = sigma
21 |
22 | def forward(self, x):
23 | if self.use_noise:
24 | return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False)
25 | return x
26 |
27 | # modified version of mocogan's image discriminator
28 | # input dimension -> batch_size x channels x 256 x 256, output dimension -> batch_size
29 | class ImageDiscriminator(nn.Module):
30 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None):
31 | super(ImageDiscriminator, self).__init__()
32 |
33 | self.use_noise = use_noise
34 |
35 | self.main = nn.Sequential(
36 | Noise(use_noise, sigma=noise_sigma),
37 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False),
38 | nn.LeakyReLU(0.2, inplace=True),
39 |
40 | Noise(use_noise, sigma=noise_sigma),
41 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
42 | nn.BatchNorm2d(ndf * 2),
43 | nn.LeakyReLU(0.2, inplace=True),
44 |
45 | Noise(use_noise, sigma=noise_sigma),
46 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
47 | nn.BatchNorm2d(ndf * 4),
48 | nn.LeakyReLU(0.2, inplace=True),
49 |
50 | Noise(use_noise, sigma=noise_sigma),
51 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
52 | nn.BatchNorm2d(ndf * 8),
53 | nn.LeakyReLU(0.2, inplace=True),
54 |
55 | Noise(use_noise, sigma=noise_sigma),
56 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
57 | nn.BatchNorm2d(ndf * 16),
58 | nn.LeakyReLU(0.2, inplace=True),
59 |
60 | Noise(use_noise, sigma=noise_sigma),
61 | nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
62 | nn.BatchNorm2d(ndf * 32),
63 | nn.LeakyReLU(0.2, inplace=True),
64 |
65 | nn.Conv2d(ndf * 32, 1, 4, 1, 0, bias=False),
66 | )
67 |
68 | def forward(self, input):
69 | h = self.main(input).squeeze()
70 | return h, None
71 |
72 | # modified version of mocogan's patch image discriminator
73 | # input dimension -> batch_size x channels x 256 x 256, output dimension -> batch_size x 4 x 4
74 | class PatchImageDiscriminator(nn.Module):
75 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None):
76 | super(PatchImageDiscriminator, self).__init__()
77 |
78 | self.use_noise = use_noise
79 |
80 | self.main = nn.Sequential(
81 | Noise(use_noise, sigma=noise_sigma),
82 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False),
83 | nn.LeakyReLU(0.2, inplace=True),
84 |
85 | Noise(use_noise, sigma=noise_sigma),
86 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
87 | nn.BatchNorm2d(ndf * 2),
88 | nn.LeakyReLU(0.2, inplace=True),
89 |
90 | Noise(use_noise, sigma=noise_sigma),
91 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
92 | nn.BatchNorm2d(ndf * 4),
93 | nn.LeakyReLU(0.2, inplace=True),
94 |
95 | Noise(use_noise, sigma=noise_sigma),
96 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
97 | nn.BatchNorm2d(ndf * 8),
98 | nn.LeakyReLU(0.2, inplace=True),
99 |
100 | Noise(use_noise, sigma=noise_sigma),
101 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
102 | nn.BatchNorm2d(ndf * 16),
103 | nn.LeakyReLU(0.2, inplace=True),
104 |
105 | Noise(use_noise, sigma=noise_sigma),
106 | nn.Conv2d(ndf * 16, 1, 4, 2, 1, bias=False),
107 | )
108 |
109 | def forward(self, input):
110 | h = self.main(input).squeeze()
111 | return h, None
112 |
113 | # modified version of mocogan's video discriminator
114 | # input dimension -> batch_size x channels x 16 x 256 x 256
115 | # output idmension -> batch_size x 16
116 | class VideoDiscriminator(nn.Module):
117 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64):
118 | super(VideoDiscriminator, self).__init__()
119 |
120 | self.n_channels = n_channels
121 | self.n_output_neurons = n_output_neurons
122 | self.use_noise = use_noise
123 | self.bn_use_gamma = bn_use_gamma
124 |
125 | self.main = nn.Sequential(
126 | Noise(use_noise, sigma=noise_sigma),
127 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
128 | nn.LeakyReLU(0.2, inplace=True),
129 |
130 | Noise(use_noise, sigma=noise_sigma),
131 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
132 | nn.BatchNorm3d(ndf * 2),
133 | nn.LeakyReLU(0.2, inplace=True),
134 |
135 | Noise(use_noise, sigma=noise_sigma),
136 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
137 | nn.BatchNorm3d(ndf * 4),
138 | nn.LeakyReLU(0.2, inplace=True),
139 |
140 | Noise(use_noise, sigma=noise_sigma),
141 | nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
142 | nn.BatchNorm3d(ndf * 8),
143 | nn.LeakyReLU(0.2, inplace=True),
144 |
145 | Noise(use_noise, sigma=noise_sigma),
146 | nn.Conv3d(ndf * 8, ndf * 16, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
147 | nn.BatchNorm3d(ndf * 16),
148 | nn.LeakyReLU(0.2, inplace=True),
149 |
150 | Noise(use_noise, sigma=noise_sigma),
151 | nn.Conv3d(ndf * 16, ndf * 32, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
152 | nn.BatchNorm3d(ndf * 32),
153 | nn.LeakyReLU(0.2, inplace=True),
154 |
155 | nn.Conv3d(ndf * 32, n_output_neurons, (1, 4, 4), 1, 0, bias=False), # cannot apply kernel size of 4 on input image of size 1
156 | )
157 |
158 | def forward(self, input):
159 | h = self.main(input).squeeze()
160 |
161 | return h, None
162 |
163 | # modified version of mocogan's patch video disctiminator
164 | # input dimension -> batch_size x channels x 16 x 256 x 256
165 | # output dimension -> batch_size x 4 x 4 x 4
166 | class PatchVideoDiscriminator(nn.Module):
167 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64):
168 | super(PatchVideoDiscriminator, self).__init__()
169 |
170 | self.n_channels = n_channels
171 | self.n_output_neurons = n_output_neurons
172 | self.use_noise = use_noise
173 | self.bn_use_gamma = bn_use_gamma
174 |
175 | self.main = nn.Sequential(
176 | Noise(use_noise, sigma=noise_sigma),
177 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
178 | nn.LeakyReLU(0.2, inplace=True),
179 |
180 | Noise(use_noise, sigma=noise_sigma),
181 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
182 | nn.BatchNorm3d(ndf * 2),
183 | nn.LeakyReLU(0.2, inplace=True),
184 |
185 | Noise(use_noise, sigma=noise_sigma),
186 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
187 | nn.BatchNorm3d(ndf * 4),
188 | nn.LeakyReLU(0.2, inplace=True),
189 |
190 | Noise(use_noise, sigma=noise_sigma),
191 | nn.Conv3d(ndf * 4, ndf * 8, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
192 | nn.BatchNorm3d(ndf * 8),
193 | nn.LeakyReLU(0.2, inplace=True),
194 |
195 | Noise(use_noise, sigma=noise_sigma),
196 | nn.Conv3d(ndf * 8, ndf * 16, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
197 | nn.BatchNorm3d(ndf * 16),
198 | nn.LeakyReLU(0.2, inplace=True),
199 |
200 | nn.Conv3d(ndf * 16, 1, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
201 | )
202 |
203 | def forward(self, input):
204 | h = self.main(input).squeeze()
205 |
206 | return h, None
--------------------------------------------------------------------------------
/TemporalAlignment/models/mocoganhd_content_disc.py:
--------------------------------------------------------------------------------
1 | # This is the mocoganhd content discriminator code
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import functools
6 |
7 | # this is the image discriminator - for classifying the content as real/fake
8 | class ModelD_img(nn.Module):
9 | def __init__(self, nc, norm_D_3d, num_D, lr):
10 | super(ModelD_img, self).__init__()
11 | nc = nc * 2
12 |
13 | self.netD = MultiscaleDiscriminator(input_nc=nc,
14 | norm_layer=get_norm_layer(
15 | norm_D_3d),
16 | num_D=num_D)
17 | self.netD.apply(weights_init)
18 |
19 | self.optim = torch.optim.Adam(self.netD.parameters(),
20 | lr=lr,
21 | betas=(0.5, 0.999))
22 |
23 | def forward(self, x):
24 | return self.netD.forward(x)
25 |
26 |
27 | def weights_init(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Conv') != -1 and hasattr(m, 'weight'):
30 | m.weight.data.normal_(0.0, 0.02)
31 | elif classname.find('BatchNorm2d') != -1:
32 | m.weight.data.normal_(1.0, 0.02)
33 | m.bias.data.fill_(0)
34 |
35 |
36 | def get_norm_layer(norm_type='instance'):
37 | if norm_type == 'batch':
38 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
39 | elif norm_type == 'instance':
40 | norm_layer = functools.partial(nn.InstanceNorm2d,
41 | affine=False,
42 | track_running_stats=True)
43 | else:
44 | raise NotImplementedError('normalization layer [%s] is not found' %
45 | norm_type)
46 | return norm_layer
47 |
48 |
49 | class MultiscaleDiscriminator(nn.Module):
50 | def __init__(self,
51 | input_nc,
52 | ndf=64,
53 | n_layers=3,
54 | n_frames=16,
55 | norm_layer=nn.InstanceNorm2d,
56 | num_D=2,
57 | getIntermFeat=True):
58 | super(MultiscaleDiscriminator, self).__init__()
59 | self.num_D = num_D
60 | self.n_layers = n_layers
61 | self.getIntermFeat = getIntermFeat
62 | ndf_max = 64
63 |
64 | for i in range(num_D):
65 | netD = NLayerDiscriminator(
66 | input_nc, min(ndf_max, ndf * (2**(num_D - 1 - i))), n_layers,
67 | norm_layer, getIntermFeat)
68 | if getIntermFeat:
69 | for j in range(n_layers + 2):
70 | setattr(self, 'scale' + str(i) + '_layer' + str(j),
71 | getattr(netD, 'model' + str(j)))
72 | else:
73 | setattr(self, 'layer' + str(i), netD.model)
74 | self.downsample = nn.AvgPool2d(3,
75 | stride=2,
76 | padding=[1, 1],
77 | count_include_pad=False)
78 |
79 | def singleD_forward(self, model, input):
80 | if self.getIntermFeat:
81 | result = [input]
82 | for i in range(len(model)):
83 | result.append(model[i](result[-1]))
84 | return result[1:]
85 | else:
86 | return [model(input)]
87 |
88 | def forward(self, input):
89 | num_D = self.num_D
90 | result = []
91 | input_downsampled = input
92 | for i in range(num_D):
93 | if self.getIntermFeat:
94 | model = [
95 | getattr(self,
96 | 'scale' + str(num_D - 1 - i) + '_layer' + str(j))
97 | for j in range(self.n_layers + 2)
98 | ]
99 | else:
100 | model = getattr(self, 'layer' + str(num_D - 1 - i))
101 | # the model needs to be transferred to the GPU first
102 | result.append(self.singleD_forward(model, input_downsampled)) # the model is not on the GPU, that's why the error is being thrown
103 | if i != (num_D - 1):
104 | input_downsampled = self.downsample(input_downsampled)
105 | return result
106 |
107 |
108 | class NLayerDiscriminator(nn.Module):
109 | def __init__(self,
110 | input_nc,
111 | ndf=64,
112 | n_layers=3,
113 | norm_layer=nn.InstanceNorm2d,
114 | getIntermFeat=True):
115 | super(NLayerDiscriminator, self).__init__()
116 | self.getIntermFeat = getIntermFeat
117 | self.n_layers = n_layers
118 |
119 | kw = 4
120 | padw = int(np.ceil((kw - 1.0) / 2))
121 | sequence = [[
122 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
123 | nn.LeakyReLU(0.2, True)
124 | ]]
125 |
126 | nf = ndf
127 | for n in range(1, n_layers):
128 | nf_prev = nf
129 | nf = min(nf * 2, 512)
130 | sequence += [[
131 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
132 | norm_layer(nf),
133 | nn.LeakyReLU(0.2, True)
134 | ]]
135 |
136 | nf_prev = nf
137 | nf = min(nf * 2, 512)
138 | sequence += [[
139 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
140 | norm_layer(nf),
141 | nn.LeakyReLU(0.2, True)
142 | ]]
143 |
144 | sequence += [[
145 | nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)
146 | ]]
147 |
148 | if getIntermFeat:
149 | for n in range(len(sequence)):
150 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
151 | else:
152 | sequence_stream = []
153 | for n in range(len(sequence)):
154 | sequence_stream += sequence[n]
155 | self.model = nn.Sequential(*sequence_stream)
156 |
157 | def forward(self, input):
158 | if self.getIntermFeat:
159 | res = [input]
160 | for n in range(self.n_layers + 2):
161 | model = getattr(self, 'model' + str(n))
162 | res.append(model(res[-1]))
163 | return res[1:]
164 | else:
165 | return self.model(input)
166 |
--------------------------------------------------------------------------------
/TemporalAlignment/models/mocoganhd_losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.
3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,
4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.
5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,
6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.
7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.
8 | """
9 | import sys
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | import torch.nn as nn
14 |
15 |
16 | def loss_hinge_dis(dis_fake, dis_real):
17 | loss_real = torch.mean(F.relu(1. - dis_real))
18 | loss_fake = torch.mean(F.relu(1. + dis_fake))
19 | return loss_real, loss_fake
20 |
21 |
22 | def loss_hinge_gen(dis_fake):
23 | loss = -torch.mean(dis_fake)
24 | return loss
25 |
26 |
27 | def compute_gradient_penalty_T(real_B, fake_B, modelD, opt):
28 | alpha = torch.rand(list(real_B.size())[0], 1, 1, 1, 1)
29 | alpha = alpha.expand(real_B.size()).cuda(real_B.get_device())
30 |
31 | interpolates = alpha * real_B.data + (1 - alpha) * fake_B.data
32 | interpolates = torch.tensor(interpolates, requires_grad=True)
33 |
34 | pred_interpolates = modelD(interpolates)
35 |
36 | gradient_penalty = 0
37 | if isinstance(pred_interpolates, list):
38 | for cur_pred in pred_interpolates:
39 | gradients = torch.autograd.grad(outputs=cur_pred[-1],
40 | inputs=interpolates,
41 | grad_outputs=torch.ones(
42 | cur_pred[-1].size()).cuda(
43 | real_B.get_device()),
44 | create_graph=True,
45 | retain_graph=True,
46 | only_inputs=True)[0]
47 |
48 | gradient_penalty += ((gradients.norm(2, dim=1) - 1)**2).mean()
49 | else:
50 | sys.exit('output is not list!')
51 |
52 | gradient_penalty = (gradient_penalty / opt.num_D) * 10
53 | return gradient_penalty
54 |
55 |
56 | class GANLoss(nn.Module):
57 | def __init__(self,
58 | use_lsgan=True,
59 | target_real_label=1.0,
60 | target_fake_label=0.0,
61 | tensor=torch.FloatTensor):
62 | super(GANLoss, self).__init__()
63 | self.real_label = target_real_label
64 | self.fake_label = target_fake_label
65 | self.real_label_var = None
66 | self.fake_label_var = None
67 | self.Tensor = tensor
68 | if use_lsgan:
69 | self.loss = nn.MSELoss()
70 | else:
71 | self.loss = nn.BCELoss()
72 |
73 | def get_target_tensor(self, input, target_is_real):
74 | target_tensor = None
75 | if target_is_real:
76 | create_label = ((self.real_label_var is None)
77 | or (self.real_label_var.numel() != input.numel()))
78 | if create_label:
79 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
80 | self.real_label_var = torch.tensor(real_tensor,
81 | requires_grad=False)
82 | target_tensor = self.real_label_var
83 | else:
84 | create_label = ((self.fake_label_var is None)
85 | or (self.fake_label_var.numel() != input.numel()))
86 | if create_label:
87 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
88 | self.fake_label_var = torch.tensor(fake_tensor,
89 | requires_grad=False)
90 | target_tensor = self.fake_label_var
91 |
92 | if input.is_cuda:
93 | target_tensor = target_tensor.cuda()
94 | return target_tensor
95 |
96 | def __call__(self, input, target_is_real):
97 | if isinstance(input[0], list):
98 | loss = 0
99 | for input_i in input:
100 | pred = input_i[-1]
101 | target_tensor = self.get_target_tensor(pred, target_is_real)
102 | loss += self.loss(pred, target_tensor)
103 | return loss
104 | else:
105 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
106 | return self.loss(input[-1], target_tensor)
107 |
108 |
109 | class Relativistic_Average_LSGAN(GANLoss):
110 | '''
111 | Relativistic average LSGAN
112 | '''
113 |
114 | def __call__(self, input_1, input_2, target_is_real):
115 | if isinstance(input_1[0], list):
116 | loss = 0
117 | for input_i, _input_i in zip(input_1, input_2):
118 | pred = input_i[-1]
119 | _pred = _input_i[-1]
120 | target_tensor = self.get_target_tensor(pred, target_is_real)
121 | loss += self.loss(pred - torch.mean(_pred), target_tensor)
122 | return loss
123 | else:
124 | target_tensor = self.get_target_tensor(input_1[-1], target_is_real)
125 | return self.loss(input_1[-1] - torch.mean(input_2[-1]),
126 | target_tensor)
127 |
--------------------------------------------------------------------------------
/TemporalAlignment/models/mocoganhd_models.py:
--------------------------------------------------------------------------------
1 | # class that has additional functionalities provided in the mocogan_hd repository
2 | """
3 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.
4 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,
5 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.
6 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,
7 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.
8 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.
9 | """
10 | import os
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 |
16 |
17 | def load_checkpoints(path, gpu):
18 | if gpu is None:
19 | ckpt = torch.load(path)
20 | else:
21 | loc = 'cuda:{}'.format(gpu)
22 | ckpt = torch.load(path, map_location=loc)
23 | return ckpt
24 |
25 |
26 | def model_to_gpu(model, isTrain, gpu):
27 | if isTrain:
28 | if gpu is not None:
29 | model.cuda(gpu)
30 | model = DDP(model,
31 | device_ids=[gpu],
32 | find_unused_parameters=True)
33 | else:
34 | model.cuda()
35 | model = DDP(model, find_unused_parameters=True)
36 | else:
37 | model.cuda()
38 | model = nn.DataParallel(model)
39 |
40 | return model
--------------------------------------------------------------------------------
/TemporalAlignment/models/mocoganhd_video_disc.py:
--------------------------------------------------------------------------------
1 | # This is the mocoganhd motion (video) discriminator code
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import functools
6 |
7 | # this is the discriminator for classifying the real/fake video motion (video code)
8 | class ModelD_3d(nn.Module):
9 | def __init__(self, nc, norm_D_3d, num_D, lr, cross_domain, n_frames_G):
10 | super(ModelD_3d, self).__init__()
11 | if cross_domain:
12 | nc = nc
13 | n_frames_G = n_frames_G
14 | else:
15 | nc = nc * 2
16 | n_frames_G = n_frames_G - 1
17 |
18 | self.netD = MultiscaleDiscriminator(input_nc=nc,
19 | n_frames=n_frames_G,
20 | norm_layer=get_norm_layer(
21 | norm_D_3d),
22 | num_D=num_D)
23 | self.netD.apply(weights_init)
24 |
25 | self.optim = torch.optim.Adam(self.netD.parameters(),
26 | lr=lr,
27 | betas=(0.5, 0.999))
28 |
29 | def forward(self, x):
30 | return self.netD.forward(x)
31 |
32 |
33 | def weights_init(m):
34 | classname = m.__class__.__name__
35 | if classname.find('Conv') != -1 and hasattr(m, 'weight'):
36 | m.weight.data.normal_(0.0, 0.02)
37 | elif classname.find('BatchNorm3d') != -1:
38 | m.weight.data.normal_(1.0, 0.02)
39 | m.bias.data.fill_(0)
40 |
41 |
42 | def get_norm_layer(norm_type='instance'):
43 | if norm_type == 'batch':
44 | norm_layer = functools.partial(nn.BatchNorm3d, affine=True)
45 | elif norm_type == 'instance':
46 | norm_layer = functools.partial(nn.InstanceNorm3d,
47 | affine=False,
48 | track_running_stats=True)
49 | else:
50 | raise NotImplementedError('normalization layer [%s] is not found' %
51 | norm_type)
52 | return norm_layer
53 |
54 |
55 | class MultiscaleDiscriminator(nn.Module):
56 | def __init__(self,
57 | input_nc,
58 | ndf=64,
59 | n_layers=3,
60 | n_frames=16,
61 | norm_layer=nn.InstanceNorm3d,
62 | num_D=2,
63 | getIntermFeat=True):
64 | super(MultiscaleDiscriminator, self).__init__()
65 | self.num_D = num_D
66 | self.n_layers = n_layers
67 | self.getIntermFeat = getIntermFeat
68 | ndf_max = 64
69 |
70 | for i in range(num_D):
71 | netD = NLayerDiscriminator(
72 | input_nc, min(ndf_max, ndf * (2**(num_D - 1 - i))), n_layers,
73 | norm_layer, getIntermFeat)
74 | if getIntermFeat:
75 | for j in range(n_layers + 2):
76 | setattr(self, 'scale' + str(i) + '_layer' + str(j),
77 | getattr(netD, 'model' + str(j)))
78 | else:
79 | setattr(self, 'layer' + str(i), netD.model)
80 | if n_frames > 16:
81 | self.downsample = nn.AvgPool3d(3,
82 | stride=2,
83 | padding=[1, 1, 1],
84 | count_include_pad=False)
85 | else:
86 | self.downsample = nn.AvgPool3d(3,
87 | stride=[1, 2, 2],
88 | padding=[1, 1, 1],
89 | count_include_pad=False)
90 |
91 | def singleD_forward(self, model, input):
92 | if self.getIntermFeat:
93 | result = [input]
94 | for i in range(len(model)):
95 | result.append(model[i](result[-1]))
96 | return result[1:]
97 | else:
98 | return [model(input)]
99 |
100 | def forward(self, input):
101 | num_D = self.num_D
102 | result = []
103 | input_downsampled = input
104 | for i in range(num_D):
105 | if self.getIntermFeat:
106 | model = [
107 | getattr(self,
108 | 'scale' + str(num_D - 1 - i) + '_layer' + str(j))
109 | for j in range(self.n_layers + 2)
110 | ]
111 | else:
112 | model = getattr(self, 'layer' + str(num_D - 1 - i))
113 | result.append(self.singleD_forward(model, input_downsampled))
114 | if i != (num_D - 1):
115 | input_downsampled = self.downsample(input_downsampled)
116 | return result
117 |
118 |
119 | class NLayerDiscriminator(nn.Module):
120 | def __init__(self,
121 | input_nc,
122 | ndf=64,
123 | n_layers=3,
124 | norm_layer=nn.InstanceNorm3d,
125 | getIntermFeat=True):
126 | super(NLayerDiscriminator, self).__init__()
127 | self.getIntermFeat = getIntermFeat
128 | self.n_layers = n_layers
129 |
130 | kw = 4
131 | padw = int(np.ceil((kw - 1.0) / 2))
132 | sequence = [[
133 | nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
134 | nn.LeakyReLU(0.2, True)
135 | ]]
136 |
137 | nf = ndf
138 | for n in range(1, n_layers):
139 | nf_prev = nf
140 | nf = min(nf * 2, 512)
141 | sequence += [[
142 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
143 | norm_layer(nf),
144 | nn.LeakyReLU(0.2, True)
145 | ]]
146 |
147 | nf_prev = nf
148 | nf = min(nf * 2, 512)
149 | sequence += [[
150 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
151 | norm_layer(nf),
152 | nn.LeakyReLU(0.2, True)
153 | ]]
154 |
155 | sequence += [[
156 | nn.Conv3d(nf, 1, kernel_size=kw, stride=1, padding=padw)
157 | ]]
158 |
159 | if getIntermFeat:
160 | for n in range(len(sequence)):
161 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
162 | else:
163 | sequence_stream = []
164 | for n in range(len(sequence)):
165 | sequence_stream += sequence[n]
166 | self.model = nn.Sequential(*sequence_stream)
167 |
168 | def forward(self, input):
169 | if self.getIntermFeat:
170 | res = [input]
171 | for n in range(self.n_layers + 2):
172 | model = getattr(self, 'model' + str(n))
173 | res.append(model(res[-1]))
174 | return res[1:]
175 | else:
176 | return self.model(input)
177 |
--------------------------------------------------------------------------------
/TemporalAlignment/models/video_discriminator.py:
--------------------------------------------------------------------------------
1 | # This class implements the Video Discriminator
2 | # Discriminator based on MocoGAN's discriminator
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.parallel
6 | import torch.utils.data
7 | from torch.autograd import Variable
8 |
9 | import numpy as np
10 |
11 | class Noise(nn.Module):
12 | def __init__(self, use_noise, sigma=0.2):
13 | super(Noise, self).__init__()
14 | self.use_noise = use_noise
15 | self.sigma = sigma
16 |
17 | def forward(self, x):
18 | if self.use_noise:
19 | return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False)
20 | return x
21 |
22 | class VideoDiscriminator(nn.Module):
23 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64):
24 | super(VideoDiscriminator, self).__init__()
25 |
26 | self.n_channels = n_channels
27 | self.n_output_neurons = n_output_neurons
28 | self.use_noise = use_noise
29 | self.bn_use_gamma = bn_use_gamma
30 |
31 | self.main = nn.Sequential(
32 | Noise(use_noise, sigma=noise_sigma),
33 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
34 | nn.LeakyReLU(0.2, inplace=True),
35 |
36 | Noise(use_noise, sigma=noise_sigma),
37 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
38 | nn.BatchNorm3d(ndf * 2),
39 | nn.LeakyReLU(0.2, inplace=True),
40 |
41 | Noise(use_noise, sigma=noise_sigma),
42 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
43 | nn.BatchNorm3d(ndf * 4),
44 | nn.LeakyReLU(0.2, inplace=True),
45 |
46 | Noise(use_noise, sigma=noise_sigma),
47 | nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
48 | nn.BatchNorm3d(ndf * 8),
49 | nn.LeakyReLU(0.2, inplace=True),
50 |
51 | nn.Conv3d(ndf * 8, n_output_neurons, 4, 1, 0, bias=False),
52 | )
53 |
54 | self.linear = nn.Linear(13*13, 1) # add linear layer to deal with different resolution
55 |
56 | def forward(self, input):
57 | h = self.main(input).view(-1)
58 | h = self.linear(h)
59 | return h
--------------------------------------------------------------------------------
/TemporalAlignment/perturbations.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file generates the paired data by introducing motion errors (aka perturbations) in the source face
3 | The imperfectly blended face is a 6 channel image which the model has to generate to denoise in a temporal fashion
4 | '''
5 |
6 | import sys
7 | import os
8 | import os.path as osp
9 | import random
10 | from glob import glob
11 | from tqdm import tqdm
12 | from enum import Enum
13 |
14 | from TemporalAlignment.ranges import *
15 |
16 | import torch
17 | import torchvision.transforms as transforms
18 |
19 | import matplotlib.pyplot as plt
20 | from PIL import Image
21 | from wand.image import Image as WandImage
22 |
23 | import cv2
24 | import numpy as np
25 |
26 | '''
27 | The following perturbation functions are added --
28 | 1. Affine transformations - translation, rotation, resize
29 | 2. Color transformations - change hue, saturation etc
30 | 3. Non-linear transformations aka distortions - arc, barrel, barrel_inverse
31 | '''
32 |
33 | '''
34 | types of distortions
35 | '''
36 | class Distortion(Enum):
37 | ARC = 1,
38 | BARREL = 2,
39 | BARREL_INVERSE = 3,
40 |
41 |
42 | '''
43 | horizontal image translation -- affine transformation
44 | '''
45 | def translate_horizontal(x, image):
46 | M = np.float32([
47 | [1, 0, x],
48 | [0, 1, 0]
49 | ])
50 |
51 | shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))
52 | return shifted
53 |
54 | '''
55 | vertical image translation -- affine transformation
56 | '''
57 | def translate_vertical(y, image):
58 | M = np.float32([
59 | [1, 0, 0],
60 | [0, 1, y]
61 | ])
62 |
63 | shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))
64 |
65 | return shifted
66 |
67 | '''
68 | image rotation (clockwise or anti-clockwise) -- affine transformation
69 | '''
70 | def rotate_image(rotation, image, center=None):
71 | # rotates the image by the center of the image
72 | h, w = image.shape[:2]
73 |
74 | if center is None:
75 | cX, cY = (w//2, h//2)
76 | M = cv2.getRotationMatrix2D((cX, cY), rotation, 1.0)
77 | else:
78 | M = cv2.getRotationMatrix2D(center, rotation, 1.0)
79 |
80 | rotated = cv2.warpAffine(image, M, (w, h))
81 |
82 | return rotated
83 |
84 | '''
85 | image resize operations (zoom in and zoom out) -- not used currently
86 | '''
87 | def resize_image(magnification, image):
88 | res = cv2.resize(image, None, fx=magnification, fy=magnification, interpolation=cv2.INTER_CUBIC)
89 | h, w = image.shape[:2]
90 |
91 | if magnification >= 1:
92 | cX, cY = res.shape[1] // 2, res.shape[0] // 2
93 | left_index = cX - w // 2
94 | upper_index = cY - h // 2
95 | modified_image = res[upper_index : upper_index + h, left_index : left_index + w]
96 | else:
97 | modified_image = np.zeros((image.shape), dtype=np.uint8)
98 | hs, ws = res.shape[:2]
99 | difference_h = h - hs
100 | difference_w = w - ws
101 | left_index = difference_w // 2
102 | upper_index = difference_h // 2
103 | modified_image[upper_index : upper_index + hs, left_index : left_index + ws] = res
104 |
105 | return modified_image
106 |
107 | '''
108 | applies shear transformation along both the axes -- not used currently
109 | '''
110 | def shear_image(shear, image):
111 | shear_x, shear_y = shear, shear
112 | M = np.float32([
113 | [1, shear_x, 0],
114 | [shear_y, 1, 0]
115 | ])
116 |
117 | sheared = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))
118 |
119 | return sheared
120 |
121 | '''
122 | flips the imagein the horizontal direction -- not used currently
123 | '''
124 | def image_flip(flip_code, image):
125 | flipped_image = cv2.flip(image, int(flip_code))
126 | return flipped_image
127 |
128 | '''
129 | applies non-linear transformation (distortion) to the image
130 | '''
131 | def distort_image(distortion_type, image):
132 | img_file = np.array(image)
133 | with WandImage.from_array(img_file) as img:
134 | '''
135 | three type of image distortions are supported - i) arc, ii) barrel, and iii) inverse-barrel
136 | '''
137 | if distortion_type == Distortion.ARC.value:
138 | distortion_angle = random.randint(0, 30)
139 | img.distort('arc', (distortion_angle,))
140 | img.resize(image.shape[0], image.shape[1])
141 |
142 | opencv_image = np.array(img)
143 |
144 | elif distortion_type == Distortion.BARREL.value:
145 | a = random.randint(0, 10)/10
146 | b = random.randint(2, 7)/10
147 | c = random.randint(0, 5)/10
148 | d = random.randint(10, 10)/10
149 | args = (a, b, c, d)
150 | img.distort('barrel', (args))
151 | img.resize(image.shape[0], image.shape[1])
152 |
153 | opencv_image = np.array(img)
154 |
155 | else:
156 | b = random.randint(0, 2)/10
157 | c = random.randint(-5, 0)/10
158 | d = random.randint(10, 10)/10
159 | args = (0.0, b, c, d)
160 | img.distort('barrel_inverse', (args))
161 | img.resize(image.shape[0], image.shape[1])
162 |
163 | opencv_image = np.array(img)
164 |
165 | return opencv_image
166 |
167 | '''
168 | blend the perturbed image and the masked image
169 | '''
170 | def combine_images(face_mask, perturbed_image, generate_mask=True):
171 | image_masked = face_mask.copy()
172 | if generate_mask:
173 | mask = perturbed_image[..., 0] != 0
174 | image_masked[mask] = 0
175 |
176 | combined_image = image_masked + perturbed_image
177 |
178 | return combined_image
179 |
180 | '''
181 | finds the center of the eyes from the face landmarks
182 | '''
183 | def find_eye_center(shape):
184 | # landmark coordinates corresponding to the eyes
185 | lStart, lEnd = 36, 41
186 | rStart, rEnd = 42, 47
187 |
188 | # landmarks for the left and right eyes
189 | leftEyePoints = shape[lStart:lEnd]
190 | rightEyePoints = shape[rStart:rEnd]
191 |
192 | # compute the center of mass for each of the eyes
193 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int")
194 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int")
195 |
196 | # compute the angle between the eye centroids
197 | dY = rightEyeCenter[1] - leftEyeCenter[1]
198 | dX = rightEyeCenter[0] - leftEyeCenter[0]
199 | angle = np.degrees(np.arctan2(dY, dX))
200 |
201 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2,
202 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2)
203 |
204 | '''
205 | applies composite perturbations to a single image
206 | the perturbations to apply are selected randomly
207 | '''
208 | def perturb_image_composite(face_image, landmark):
209 | perturbation_functions = [
210 | translate_horizontal,
211 | translate_vertical,
212 | rotate_image,
213 | resize_image,
214 | # shear_image,
215 | # image_flip,
216 | distort_image,
217 | ]
218 |
219 | perturbation_function_map = {
220 | translate_horizontal : [-translation_range, translation_range, 1],
221 | translate_vertical : [-translation_range, translation_range, 1],
222 | rotate_image : [-rotation_range, rotation_range, 1],
223 | resize_image : [scale_ranges[0], scale_ranges[1], 100],
224 | # shear_image : [-10, 10, 100],
225 | # image_flip : [1, 1, 1],
226 | distort_image : [0, len(Distortion), 1],
227 | }
228 |
229 | gt_transformations = {
230 | 'translate_horizontal': 0,
231 | 'translate_vertical': 0,
232 | 'rotate_image': 0
233 | }
234 |
235 | eyes_center = find_eye_center(landmark)
236 |
237 | # maintains the perturbations to apply to the image
238 | composite_perturbations = list()
239 | # ensures atleast one perturbation is produced
240 | while len(composite_perturbations) == 0:
241 | for i, perturbation_function in enumerate(perturbation_functions):
242 | if random.randint(0, 1):
243 | composite_perturbations.append(perturbation_function)
244 |
245 | # print(f'Perturbations applied : {composite_perturbations}', flush=True)
246 |
247 | for perturbation_function in composite_perturbations:
248 | perturbation_map = perturbation_function_map[perturbation_function]
249 | perturbation_value = random.randint(perturbation_map[0], perturbation_map[1])/perturbation_map[2]
250 | normalized_value = perturbation_value/perturbation_map[1]
251 |
252 | if perturbation_function == translate_horizontal:
253 | gt_transformations['translate_horizontal'] = perturbation_value
254 | elif perturbation_function == translate_vertical:
255 | gt_transformations['translate_vertical'] = perturbation_value
256 | else:
257 | gt_transformations['rotate_image'] = perturbation_value
258 |
259 | if perturbation_function == rotate_image:
260 | face_image = perturbation_function(perturbation_value, face_image, center=eyes_center)
261 | else:
262 | face_image = perturbation_function(perturbation_value, face_image)
263 |
264 | return face_image, gt_transformations
265 |
266 | '''
267 | the function specifies the perturbation functions and the amounts by which corresponding perturbations are applied
268 | the perturbed image (foreground) is combined with the face_mask (background) to reproduce blending seen due to two different flows
269 | applies multiple perturbations (composite perturbations) to generate complex perturbations
270 | '''
271 | def perturb_image(face_image):
272 | perturbation_functions = [
273 | translate_horizontal,
274 | translate_vertical,
275 | rotate_image,
276 | resize_image
277 | ]
278 |
279 | perturbation_function_map = {
280 | translate_horizontal : [-20, 20, 1],
281 | translate_vertical : [-20, 20, 1],
282 | rotate_image : [-25, 25, 1],
283 | resize_image : [90, 110, 100]
284 | }
285 |
286 | random_perturbation_index = random.randint(0, len(perturbation_functions)-1)
287 | # random_perturbation_index = 0 # used for debugging
288 | perturbation_function = perturbation_functions[random_perturbation_index]
289 | perturbation_map = perturbation_function_map[perturbation_function]
290 | perturbation_value = random.randint(perturbation_map[0], perturbation_map[1])/perturbation_map[2]
291 | # print(f'Using perturbation : {random_perturbation_index}, with value : {perturbation_value}', flush=True)
292 | intermediate_perturbed_image = perturbation_function(perturbation_value, face_image)
293 | # perturbed_image = combine_images(face_mask, intermediate_perturbed_image)
294 |
295 | return intermediate_perturbed_image
296 |
297 | '''
298 | def test_sample():
299 | file = '/ssd_scratch/cvit/aditya1/CelebA-HQ-img/13842.jpg'
300 | gpu_id = 0
301 | parsing, image = generate_segmentation(file, gpu_id)
302 | for i in range(PERTURBATIONS_PER_IDENTITY):
303 | face_image, background_image = generate_segmented_face(parsing, image)
304 | perturbed_image = perturb_image_composite(face_image, background_image)
305 | perturbed_filename, extension = osp.basename(file).split('.')
306 | perturbed_image_path = osp.join(perturbed_image_dir, perturbed_filename + '_' + str(i) + '.' + extension)
307 | save_image(perturbed_image_path, perturbed_image)
308 | '''
--------------------------------------------------------------------------------
/TemporalAlignment/ranges.py:
--------------------------------------------------------------------------------
1 | translation_range = 3
2 | rotation_range = 3
3 | scale_ranges = 90,110
--------------------------------------------------------------------------------
/bad_mp4s.json:
--------------------------------------------------------------------------------
1 | ["iOqgDmaYdqA/00028.mp4", "iOqgDmaYdqA/00190.mp4", "iOqgDmaYdqA/00027.mp4", "iOqgDmaYdqA/00000.mp4", "y9lSgYW2ObQ/00036.mp4", "y9lSgYW2ObQ/00041.mp4", "XGZMe5DsnRU/00000.mp4", "yYmSVsMXve0/00017.mp4", "yYmSVsMXve0/00028.mp4", "yYmSVsMXve0/00013.mp4", "yYmSVsMXve0/00047.mp4", "yYmSVsMXve0/00008.mp4", "yYmSVsMXve0/00007.mp4", "yYmSVsMXve0/00015.mp4", "yYmSVsMXve0/00009.mp4", "yYmSVsMXve0/00010.mp4", "TZpP05GHWnw/00191.mp4", "TZpP05GHWnw/00175.mp4", "TZpP05GHWnw/00196.mp4", "TZpP05GHWnw/00192.mp4", "TZpP05GHWnw/00170.mp4", "TZpP05GHWnw/00246.mp4", "TZpP05GHWnw/00206.mp4", "TZpP05GHWnw/00189.mp4", "TZpP05GHWnw/00211.mp4", "TZpP05GHWnw/00188.mp4", "TZpP05GHWnw/00187.mp4", "TZpP05GHWnw/00247.mp4", "TZpP05GHWnw/00186.mp4", "TZpP05GHWnw/00245.mp4", "TZpP05GHWnw/00207.mp4", "TZpP05GHWnw/00143.mp4", "TZpP05GHWnw/00171.mp4", "TZpP05GHWnw/00214.mp4", "TZpP05GHWnw/00227.mp4", "TZpP05GHWnw/00212.mp4", "TZpP05GHWnw/00205.mp4", "TZpP05GHWnw/00152.mp4", "TZpP05GHWnw/00193.mp4", "IGl4w4a_pcw/00073.mp4", "IGl4w4a_pcw/00108.mp4", "IGl4w4a_pcw/00055.mp4", "IGl4w4a_pcw/00047.mp4", "IGl4w4a_pcw/00066.mp4", "IGl4w4a_pcw/00107.mp4", "IGl4w4a_pcw/00057.mp4", "IGl4w4a_pcw/00048.mp4", "IGl4w4a_pcw/00051.mp4", "IGl4w4a_pcw/00049.mp4", "IGl4w4a_pcw/00056.mp4", "IGl4w4a_pcw/00000.mp4", "IGl4w4a_pcw/00050.mp4", "IGl4w4a_pcw/00212.mp4", "UjX-4h6UahQ/00005.mp4", "UjX-4h6UahQ/00002.mp4", "UjX-4h6UahQ/00001.mp4", "UjX-4h6UahQ/00000.mp4", "WOEcMOEmbC8/00037.mp4", "WOEcMOEmbC8/00031.mp4", "ozEAkq3qSi0/00063.mp4", "ozEAkq3qSi0/00273.mp4", "ozEAkq3qSi0/00175.mp4", "ozEAkq3qSi0/00122.mp4", "ozEAkq3qSi0/00275.mp4", "ozEAkq3qSi0/00095.mp4", "ozEAkq3qSi0/00100.mp4", "ozEAkq3qSi0/00153.mp4", "ozEAkq3qSi0/00277.mp4", "ozEAkq3qSi0/00204.mp4", "ozEAkq3qSi0/00278.mp4", "ozEAkq3qSi0/00272.mp4", "ozEAkq3qSi0/00174.mp4", "ozEAkq3qSi0/00276.mp4", "ozEAkq3qSi0/00205.mp4", "ozEAkq3qSi0/00152.mp4", "ozEAkq3qSi0/00202.mp4", "0uNEEfQKfwk/00112.mp4", "0uNEEfQKfwk/00298.mp4", "0uNEEfQKfwk/00106.mp4", "0uNEEfQKfwk/00318.mp4", "0uNEEfQKfwk/00174.mp4", "0uNEEfQKfwk/00276.mp4", "0uNEEfQKfwk/00297.mp4", "0uNEEfQKfwk/00000.mp4", "c3tTE3hzON8/00174.mp4", "dUrBaT_ChGc/00007.mp4", "dUrBaT_ChGc/00029.mp4", "C93trWdL3Rg/00222.mp4", "C93trWdL3Rg/00191.mp4", "C93trWdL3Rg/00200.mp4", "C93trWdL3Rg/00017.mp4", "C93trWdL3Rg/00135.mp4", "C93trWdL3Rg/00072.mp4", "C93trWdL3Rg/00185.mp4", "C93trWdL3Rg/00131.mp4", "C93trWdL3Rg/00067.mp4", "C93trWdL3Rg/00016.mp4", "C93trWdL3Rg/00073.mp4", "C93trWdL3Rg/00175.mp4", "C93trWdL3Rg/00108.mp4", "C93trWdL3Rg/00215.mp4", "C93trWdL3Rg/00055.mp4", "C93trWdL3Rg/00104.mp4", "C93trWdL3Rg/00047.mp4", "C93trWdL3Rg/00181.mp4", "C93trWdL3Rg/00094.mp4", "C93trWdL3Rg/00122.mp4", "C93trWdL3Rg/00074.mp4", "C93trWdL3Rg/00161.mp4", "C93trWdL3Rg/00225.mp4", "C93trWdL3Rg/00142.mp4", "C93trWdL3Rg/00224.mp4", "C93trWdL3Rg/00066.mp4", "C93trWdL3Rg/00130.mp4", "C93trWdL3Rg/00071.mp4", "C93trWdL3Rg/00125.mp4", "C93trWdL3Rg/00095.mp4", "C93trWdL3Rg/00190.mp4", "C93trWdL3Rg/00037.mp4", "C93trWdL3Rg/00206.mp4", "C93trWdL3Rg/00110.mp4", "C93trWdL3Rg/00134.mp4", "C93trWdL3Rg/00132.mp4", "C93trWdL3Rg/00042.mp4", "C93trWdL3Rg/00179.mp4", "C93trWdL3Rg/00219.mp4", "C93trWdL3Rg/00189.mp4", "C93trWdL3Rg/00029.mp4", "C93trWdL3Rg/00209.mp4", "C93trWdL3Rg/00048.mp4", "C93trWdL3Rg/00041.mp4", "C93trWdL3Rg/00211.mp4", "C93trWdL3Rg/00051.mp4", "C93trWdL3Rg/00201.mp4", "C93trWdL3Rg/00176.mp4", "C93trWdL3Rg/00147.mp4", "C93trWdL3Rg/00081.mp4", "C93trWdL3Rg/00083.mp4", "C93trWdL3Rg/00006.mp4", "C93trWdL3Rg/00154.mp4", "C93trWdL3Rg/00140.mp4", "C93trWdL3Rg/00045.mp4", "C93trWdL3Rg/00024.mp4", "C93trWdL3Rg/00168.mp4", "C93trWdL3Rg/00049.mp4", "C93trWdL3Rg/00156.mp4", "C93trWdL3Rg/00136.mp4", "C93trWdL3Rg/00188.mp4", "C93trWdL3Rg/00054.mp4", "C93trWdL3Rg/00009.mp4", "C93trWdL3Rg/00230.mp4", "C93trWdL3Rg/00160.mp4", "C93trWdL3Rg/00027.mp4", "C93trWdL3Rg/00164.mp4", "C93trWdL3Rg/00084.mp4", "C93trWdL3Rg/00187.mp4", "C93trWdL3Rg/00150.mp4", "C93trWdL3Rg/00091.mp4", "C93trWdL3Rg/00121.mp4", "C93trWdL3Rg/00204.mp4", "C93trWdL3Rg/00207.mp4", "C93trWdL3Rg/00031.mp4", "C93trWdL3Rg/00022.mp4", "C93trWdL3Rg/00059.mp4", "C93trWdL3Rg/00174.mp4", "C93trWdL3Rg/00010.mp4", "C93trWdL3Rg/00109.mp4", "C93trWdL3Rg/00098.mp4", "C93trWdL3Rg/00220.mp4", "C93trWdL3Rg/00143.mp4", "C93trWdL3Rg/00118.mp4", "C93trWdL3Rg/00228.mp4", "C93trWdL3Rg/00018.mp4", "C93trWdL3Rg/00119.mp4", "C93trWdL3Rg/00033.mp4", "C93trWdL3Rg/00056.mp4", "C93trWdL3Rg/00167.mp4", "C93trWdL3Rg/00138.mp4", "C93trWdL3Rg/00169.mp4", "C93trWdL3Rg/00025.mp4", "C93trWdL3Rg/00208.mp4", "C93trWdL3Rg/00030.mp4", "C93trWdL3Rg/00082.mp4", "C93trWdL3Rg/00014.mp4", "C93trWdL3Rg/00068.mp4", "C93trWdL3Rg/00114.mp4", "C93trWdL3Rg/00182.mp4", "C93trWdL3Rg/00226.mp4", "C93trWdL3Rg/00078.mp4", "C93trWdL3Rg/00145.mp4", "C93trWdL3Rg/00227.mp4", "C93trWdL3Rg/00127.mp4", "C93trWdL3Rg/00050.mp4", "C93trWdL3Rg/00178.mp4", "C93trWdL3Rg/00103.mp4", "C93trWdL3Rg/00039.mp4", "C93trWdL3Rg/00026.mp4", "C93trWdL3Rg/00124.mp4", "C93trWdL3Rg/00123.mp4", "C93trWdL3Rg/00229.mp4", "C93trWdL3Rg/00035.mp4", "C93trWdL3Rg/00205.mp4", "C93trWdL3Rg/00180.mp4", "JGCg3ARv14U/00058.mp4", "JGCg3ARv14U/00029.mp4", "JGCg3ARv14U/00059.mp4", "JGCg3ARv14U/00000.mp4", "MTgyYqY3XTE/00181.mp4", "MTgyYqY3XTE/00119.mp4", "MTgyYqY3XTE/00184.mp4", "MTgyYqY3XTE/00001.mp4", "MTgyYqY3XTE/00000.mp4", "wuS8Hq1Yhro/00017.mp4", "wuS8Hq1Yhro/00023.mp4", "wuS8Hq1Yhro/00046.mp4", "wuS8Hq1Yhro/00016.mp4", "wuS8Hq1Yhro/00044.mp4", "wuS8Hq1Yhro/00028.mp4", "wuS8Hq1Yhro/00013.mp4", "wuS8Hq1Yhro/00047.mp4", "wuS8Hq1Yhro/00012.mp4", "wuS8Hq1Yhro/00008.mp4", "wuS8Hq1Yhro/00032.mp4", "wuS8Hq1Yhro/00037.mp4", "wuS8Hq1Yhro/00036.mp4", "wuS8Hq1Yhro/00042.mp4", "wuS8Hq1Yhro/00007.mp4", "wuS8Hq1Yhro/00034.mp4", "wuS8Hq1Yhro/00029.mp4", "wuS8Hq1Yhro/00015.mp4", "wuS8Hq1Yhro/00048.mp4", "wuS8Hq1Yhro/00020.mp4", "wuS8Hq1Yhro/00004.mp4", "wuS8Hq1Yhro/00041.mp4", "wuS8Hq1Yhro/00006.mp4", "wuS8Hq1Yhro/00005.mp4", "wuS8Hq1Yhro/00045.mp4", "wuS8Hq1Yhro/00024.mp4", "wuS8Hq1Yhro/00019.mp4", "wuS8Hq1Yhro/00009.mp4", "wuS8Hq1Yhro/00027.mp4", "wuS8Hq1Yhro/00043.mp4", "wuS8Hq1Yhro/00031.mp4", "wuS8Hq1Yhro/00022.mp4", "wuS8Hq1Yhro/00040.mp4", "wuS8Hq1Yhro/00010.mp4", "wuS8Hq1Yhro/00018.mp4", "wuS8Hq1Yhro/00002.mp4", "wuS8Hq1Yhro/00033.mp4", "wuS8Hq1Yhro/00011.mp4", "wuS8Hq1Yhro/00025.mp4", "wuS8Hq1Yhro/00030.mp4", "wuS8Hq1Yhro/00014.mp4", "wuS8Hq1Yhro/00001.mp4", "wuS8Hq1Yhro/00038.mp4", "wuS8Hq1Yhro/00000.mp4", "wuS8Hq1Yhro/00039.mp4", "wuS8Hq1Yhro/00026.mp4", "wuS8Hq1Yhro/00035.mp4", "wuS8Hq1Yhro/00021.mp4", "wuS8Hq1Yhro/00003.mp4"]
--------------------------------------------------------------------------------
/bash_scripts/train_videovqvae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=acm_abl_wo_temporalmodule
4 | #SBATCH --mem-per-cpu=2048
5 | #SBATCH --partition long
6 | #SBATCH --account research
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --mincpus=38
9 | #SBATCH --nodes=1
10 | #SBATCH --time 4-00:00:00
11 | #SBATCH --signal=B:HUP@600
12 | #SBATCH -w gnode031
13 |
14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact
15 |
16 | source /home2/aditya1/miniconda3/bin/activate base
17 |
18 | CUDA_VISIBLE_DEVICES=0 python train_vqvae_perceptual.py --epoch 1000 --colorjit const --batch_size 1
--------------------------------------------------------------------------------
/bash_scripts/train_videovqvae_perturbations.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=acmabl_wo_translation
4 | #SBATCH --mem-per-cpu=2048
5 | #SBATCH --partition long
6 | #SBATCH --account cvit_bhaasha
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --mincpus=38
9 | #SBATCH --nodes=1
10 | #SBATCH --time 10-00:00:00
11 | #SBATCH --signal=B:HUP@600
12 | #SBATCH -w gnode061
13 |
14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact
15 |
16 | source /home2/aditya1/miniconda3/bin/activate base
17 |
18 | CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 \
19 | python train_vqvae_perceptual.py --epoch 1000 \
20 | --colorjit const --batch_size 1 \
21 | --sample_folder /ssd_scratch/cvit/aditya1/video_vqvae2_results/ablation_translation_disabled_samples \
22 | --checkpoint_dir /ssd_scratch/cvit/aditya1/video_vqvae2_results/ablation_translation_disabled_checkpoints
--------------------------------------------------------------------------------
/bash_scripts/train_videovqvae_resume_ckpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=acmwotemporal_abl
4 | #SBATCH --mem-per-cpu=2048
5 | #SBATCH --partition long
6 | #SBATCH --account research
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --mincpus=38
9 | #SBATCH --nodes=1
10 | #SBATCH --time 4-00:00:00
11 | #SBATCH --signal=B:HUP@600
12 | #SBATCH -w gnode031
13 |
14 | cd /ssd_scratch/cvit/aditya1/acm_rebuttal/video_vqvae/VQVAE2-Refact
15 |
16 | source /home2/aditya1/miniconda3/bin/activate base
17 |
18 | CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 python train_vqvae_perceptual.py \
19 | --epoch 1000 --colorjit const --batch_size 1 \
20 | --ckpt /ssd_scratch/cvit/aditya1/video_vqvae2_results/checkpoint_iquff/vqvae_7_5121.pt \
21 | --sample_folder /ssd_scratch/cvit/aditya1/video_vqvae2_results/without_temporal_resume_samples \
22 | --checkpoint_dir /ssd_scratch/cvit/aditya1/video_vqvae2_results/without_temporal_resume_checkpoints
--------------------------------------------------------------------------------
/command.txt:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python train_video_vqvae.py --epoch 20000 --ckpt vqvae_068.pt
2 |
3 | CUDA_VISIBLE_DEVICES=1 python train_video_vqvae.py --epoch 20000 --max_frame_len 8 --ckpt checkpoint_target_all5losses/vqvae_501.pt,checkpoint_target_all5losses/adversarial_501.pt --validate_at 2048 --epoch 200000
4 |
5 | CUDA_VISIBLE_DEVICES=0 python train_video_vqvae.py --epoch 20000 --max_frame_len 8 --ckpt checkpoint_target_all5losses/vqvae_5301.pt,checkpoint_target_all5losses/adversarial_5301.pt --validate_at 2048 --epoch 200000
6 |
7 | CUDA_VISIBLE_DEVICES=2 python train_video_vqvae.py --max_frame_len 8 --ckpt checkpoint_vlog_all5losses/vqvae_5401.pt,checkpoint_vlog_all5losses/adversarial_5401.pt
8 |
9 |
10 | CUDA_VISIBLE_DEVICES=2 python train_video_vqvae_nta.py --max_frame_len 8 --ckpt multi_005.pt --validate_at 512 --epoch 200000
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Config file containing the hyperparameters
3 | '''
4 | DATASET = 11
5 | LATENT_LOSS_WEIGHT = 1 # hyperparameter weight for the latent loss
6 | PERCEPTUAL_LOSS_WEIGHT = 1 # hyperparameter weight for the perceptual loss
7 |
8 | # weights for the mocoganhd discriminator
9 | G_LOSS_2D_WEIGHT = 0.25
10 | G_LOSS_3D_WEIGHT = 0.25
11 |
12 | image_disc_weight = 0.5
13 | video_disc_weight = 0.5
14 |
15 | D_LOSS_WEIGHT = 0.1
16 |
17 | SAMPLE_SIZE_FOR_VISUALIZATION = 8
18 | DISC_LOSS_WEIGHT = 0.25 # TODO - modify
--------------------------------------------------------------------------------
/copy.sh:
--------------------------------------------------------------------------------
1 | cp /home2/aditya1/cvit/content_sync/FaceOff/*.py .
2 | cp /home2/aditya1/cvit/content_sync/FaceOff/datasets/*.py datasets/.
3 | cp /home2/aditya1/cvit/content_sync/FaceOff/disc_trainers/*.py disc_trainers/.
4 | cp /home2/aditya1/cvit/content_sync/FaceOff/models/*.py models/.
5 | cp -r /home2/aditya1/cvit/content_sync/FaceOff/TemporalAlignment/* TemporalAlignment/.
--------------------------------------------------------------------------------
/datasets/face_translation_videos3_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 | from glob import glob
4 | import os
5 | import json
6 | import os.path as osp
7 |
8 | from skimage import io
9 | from skimage.transform import resize
10 | from scipy.ndimage import laplace
11 | import numpy as np
12 | import cv2
13 | from skimage import transform as tf
14 |
15 | from scipy.ndimage import laplace
16 |
17 | target_without_face_apply = True
18 |
19 | def resize_frame(frame, resize_dim=256):
20 | h, w, _ = frame.shape
21 |
22 | if h > w:
23 | padw, padh = (h-w)//2, 0
24 | else:
25 | padw, padh = 0, (w-h)//2
26 |
27 | padded = cv2.copyMakeBorder(frame, padh, padh, padw, padw, cv2.BORDER_CONSTANT, value=0)
28 | padded = cv2.resize(padded, (resize_dim, resize_dim), interpolation=cv2.INTER_LINEAR)
29 |
30 | return padded
31 |
32 | def readPoints(nparray) :
33 | points = []
34 |
35 | for row in nparray:
36 | x, y = row[0], row[1]
37 | points.append((int(x), int(y)))
38 |
39 | return points
40 |
41 | def generate_convex_hull(img, points):
42 | # points = np.load(landmark_path, allow_pickle=True)['landmark'].astype(np.uint8)
43 | points = readPoints(points)
44 |
45 | hull = []
46 | hullIndex = cv2.convexHull(np.array(points), returnPoints = False)
47 |
48 | for i in range(0, len(hullIndex)):
49 | hull.append(points[int(hullIndex[i])])
50 |
51 | sizeImg = img.shape
52 | rect = (0, 0, sizeImg[1], sizeImg[0])
53 |
54 | hull8U = []
55 | for i in range(0, len(hull)):
56 | hull8U.append((hull[i][0], hull[i][1]))
57 |
58 | mask = np.zeros(img.shape, dtype = img.dtype)
59 |
60 | cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))
61 |
62 | # convex_face = ((mask/255.) * img).astype(np.uint8)
63 |
64 | return mask
65 |
66 | def enlarge_mask(img_mask, enlargement=5):
67 | img1 = img_mask.copy()
68 | img = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
69 |
70 | ret, thresh = cv2.threshold(img,50,255,0)
71 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
72 |
73 | for i in range(len(contours)):
74 | hull = cv2.convexHull(contours[i])
75 | cv2.drawContours(img1, [hull], -1, (255, 255, 255), enlargement)
76 |
77 | return img1
78 |
79 | def poisson_blend(target_img, src_img, mask_img, iter: int = 1024):
80 | for _ in range(iter):
81 | target_img = target_img + 0.25 * mask_img * laplace(target_img - src_img)
82 | return target_img.clip(0, 1)
83 |
84 | # -- Face Transformation
85 | def warp_img(src, dst, img, std_size):
86 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix
87 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image
88 | warped = warped * 255 # note output from wrap is double image (value range [0,1])
89 | warped = warped.astype('uint8')
90 | return warped, tform
91 |
92 | def apply_transform(transform, img, std_size):
93 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
94 | warped = warped * 255 # note output from wrap is double image (value range [0,1])
95 | warped = warped.astype('uint8')
96 | return warped
97 |
98 | # Method to combine the face segmented with the face segmentation mask
99 | def combine_images(face_mask, face_image, generate_mask=True):
100 | image_masked = face_mask.copy()
101 | if generate_mask:
102 | mask = face_image[..., 0] != 0
103 | image_masked[mask] = 0
104 |
105 | combined_image = image_masked + face_image
106 |
107 | return combined_image
108 |
109 | # computes the rotation of the face using the angle of the line connecting the eye centroids
110 | def compute_rotation(shape):
111 | # landmark coordinates corresponding to the eyes
112 | lStart, lEnd = 36, 41
113 | rStart, rEnd = 42, 47
114 |
115 | # landmarks for the left and right eyes
116 | leftEyePoints = shape[lStart:lEnd]
117 | rightEyePoints = shape[rStart:rEnd]
118 |
119 | # compute the center of mass for each of the eyes
120 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int")
121 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int")
122 |
123 | # compute the angle between the eye centroids
124 | dY = rightEyeCenter[1] - leftEyeCenter[1]
125 | dX = rightEyeCenter[0] - leftEyeCenter[0]
126 | angle = np.degrees(np.arctan2(dY, dX))
127 |
128 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2,
129 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2)
130 |
131 | dist = np.sqrt((dX ** 2) + (dY ** 2)) # this indicates the distance between the two eyes
132 |
133 | return angle, eyesCenter, dist
134 |
135 | def apply_mask(mask, image):
136 | return ((mask / 255.) * image).astype(np.uint8)
137 |
138 | # code to generate the alignment between the source and the target image
139 | def generate_warped_image(source_landmark_npz, target_landmark_npz,
140 | source_image_path, target_image_path,
141 | poisson_blend_required = False,
142 | require_full_mask = False):
143 |
144 | stablePoints = [33, 36, 39, 42, 45]
145 | std_size = (256, 256)
146 |
147 | source_image = resize_frame(io.imread(source_image_path))
148 | target_image = resize_frame(io.imread(target_image_path))
149 |
150 | source_landmarks = np.load(source_landmark_npz)['landmark']
151 |
152 | target_landmarks = np.load(target_landmark_npz)['landmark']
153 |
154 | if require_full_mask:
155 | source_convex_mask = generate_convex_hull(source_image, source_landmarks)
156 | source_convex_mask_no_enlargement = source_convex_mask.copy()
157 | else:
158 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[17:])
159 | # enlarge the convex mask
160 | source_convex_mask_no_enlargement = source_convex_mask.copy()
161 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=10)
162 |
163 | # apply the convex mask to the face
164 | source_face_segmented = apply_mask(source_convex_mask, source_image)
165 | source_face_transformed, transformation = warp_img(source_landmarks[stablePoints, :],
166 | target_landmarks[stablePoints, :],
167 | source_face_segmented,
168 | std_size)
169 |
170 | source_convex_mask_transformed = apply_transform(transformation, source_convex_mask, std_size)
171 | source_convex_mask_no_enlargement_transformed = apply_transform(transformation, source_convex_mask_no_enlargement, std_size)
172 | source_image_transformed = apply_transform(transformation, source_image, std_size)
173 |
174 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks))
175 | # target_background = apply_mask(target_convex_mask, target_image)
176 |
177 | target_convex_mask_without_jaw = generate_convex_hull(target_image, target_landmarks[17:])
178 | target_convex_mask_without_jaw = enlarge_mask(target_convex_mask_without_jaw, enlargement=10)
179 |
180 | target_convex_mask_without_jaw = np.invert(target_convex_mask_without_jaw)
181 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image)
182 | target_without_face = apply_mask(target_convex_mask, target_image)
183 |
184 | if poisson_blend_required:
185 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.)
186 | else:
187 | if target_without_face_apply:
188 | combined_image = combine_images(target_without_face, source_face_transformed)
189 | else:
190 | combined_image = combine_images(target_image, source_face_transformed)
191 | # apply the transformed convex mask to the target face for sanity
192 | # target_masked = apply_mask(source_convex_mask_transformed, target_image)
193 |
194 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, source_convex_mask_no_enlargement, target_image, target_convex_mask, combined_image, target_without_face_features, source_image
195 |
196 | # code to generate the alignment between the source and the target image
197 | def generate_aligned_image(source_landmark_npz, target_landmark_npz,
198 | source_image_path, target_image_path,
199 | poisson_blend_required = False,
200 | require_full_mask = False):
201 |
202 | source_image = resize_frame(io.imread(source_image_path))
203 | target_image = resize_frame(io.imread(target_image_path))
204 |
205 | source_landmarks = np.load(source_landmark_npz)['landmark']
206 | source_rotation, source_center, source_distance = compute_rotation(source_landmarks)
207 |
208 | target_landmarks = np.load(target_landmark_npz)['landmark']
209 | target_rotation, target_center, target_distance = compute_rotation(target_landmarks)
210 |
211 | # rotation of the target conditioned on the source orientation
212 | target_conditioned_source_rotation = source_rotation - target_rotation
213 |
214 | # calculate the scaling that needs to be applied on the source image
215 | scaling = target_distance / source_distance
216 |
217 | # apply the rotation on the source image
218 | height, width = 256, 256
219 | # print(f'Angle of rotation is : {target_conditioned_source_rotation}')
220 | rotate_matrix = cv2.getRotationMatrix2D(center=source_center, angle=target_conditioned_source_rotation, scale=scaling)
221 |
222 | # calculate the translation component of the matrix M
223 | rotate_matrix[0, 2] += (target_center[0] - source_center[0])
224 | rotate_matrix[1, 2] += (target_center[1] - source_center[1])
225 |
226 | if require_full_mask:
227 | source_convex_mask = generate_convex_hull(source_image, source_landmarks)
228 | else:
229 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[17:])
230 | # enlarge the convex mask using the enlargement
231 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=5)
232 |
233 | # apply the convex mask to the face
234 | source_face_segmented = apply_mask(source_convex_mask, source_image)
235 | source_face_transformed = cv2.warpAffine(source_face_segmented, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
236 | source_convex_mask_transformed = cv2.warpAffine(source_convex_mask, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
237 | source_image_transformed = cv2.warpAffine(source_image, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
238 |
239 | # used for computing the target background
240 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks))
241 | # target_background = ((target_convex_mask/255.)*target_image).astype(np.uint8)
242 | target_convex_mask_without_jaw = np.invert(generate_convex_hull(target_image, target_landmarks[17:]))
243 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image)
244 | target_without_face = apply_mask(target_convex_mask, target_image)
245 |
246 | if poisson_blend_required:
247 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.)
248 | else:
249 | if target_without_face_apply:
250 | combined_image = combine_images(target_without_face, source_face_transformed)
251 | else:
252 | combined_image = combine_images(target_image, source_face_transformed)
253 |
254 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, target_image, target_convex_mask, combined_image
255 |
--------------------------------------------------------------------------------
/datasets/face_translation_videos3_utils_bb.py:
--------------------------------------------------------------------------------
1 | '''
2 | For generating the masked region, use the bounding box around the landmarks
3 | '''
4 |
5 | import sys
6 | import random
7 | from glob import glob
8 | import os
9 | import json
10 | import os.path as osp
11 |
12 | import cv2
13 | import numpy as np
14 | from skimage import io
15 | from skimage.transform import resize
16 | from scipy.ndimage import laplace
17 | from skimage import transform as tf
18 |
19 | from scipy.ndimage import laplace
20 |
21 | target_without_face_apply = True
22 |
23 | requires_bb = False
24 | extract_lip_region = False
25 |
26 | '''
27 | coordinates of the lip region
28 | '''
29 | if extract_lip_region:
30 | start_idx = 49
31 | end_idx = 61
32 | else:
33 | start_idx = 0
34 | end_idx = 67
35 |
36 | '''
37 | resize the frame
38 | the learned transformation is a combination of affine, non-linear, and color transformations
39 | '''
40 | def resize_frame(frame, resize_dim=256):
41 | h, w, _ = frame.shape
42 |
43 | if h > w:
44 | padw, padh = (h-w)//2, 0
45 | else:
46 | padw, padh = 0, (w-h)//2
47 |
48 | padded = cv2.copyMakeBorder(frame, padh, padh, padw, padw, cv2.BORDER_CONSTANT, value=0)
49 | padded = cv2.resize(padded, (resize_dim, resize_dim), interpolation=cv2.INTER_LINEAR)
50 |
51 | return padded
52 |
53 | def readPoints(nparray) :
54 | points = []
55 |
56 | for row in nparray:
57 | x, y = row[0], row[1]
58 | points.append((int(x), int(y)))
59 |
60 | return points
61 |
62 | '''
63 | generate the convex hull based on the coordinates of the bounding box
64 | '''
65 | def generate_convex_hull_bb(img, points):
66 | mask = np.zeros(img.shape, dtype=img.dtype)
67 | min_x, max_x, min_y, max_y = estimate_bb_coordinates(points)
68 | mask[min_y:max_y, min_x:max_x] = 255
69 |
70 | return mask
71 |
72 | '''
73 | generates the convex hull mask of the face
74 | '''
75 | def generate_convex_hull(img, points):
76 | # points = np.load(landmark_path, allow_pickle=True)['landmark'].astype(np.uint8)
77 | points = readPoints(points)
78 |
79 | hull = []
80 | hullIndex = cv2.convexHull(np.array(points), returnPoints = False)
81 |
82 | for i in range(0, len(hullIndex)):
83 | hull.append(points[int(hullIndex[i])])
84 |
85 | sizeImg = img.shape
86 | rect = (0, 0, sizeImg[1], sizeImg[0])
87 |
88 | hull8U = []
89 | for i in range(0, len(hull)):
90 | hull8U.append((hull[i][0], hull[i][1]))
91 |
92 | mask = np.zeros(img.shape, dtype = img.dtype)
93 |
94 | cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))
95 |
96 | return mask
97 |
98 | '''
99 | Enlarge the face mask to accomodate the region close to the face boundary
100 | '''
101 | def enlarge_mask(img_mask, enlargement=5):
102 | img1 = img_mask.copy()
103 | img = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
104 |
105 | ret, thresh = cv2.threshold(img,50,255,0)
106 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
107 |
108 | for i in range(len(contours)):
109 | hull = cv2.convexHull(contours[i])
110 | cv2.drawContours(img1, [hull], -1, (255, 255, 255), enlargement)
111 |
112 | return img1
113 |
114 | '''
115 | Heuristic blending of the face images using poisson blend
116 | '''
117 | def poisson_blend(target_img, src_img, mask_img, iter: int = 1024):
118 | for _ in range(iter):
119 | target_img = target_img + 0.25 * mask_img * laplace(target_img - src_img)
120 | return target_img.clip(0, 1)
121 |
122 | # -- Face Transformation
123 | def warp_img(src, dst, img, std_size):
124 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix
125 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image
126 | warped = warped * 255 # note output from wrap is double image (value range [0,1])
127 | warped = warped.astype('uint8')
128 | return warped, tform
129 |
130 |
131 | def apply_transform(transform, img, std_size):
132 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
133 | warped = warped * 255 # note output from wrap is double image (value range [0,1])
134 | warped = warped.astype('uint8')
135 | return warped
136 |
137 | '''
138 | Combines the segmented face image (foreground) with the segmented face mask (background)
139 | '''
140 | # Method to combine the face segmented with the face segmentation mask
141 | def combine_images(face_mask, face_image, generate_mask=True):
142 | image_masked = face_mask.copy()
143 | if generate_mask:
144 | mask = face_image[..., 0] != 0
145 | image_masked[mask] = 0
146 |
147 | combined_image = image_masked + face_image
148 |
149 | return combined_image
150 |
151 |
152 | '''
153 | Computes the rotation of the face w.r.t the horizontal
154 | The angle of the line connecting the eye centroids is estimated
155 | '''
156 | def compute_rotation(shape):
157 | # landmark coordinates corresponding to the eyes
158 | lStart, lEnd = 36, 41
159 | rStart, rEnd = 42, 47
160 |
161 | # landmarks for the left and right eyes
162 | leftEyePoints = shape[lStart:lEnd]
163 | rightEyePoints = shape[rStart:rEnd]
164 |
165 | # compute the center of mass for each of the eyes
166 | leftEyeCenter = leftEyePoints.mean(axis=0).astype("int")
167 | rightEyeCenter = rightEyePoints.mean(axis=0).astype("int")
168 |
169 | # compute the angle between the eye centroids
170 | dY = rightEyeCenter[1] - leftEyeCenter[1]
171 | dX = rightEyeCenter[0] - leftEyeCenter[0]
172 | angle = np.degrees(np.arctan2(dY, dX))
173 |
174 | eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2,
175 | (leftEyeCenter[1] + rightEyeCenter[1]) / 2)
176 |
177 | dist = np.sqrt((dX ** 2) + (dY ** 2)) # this indicates the distance between the two eyes
178 |
179 | return angle, eyesCenter, dist
180 |
181 | '''
182 | Apply the mask on the input face image
183 | '''
184 | def apply_mask(mask, image):
185 | return ((mask / 255.) * image).astype(np.uint8)
186 |
187 | '''
188 | Estimate the coordinates of the bounding box from the face landmarks
189 | '''
190 | def estimate_bb_coordinates(landmarks, epsilon=10):
191 | min_x, max_x, min_y, max_y = int(np.min(landmarks[:, 0])), int(np.max(landmarks[:, 0])), \
192 | int(np.min(landmarks[:, 1])), int(np.max(landmarks[:, 1]))
193 |
194 | return min_x - epsilon, max_x + epsilon, min_y - epsilon, max_y + epsilon
195 |
196 | # code to generate the alignment between the source and the target image
197 | def generate_warped_image(source_landmark_npz, target_landmark_npz,
198 | source_image_path, target_image_path,
199 | poisson_blend_required = False,
200 | require_full_mask = False):
201 |
202 | stablePoints = [33, 36, 39, 42, 45]
203 | std_size = (256, 256)
204 |
205 | source_image = resize_frame(io.imread(source_image_path))
206 | target_image = resize_frame(io.imread(target_image_path))
207 |
208 | source_landmarks = np.load(source_landmark_npz)['landmark']
209 |
210 | target_landmarks = np.load(target_landmark_npz)['landmark']
211 |
212 | if require_full_mask:
213 | source_convex_mask = generate_convex_hull(source_image, source_landmarks)
214 | source_convex_mask_no_enlargement = source_convex_mask.copy()
215 | else:
216 | if requires_bb:
217 | source_convex_mask = generate_convex_hull_bb(source_image, source_landmarks[start_idx:end_idx])
218 | else:
219 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[start_idx:end_idx])
220 |
221 | # enlarge the convex mask
222 | source_convex_mask_no_enlargement = source_convex_mask.copy()
223 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=10)
224 |
225 | # apply the convex mask to the face
226 | source_face_segmented = apply_mask(source_convex_mask, source_image)
227 | source_face_transformed, transformation = warp_img(source_landmarks[stablePoints, :],
228 | target_landmarks[stablePoints, :],
229 | source_face_segmented,
230 | std_size)
231 |
232 | source_convex_mask_transformed = apply_transform(transformation, source_convex_mask, std_size)
233 | source_convex_mask_no_enlargement_transformed = apply_transform(transformation, source_convex_mask_no_enlargement, std_size)
234 | source_image_transformed = apply_transform(transformation, source_image, std_size)
235 |
236 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks))
237 | # target_background = apply_mask(target_convex_mask, target_image)
238 |
239 | if requires_bb:
240 | target_convex_mask_without_jaw = generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx])
241 | else:
242 | target_convex_mask_without_jaw = generate_convex_hull(target_image, target_landmarks[start_idx:end_idx])
243 |
244 | target_convex_mask_without_jaw = enlarge_mask(target_convex_mask_without_jaw, enlargement=10)
245 |
246 | target_convex_mask_without_jaw = np.invert(target_convex_mask_without_jaw)
247 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image)
248 | target_without_face = apply_mask(target_convex_mask, target_image)
249 |
250 | if poisson_blend_required:
251 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.)
252 | else:
253 | if target_without_face_apply:
254 | combined_image = combine_images(target_without_face, source_face_transformed)
255 | else:
256 | combined_image = combine_images(target_image, source_face_transformed)
257 | # apply the transformed convex mask to the target face for sanity
258 | # target_masked = apply_mask(source_convex_mask_transformed, target_image)
259 |
260 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, source_convex_mask_no_enlargement, target_image, target_convex_mask, combined_image, target_without_face_features, source_image
261 |
262 | # code to generate the alignment between the source and the target image
263 | def generate_aligned_image(source_landmark_npz, target_landmark_npz,
264 | source_image_path, target_image_path,
265 | poisson_blend_required = False,
266 | require_full_mask = False):
267 |
268 | source_image = resize_frame(io.imread(source_image_path))
269 | target_image = resize_frame(io.imread(target_image_path))
270 |
271 | source_landmarks = np.load(source_landmark_npz)['landmark']
272 | source_rotation, source_center, source_distance = compute_rotation(source_landmarks)
273 |
274 | target_landmarks = np.load(target_landmark_npz)['landmark']
275 | target_rotation, target_center, target_distance = compute_rotation(target_landmarks)
276 |
277 | # rotation of the source conditioned on the source orientation
278 | target_conditioned_source_rotation = source_rotation - target_rotation
279 |
280 | # calculate the scaling that needs to be applied on the source image
281 | scaling = target_distance / source_distance
282 |
283 | # apply the rotation on the source image
284 | height, width = 256, 256
285 | # print(f'Angle of rotation is : {target_conditioned_source_rotation}')
286 | rotate_matrix = cv2.getRotationMatrix2D(center=source_center, angle=target_conditioned_source_rotation, scale=scaling)
287 |
288 | # calculate the translation component of the matrix M
289 | rotate_matrix[0, 2] += (target_center[0] - source_center[0])
290 | rotate_matrix[1, 2] += (target_center[1] - source_center[1])
291 |
292 | if require_full_mask:
293 | source_convex_mask = generate_convex_hull(source_image, source_landmarks)
294 | else:
295 | if requires_bb:
296 | source_convex_mask = generate_convex_hull_bb(source_image, source_landmarks[start_idx:end_idx])
297 | else:
298 | source_convex_mask = generate_convex_hull(source_image, source_landmarks[start_idx:end_idx])
299 |
300 | # enlarge the convex mask using the enlargement
301 | source_convex_mask = enlarge_mask(source_convex_mask, enlargement=5)
302 |
303 | # apply the convex mask to the face
304 | source_face_segmented = apply_mask(source_convex_mask, source_image)
305 | source_face_transformed = cv2.warpAffine(source_face_segmented, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
306 | source_convex_mask_transformed = cv2.warpAffine(source_convex_mask, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
307 | source_image_transformed = cv2.warpAffine(source_image, rotate_matrix, (width, height), flags=cv2.INTER_CUBIC)
308 |
309 | # used for computing the target background
310 | if requires_bb:
311 | target_convex_mask = np.invert(generate_convex_hull_bb(target_image, target_landmarks))
312 | target_convex_mask_without_jaw = np.invert(generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx]))
313 | else:
314 | target_convex_mask = np.invert(generate_convex_hull(target_image, target_landmarks))
315 | target_convex_mask_without_jaw = np.invert(generate_convex_hull_bb(target_image, target_landmarks[start_idx:end_idx]))
316 |
317 | # target_background = ((target_convex_mask/255.)*target_image).astype(np.uint8)
318 | # target_convex_mask_without_jaw = np.invert(generate_convex_hull(target_image, target_landmarks[start_idx:end_idx]))
319 |
320 | target_without_face_features = apply_mask(target_convex_mask_without_jaw, target_image)
321 | target_without_face = apply_mask(target_convex_mask, target_image)
322 |
323 | if poisson_blend_required:
324 | combined_image = poisson_blend(target_image/255., source_image/255., source_face_transformed/255.)
325 | else:
326 | if target_without_face_apply:
327 | combined_image = combine_images(target_without_face, source_face_transformed)
328 | else:
329 | combined_image = combine_images(target_image, source_face_transformed)
330 |
331 | return source_face_transformed, source_convex_mask_transformed, source_image_transformed, target_image, target_convex_mask, combined_image
332 |
--------------------------------------------------------------------------------
/disc_trainers/train_vqvae_mocogan_disc.py:
--------------------------------------------------------------------------------
1 | # This code uses the mocogan content and temporal discriminators
2 |
3 | import argparse
4 | import sys
5 | import os
6 | import random
7 | import os.path as osp
8 |
9 | import torch
10 | from torch import nn, optim
11 | import torch.nn.functional as F
12 |
13 | from torchvision import datasets, transforms, utils
14 |
15 | from TemporalAlignment.models import mocogan_discriminator
16 |
17 | from tqdm import tqdm
18 |
19 | from scheduler import CycleScheduler
20 |
21 | from utils import *
22 | from config import DATASET, LATENT_LOSS_WEIGHT, image_disc_weight, video_disc_weight, SAMPLE_SIZE_FOR_VISUALIZATION
23 |
24 | criterion = nn.MSELoss()
25 |
26 | gan_criterion = nn.BCEWithLogitsLoss()
27 |
28 | global_step = 0
29 |
30 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION
31 |
32 | dataset = DATASET
33 |
34 | CONST_FRAMES_TO_CHECK = 16
35 |
36 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results'
37 | # sample_folder = '/home2/bipasha31/python_scripts/CurrentWork/samples/{}'
38 |
39 |
40 | def run_step(model, data, device, run='train'):
41 | img, S, ground_truth, source_image = process_data(data, device, dataset)
42 |
43 | out, latent_loss = model(img)
44 |
45 | out = out[:, :3]
46 |
47 | recon_loss = criterion(out, ground_truth)
48 | latent_loss = latent_loss.mean()
49 |
50 | if run == 'train':
51 | return recon_loss, latent_loss, S, out, ground_truth
52 | else:
53 | return ground_truth, img, out, source_image
54 |
55 | def run_step_custom(model, data, device, run='train'):
56 | img, S, ground_truth = process_data(data, device, dataset)
57 |
58 | out, latent_loss = model(img)
59 |
60 | out = out[:, :3] # first 3 channels of the prediction
61 |
62 | return out, ground_truth
63 |
64 |
65 | def jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder):
66 | for i, data in enumerate(tqdm(val_loader)):
67 | with torch.no_grad():
68 | source_images, input, prediction, source_images_original = run_step(model, data, device, run='val')
69 |
70 | source_hulls = input[:, :3]
71 | background = input[:, 3:]
72 |
73 | saves = {
74 | 'source': source_hulls,
75 | 'background': background,
76 | 'prediction': prediction,
77 | 'source_images': source_images,
78 | 'source_original': source_images_original
79 | }
80 |
81 | # if i % (len(val_loader) // 10) == 0 or run_type != 'train':
82 | if True:
83 | def denormalize(x):
84 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
85 |
86 | for name in saves:
87 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4"
88 | frames = saves[name].detach().cpu()
89 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames]
90 |
91 | # os.makedirs(sample_folder, exist_ok=True)
92 | save_frames_as_video(frames, saveas, fps=25)
93 |
94 | def base_validation(model, val_loader, device, epoch, i, run_type, sample_folder):
95 | def get_proper_shape(x):
96 | shape = x.shape
97 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3])
98 |
99 | for val_i, data in enumerate(tqdm(val_loader)):
100 | with torch.no_grad():
101 | sample, _, out, source_img = run_step(model, data, device, 'val')
102 |
103 | if val_i % (len(val_loader)//10) == 0: # save 10 results
104 |
105 | def denormalize(x):
106 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
107 |
108 | save_as_sample = f"{sample_folder}/{val_i}_sample.mp4"
109 | save_as_out = f"{sample_folder}/{val_i}_out.mp4"
110 |
111 | sample = sample.detach().cpu()
112 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample]
113 |
114 | out = out.detach().cpu()
115 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out]
116 |
117 | save_frames_as_video(sample, save_as_sample, fps=25)
118 | save_frames_as_video(out, save_as_out, fps=25)
119 |
120 | def validation(model, val_loader, device, epoch, i, sample_folder, run_type='train'):
121 | if dataset >= 6:
122 | jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder)
123 | else:
124 | base_validation(model, val_loader, device, epoch, i, run_type, sample_folder)
125 |
126 |
127 | def flip_video(x):
128 | num = random.randint(0, 1)
129 | if num == 0:
130 | return torch.flip(x, [2])
131 | else:
132 | return x
133 |
134 |
135 | # inside the train discriminator - disc, real, fake is required
136 | def train_discriminator(opt, discriminator, real_batch, fake_batch):
137 | opt.zero_grad()
138 |
139 | real_disc_preds, _ = discriminator(real_batch)
140 | fake_disc_preds, _ = discriminator(fake_batch.detach())
141 |
142 | ones = torch.ones_like(real_disc_preds)
143 | zeros = torch.zeros_like(fake_disc_preds)
144 |
145 | l_discriminator = (gan_criterion(real_disc_preds, ones) + gan_criterion(fake_disc_preds, zeros))/2
146 |
147 | l_discriminator.backward()
148 | opt.step()
149 |
150 | return l_discriminator
151 |
152 | def train_generator(opt, image_discriminator, video_discriminator,
153 | fake_batch, recon_loss, latent_loss):
154 | opt.zero_grad()
155 |
156 | # image disc predictions
157 | fake_image_disc_preds, _ = image_discriminator(fake_batch)
158 | all_ones = torch.ones_like(fake_image_disc_preds)
159 | fake_image_disc_loss = gan_criterion(fake_image_disc_preds, all_ones)
160 |
161 | # video disc predictions
162 | fake_video_batch = fake_batch.unsqueeze(0).permute(0, 2, 1, 3, 4)
163 | fake_video_disc_preds, _ = video_discriminator(fake_video_batch)
164 | all_ones = torch.ones_like(fake_video_disc_preds)
165 | fake_video_disc_loss = gan_criterion(fake_video_disc_preds, all_ones)
166 |
167 | gen_loss = recon_loss \
168 | + LATENT_LOSS_WEIGHT * latent_loss \
169 | + image_disc_weight * fake_image_disc_loss \
170 | + video_disc_weight * fake_video_disc_loss
171 | gen_loss.backward(retain_graph=True)
172 | opt.step()
173 |
174 | return gen_loss
175 |
176 | # training is done in an alternating fashion
177 | # disc and gen alternate their training
178 | def train(model, patch_image_disc, patch_video_disc,
179 | gen_optim, image_disc_optim, video_disc_optim,
180 | loader, val_loader, scheduler, device,
181 | epoch, validate_at, checkpoint_dir, sample_folder):
182 |
183 | SAMPLE_FRAMES = 16 # sample 16 frames for the discriminator
184 |
185 | for i, data in enumerate(loader):
186 |
187 | global global_step
188 |
189 | global_step += 1
190 |
191 | model.train()
192 | patch_image_disc.train()
193 | patch_video_disc.train()
194 |
195 | model.zero_grad()
196 | patch_image_disc.zero_grad()
197 | patch_video_disc.zero_grad()
198 |
199 | # train video generator
200 | recon_loss, latent_loss, S, out, ground_truth = run_step(model, data, device)
201 |
202 | # print(f'Frames : {out.shape[0]}')
203 |
204 | # skip if the number of frames is less than SAMPLE_FRAMES
205 | if out.shape[0] < SAMPLE_FRAMES:
206 | print(f'Encountered {out.shape[0]} frames which is less than {SAMPLE_FRAMES}. Continuing ...')
207 | continue
208 |
209 | # sample SAMPLE_FRAMES frames from out
210 | fake_sampled = out[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256
211 | real_sampled = ground_truth[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256
212 |
213 | # print(f'fake sampled : {fake_sampled.shape}')
214 | gen_loss = train_generator(gen_optim, patch_image_disc, patch_video_disc, fake_sampled, recon_loss, latent_loss)
215 |
216 | # train image discriminator
217 | l_image_dis = train_discriminator(image_disc_optim, patch_image_disc, real_sampled, fake_sampled)
218 |
219 | # train video discriminator - adding the batch dimension
220 | l_video_dis = train_discriminator(video_disc_optim, patch_video_disc,
221 | real_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4),
222 | fake_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4))
223 |
224 | lr = gen_optim.param_groups[0]["lr"]
225 |
226 | # indicates that both the generator and discriminator steps would have been performed
227 | print(f'Epoch : {epoch+1}, step : {global_step}, gen loss : {gen_loss.item():.5f}, image disc loss : {l_image_dis.item():.5f}, video disc loss : {l_video_dis.item():.5f}, lr : {lr:.5f}')
228 |
229 | # check if validation is required
230 | if i%validate_at == 0:
231 | # set the model to eval and generate the predictions
232 | model.eval()
233 |
234 | validation(model, val_loader, device, epoch, i, sample_folder)
235 |
236 | os.makedirs(checkpoint_dir, exist_ok=True)
237 |
238 | # save the vqvae2 generator weights
239 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(global_step).zfill(4)}.pt")
240 |
241 | # save the discriminator weights
242 | # save the video disc weights
243 | torch.save(patch_video_disc.state_dict(), f"{checkpoint_dir}/patch_video_disc_{epoch+1}_{str(global_step).zfill(4)}.pt")
244 |
245 | # save the image/content disc weights
246 | torch.save(patch_image_disc.state_dict(), f"{checkpoint_dir}/patch_image_disc_{epoch+1}_{str(global_step).zfill(4)}.pt")
247 |
248 | # reverse the training state of the model
249 | model.train()
250 |
251 |
252 | def main(args):
253 | device = "cuda"
254 |
255 | default_transform = transforms.Compose(
256 | [
257 | transforms.ToPILImage(),
258 | transforms.Resize(args.size),
259 | transforms.ToTensor(),
260 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
261 | ]
262 | )
263 |
264 | train_loader, val_loader, model, patch_image_disc, patch_video_disc, \
265 | image_disc_optim, video_disc_optim = \
266 | get_loaders_and_models(args, dataset, default_transform, device, test=args.test)
267 |
268 |
269 | # load the models on the gpu
270 | model = model.to(device)
271 | patch_image_disc = patch_image_disc.to(device)
272 | patch_video_disc = patch_video_disc.to(device)
273 |
274 | # loading the pretrained generator (vqvae2) model weights
275 | if args.ckpt:
276 | print(f'Loading pretrained generator model : {args.ckpt}')
277 | state_dict = torch.load(args.ckpt)
278 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() }
279 | try:
280 | model.module.load_state_dict(state_dict)
281 | except:
282 | model.load_state_dict(state_dict)
283 |
284 | # load the image and video disc models if required
285 | if args.load_disc:
286 | # image_disc_path = 'patch_image_disc_' + args.ckpt.split('_', 1)[1]
287 | # video_disc_path = 'patch_video_disc_' + args.ckpt.split('_', 1)[1]
288 |
289 | image_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_image_disc_' + osp.basename(args.ckpt).split('_', 1)[1])
290 | video_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_video_disc_' + osp.basename(args.ckpt).split('_', 1)[1])
291 |
292 | print(f'Loading pretrained disc models : {image_disc_path}, {video_disc_path}')
293 |
294 | image_disc_state_dict = torch.load(image_disc_path)
295 | image_disc_state_dict = { k.replace('module.', ''): v for k, v in image_disc_state_dict.items() }
296 |
297 | video_disc_state_dict = torch.load(video_disc_path)
298 | video_disc_state_dict = { k.replace('module.', ''): v for k, v in video_disc_state_dict.items() }
299 |
300 | try:
301 | patch_image_disc.module.load_state_dict(image_disc_state_dict)
302 | except:
303 | patch_image_disc.load_state_dict(image_disc_state_dict)
304 |
305 | try:
306 | patch_video_disc.module.load_state_dict(video_disc_state_dict)
307 | except:
308 | patch_video_disc.load_state_dict(video_disc_state_dict)
309 |
310 | if args.test:
311 | validation(model, val_loader, device, 0, 0, args.sample_folder, 'val')
312 | else:
313 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
314 |
315 | scheduler = None
316 |
317 | if args.sched == "cycle":
318 | scheduler = CycleScheduler(
319 | optimizer,
320 | args.lr,
321 | n_iter=len(train_loader) * args.epoch,
322 | momentum=None,
323 | warmup_proportion=0.05,
324 | )
325 |
326 | for i in range(args.epoch):
327 |
328 | train(model, patch_image_disc, patch_video_disc,
329 | optimizer, image_disc_optim, video_disc_optim,
330 | train_loader, val_loader, scheduler, device, i,
331 | args.validate_at, args.checkpoint_dir, args.sample_folder)
332 |
333 |
334 | def get_random_name(cipher_length=5):
335 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
336 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)])
337 |
338 | if __name__ == "__main__":
339 | parser = argparse.ArgumentParser()
340 | parser.add_argument("--n_gpu", type=int, default=1)
341 |
342 | port = random.randint(51000, 52000)
343 |
344 | # parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")
345 | parser.add_argument("--batch_size", type=int, default=32)
346 | parser.add_argument("--size", type=int, default=256)
347 | parser.add_argument("--epoch", type=int, default=560)
348 | parser.add_argument("--lr", type=float, default=3e-4)
349 | parser.add_argument("--sched", type=str)
350 | parser.add_argument("--checkpoint_suffix", type=str, default='')
351 | parser.add_argument("--validate_at", type=int, default=1024)
352 | parser.add_argument("--ckpt", required=False)
353 | parser.add_argument("--test", action='store_true', required=False)
354 | parser.add_argument("--gray", action='store_true', required=False)
355 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty')
356 | parser.add_argument("--crossid", action='store_true', required=False)
357 | parser.add_argument("--sample_folder", type=str, default='samples')
358 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint')
359 | parser.add_argument("--validation_folder", type=str, default=None)
360 | parser.add_argument("--custom_validation", action='store_true', required=False)
361 | parser.add_argument("--load_disc", action='store_true', required=False)
362 |
363 | args = parser.parse_args()
364 |
365 | # args.n_gpu = torch.cuda.device_count()
366 | current_run = get_random_name()
367 |
368 | # sample_folder = sample_folder.format(args.checkpoint_suffix)
369 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run)
370 | os.makedirs(args.sample_folder, exist_ok=True)
371 |
372 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run)
373 | # os.makedirs(args.checkpoint_dir, exist_ok=True)
374 |
375 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix)
376 |
377 | print(args, flush=True)
378 |
379 | print(f'Weight configuration used, latent loss : {LATENT_LOSS_WEIGHT}, image disc weight : {image_disc_weight}, video disc weight : {video_disc_weight}')
380 |
381 | # print(f'Weight configuration used : + \
382 | # {LATENT_LOSS_WEIGHT}, 2d disc weight : {G_LOSS_2D_WEIGHT}, + \
383 | # temporal disc weight : {G_LOSS_3D_WEIGHT}')
384 |
385 | # dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))
386 | main(args)
--------------------------------------------------------------------------------
/disc_trainers/train_vqvae_mocogan_disc_perceptual.py:
--------------------------------------------------------------------------------
1 | # This code uses the mocogan content and temporal discriminators with perceptual loss
2 |
3 | import argparse
4 | import sys
5 | import os
6 | import random
7 | import os.path as osp
8 |
9 | import torch
10 | from torch import nn, optim
11 | import torch.nn.functional as F
12 |
13 | from torchvision import datasets, transforms, utils
14 |
15 | from TemporalAlignment.models import mocogan_discriminator
16 |
17 | from tqdm import tqdm
18 |
19 | from scheduler import CycleScheduler
20 |
21 | from utils import *
22 | from config import DATASET, LATENT_LOSS_WEIGHT, PERCEPTUAL_LOSS_WEIGHT, image_disc_weight, video_disc_weight, SAMPLE_SIZE_FOR_VISUALIZATION
23 |
24 | criterion = nn.MSELoss()
25 |
26 | gan_criterion = nn.BCEWithLogitsLoss()
27 |
28 | global_step = 0
29 |
30 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION
31 |
32 | dataset = DATASET
33 |
34 | CONST_FRAMES_TO_CHECK = 16
35 |
36 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results'
37 | # sample_folder = '/home2/bipasha31/python_scripts/CurrentWork/samples/{}'
38 |
39 |
40 | def run_step(model, vqlpips, data, device, run='train'):
41 | img, S, ground_truth, source_image = process_data(data, device, dataset)
42 |
43 | out, latent_loss = model(img)
44 |
45 | out = out[:, :3]
46 |
47 | recon_loss = criterion(out, ground_truth)
48 | latent_loss = latent_loss.mean()
49 |
50 | perceptual_loss = vqlpips(ground_truth, out)
51 |
52 | if run == 'train':
53 | return recon_loss, latent_loss, S, out, ground_truth, perceptual_loss
54 | else:
55 | return ground_truth, img, out, source_image
56 |
57 | def run_step_custom(model, data, device, run='train'):
58 | img, S, ground_truth = process_data(data, device, dataset)
59 |
60 | out, latent_loss = model(img)
61 |
62 | out = out[:, :3] # first 3 channels of the prediction
63 |
64 | return out, ground_truth
65 |
66 |
67 | def jitter_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder):
68 | for i, data in enumerate(tqdm(val_loader)):
69 | with torch.no_grad():
70 | source_images, input, prediction, source_images_original = run_step(model, vqlpips, data, device, run='val')
71 |
72 | source_hulls = input[:, :3]
73 | background = input[:, 3:]
74 |
75 | saves = {
76 | 'source': source_hulls,
77 | 'background': background,
78 | 'prediction': prediction,
79 | 'source_images': source_images,
80 | 'source_original': source_images_original
81 | }
82 |
83 | # if i % (len(val_loader) // 10) == 0 or run_type != 'train':
84 | if True:
85 | def denormalize(x):
86 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
87 |
88 | for name in saves:
89 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4"
90 | frames = saves[name].detach().cpu()
91 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames]
92 |
93 | # os.makedirs(sample_folder, exist_ok=True)
94 | save_frames_as_video(frames, saveas, fps=25)
95 |
96 | def base_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder):
97 | def get_proper_shape(x):
98 | shape = x.shape
99 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3])
100 |
101 | for val_i, data in enumerate(tqdm(val_loader)):
102 | with torch.no_grad():
103 | sample, _, out, source_img = run_step(model, vqlpips, data, device, 'val')
104 |
105 | if val_i % (len(val_loader)//10) == 0: # save 10 results
106 |
107 | def denormalize(x):
108 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
109 |
110 | save_as_sample = f"{sample_folder}/{val_i}_sample.mp4"
111 | save_as_out = f"{sample_folder}/{val_i}_out.mp4"
112 |
113 | sample = sample.detach().cpu()
114 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample]
115 |
116 | out = out.detach().cpu()
117 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out]
118 |
119 | save_frames_as_video(sample, save_as_sample, fps=25)
120 | save_frames_as_video(out, save_as_out, fps=25)
121 |
122 | def validation(model, vqlpips, val_loader, device, epoch, i, sample_folder, run_type='train'):
123 | if dataset >= 6:
124 | jitter_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder)
125 | else:
126 | base_validation(model, vqlpips, val_loader, device, epoch, i, run_type, sample_folder)
127 |
128 |
129 | def flip_video(x):
130 | num = random.randint(0, 1)
131 | if num == 0:
132 | return torch.flip(x, [2])
133 | else:
134 | return x
135 |
136 |
137 | # inside the train discriminator - disc, real, fake is required
138 | def train_discriminator(opt, discriminator, real_batch, fake_batch):
139 | opt.zero_grad()
140 |
141 | real_disc_preds, _ = discriminator(real_batch)
142 | fake_disc_preds, _ = discriminator(fake_batch.detach())
143 |
144 | ones = torch.ones_like(real_disc_preds)
145 | zeros = torch.zeros_like(fake_disc_preds)
146 |
147 | l_discriminator = (gan_criterion(real_disc_preds, ones) + gan_criterion(fake_disc_preds, zeros))/2
148 |
149 | l_discriminator.backward()
150 | opt.step()
151 |
152 | return l_discriminator
153 |
154 | def train_generator(opt, image_discriminator, video_discriminator,
155 | fake_batch, recon_loss, latent_loss, perceptual_loss):
156 | opt.zero_grad()
157 |
158 | # image disc predictions
159 | fake_image_disc_preds, _ = image_discriminator(fake_batch)
160 | all_ones = torch.ones_like(fake_image_disc_preds)
161 | fake_image_disc_loss = gan_criterion(fake_image_disc_preds, all_ones)
162 |
163 | # video disc predictions
164 | fake_video_batch = fake_batch.unsqueeze(0).permute(0, 2, 1, 3, 4)
165 | fake_video_disc_preds, _ = video_discriminator(fake_video_batch)
166 | all_ones = torch.ones_like(fake_video_disc_preds)
167 | fake_video_disc_loss = gan_criterion(fake_video_disc_preds, all_ones)
168 |
169 | gen_loss = recon_loss \
170 | + LATENT_LOSS_WEIGHT * latent_loss \
171 | + image_disc_weight * fake_image_disc_loss \
172 | + video_disc_weight * fake_video_disc_loss \
173 | + PERCEPTUAL_LOSS_WEIGHT * perceptual_loss
174 |
175 | gen_loss.backward(retain_graph=True)
176 | opt.step()
177 |
178 | return gen_loss
179 |
180 | # training is done in an alternating fashion
181 | # disc and gen alternate their training
182 | def train(model, patch_image_disc, patch_video_disc,
183 | gen_optim, image_disc_optim, video_disc_optim,
184 | loader, val_loader, scheduler, device,
185 | epoch, validate_at, checkpoint_dir, sample_folder, vqlpips):
186 |
187 | SAMPLE_FRAMES = 16 # sample 16 frames for the discriminator
188 |
189 | for i, data in enumerate(loader):
190 |
191 | global global_step
192 |
193 | global_step += 1
194 |
195 | model.train()
196 | patch_image_disc.train()
197 | patch_video_disc.train()
198 |
199 | model.zero_grad()
200 | patch_image_disc.zero_grad()
201 | patch_video_disc.zero_grad()
202 |
203 | # train video generator
204 | recon_loss, latent_loss, S, out, ground_truth, perceptual_loss = run_step(model, vqlpips, data, device)
205 |
206 | # print(f'Frames : {out.shape[0]}')
207 |
208 | # skip if the number of frames is less than SAMPLE_FRAMES
209 | if out.shape[0] < SAMPLE_FRAMES:
210 | print(f'Encountered {out.shape[0]} frames which is less than {SAMPLE_FRAMES}. Continuing ...')
211 | continue
212 |
213 | # sample SAMPLE_FRAMES frames from out
214 | fake_sampled = out[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256
215 | real_sampled = ground_truth[:SAMPLE_FRAMES] # dim -> SAMPLE_FRAMES x 3 x 256 x 256
216 |
217 | # print(f'fake sampled : {fake_sampled.shape}')
218 | gen_loss = train_generator(gen_optim, patch_image_disc, patch_video_disc, fake_sampled, recon_loss, latent_loss, perceptual_loss)
219 |
220 | # train image discriminator
221 | l_image_dis = train_discriminator(image_disc_optim, patch_image_disc, real_sampled, fake_sampled)
222 |
223 | # train video discriminator - adding the batch dimension
224 | l_video_dis = train_discriminator(video_disc_optim, patch_video_disc,
225 | real_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4),
226 | fake_sampled.unsqueeze(0).permute(0, 2, 1, 3, 4))
227 |
228 | lr = gen_optim.param_groups[0]["lr"]
229 |
230 | # indicates that both the generator and discriminator steps would have been performed
231 | print(f'Epoch : {epoch+1}, step : {global_step}, gen loss : {gen_loss.item():.5f}, image disc loss : {l_image_dis.item():.5f}, video disc loss : {l_video_dis.item():.5f}, lr : {lr:.5f}')
232 |
233 | # check if validation is required
234 | if i%validate_at == 0:
235 | # set the model to eval and generate the predictions
236 | model.eval()
237 |
238 | validation(model, vqlpips, val_loader, device, epoch, i, sample_folder)
239 |
240 | os.makedirs(checkpoint_dir, exist_ok=True)
241 |
242 | # save the vqvae2 generator weights
243 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(global_step).zfill(4)}.pt")
244 |
245 | # save the discriminator weights
246 | # save the video disc weights
247 | torch.save(patch_video_disc.state_dict(), f"{checkpoint_dir}/patch_video_disc_{epoch+1}_{str(global_step).zfill(4)}.pt")
248 |
249 | # save the image/content disc weights
250 | torch.save(patch_image_disc.state_dict(), f"{checkpoint_dir}/patch_image_disc_{epoch+1}_{str(global_step).zfill(4)}.pt")
251 |
252 | # reverse the training state of the model
253 | model.train()
254 |
255 |
256 | def main(args):
257 | device = "cuda"
258 |
259 | default_transform = transforms.Compose(
260 | [
261 | transforms.ToPILImage(),
262 | transforms.Resize(args.size),
263 | transforms.ToTensor(),
264 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
265 | ]
266 | )
267 |
268 | train_loader, val_loader, model, patch_image_disc, patch_video_disc, \
269 | image_disc_optim, video_disc_optim, vqlpips = \
270 | get_loaders_and_models(args, dataset, default_transform, device, test=args.test)
271 |
272 |
273 | # load the models on the gpu
274 | model = model.to(device)
275 | patch_image_disc = patch_image_disc.to(device)
276 | patch_video_disc = patch_video_disc.to(device)
277 | vqlpips = vqlpips.to(device)
278 |
279 | # loading the pretrained generator (vqvae2) model weights
280 | if args.ckpt:
281 | print(f'Loading pretrained generator model : {args.ckpt}')
282 | state_dict = torch.load(args.ckpt)
283 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() }
284 | try:
285 | model.module.load_state_dict(state_dict)
286 | except:
287 | model.load_state_dict(state_dict)
288 |
289 | # load the image and video disc models if required
290 | if args.load_disc:
291 | # image_disc_path = 'patch_image_disc_' + args.ckpt.split('_', 1)[1]
292 | # video_disc_path = 'patch_video_disc_' + args.ckpt.split('_', 1)[1]
293 |
294 | image_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_image_disc_' + osp.basename(args.ckpt).split('_', 1)[1])
295 | video_disc_path = osp.join(osp.dirname(args.ckpt), 'patch_video_disc_' + osp.basename(args.ckpt).split('_', 1)[1])
296 |
297 | print(f'Loading pretrained disc models : {image_disc_path}, {video_disc_path}')
298 |
299 | image_disc_state_dict = torch.load(image_disc_path)
300 | image_disc_state_dict = { k.replace('module.', ''): v for k, v in image_disc_state_dict.items() }
301 |
302 | video_disc_state_dict = torch.load(video_disc_path)
303 | video_disc_state_dict = { k.replace('module.', ''): v for k, v in video_disc_state_dict.items() }
304 |
305 | try:
306 | patch_image_disc.module.load_state_dict(image_disc_state_dict)
307 | except:
308 | patch_image_disc.load_state_dict(image_disc_state_dict)
309 |
310 | try:
311 | patch_video_disc.module.load_state_dict(video_disc_state_dict)
312 | except:
313 | patch_video_disc.load_state_dict(video_disc_state_dict)
314 |
315 | if args.test:
316 | validation(model, vqlpips, val_loader, device, 0, 0, args.sample_folder, 'val')
317 | else:
318 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
319 |
320 | scheduler = None
321 |
322 | if args.sched == "cycle":
323 | scheduler = CycleScheduler(
324 | optimizer,
325 | args.lr,
326 | n_iter=len(train_loader) * args.epoch,
327 | momentum=None,
328 | warmup_proportion=0.05,
329 | )
330 |
331 | for i in range(args.epoch):
332 |
333 | train(model, patch_image_disc, patch_video_disc,
334 | optimizer, image_disc_optim, video_disc_optim,
335 | train_loader, val_loader, scheduler, device, i,
336 | args.validate_at, args.checkpoint_dir, args.sample_folder, vqlpips)
337 |
338 |
339 | def get_random_name(cipher_length=5):
340 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
341 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)])
342 |
343 | if __name__ == "__main__":
344 | parser = argparse.ArgumentParser()
345 | parser.add_argument("--n_gpu", type=int, default=1)
346 |
347 | port = random.randint(51000, 52000)
348 |
349 | # parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")
350 | parser.add_argument("--batch_size", type=int, default=32)
351 | parser.add_argument("--size", type=int, default=256)
352 | parser.add_argument("--epoch", type=int, default=560)
353 | parser.add_argument("--lr", type=float, default=3e-4)
354 | parser.add_argument("--sched", type=str)
355 | parser.add_argument("--checkpoint_suffix", type=str, default='')
356 | parser.add_argument("--validate_at", type=int, default=1024)
357 | parser.add_argument("--ckpt", required=False)
358 | parser.add_argument("--test", action='store_true', required=False)
359 | parser.add_argument("--gray", action='store_true', required=False)
360 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty')
361 | parser.add_argument("--crossid", action='store_true', required=False)
362 | parser.add_argument("--sample_folder", type=str, default='samples')
363 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint')
364 | parser.add_argument("--validation_folder", type=str, default=None)
365 | parser.add_argument("--custom_validation", action='store_true', required=False)
366 | parser.add_argument("--load_disc", action='store_true', required=False)
367 |
368 | args = parser.parse_args()
369 |
370 | # args.n_gpu = torch.cuda.device_count()
371 | current_run = get_random_name()
372 |
373 | # sample_folder = sample_folder.format(args.checkpoint_suffix)
374 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run)
375 | os.makedirs(args.sample_folder, exist_ok=True)
376 |
377 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run)
378 | # os.makedirs(args.checkpoint_dir, exist_ok=True)
379 |
380 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix)
381 |
382 | print(args, flush=True)
383 |
384 | print(f'Weight configuration used, latent loss : {LATENT_LOSS_WEIGHT}, \
385 | image disc weight : {image_disc_weight}, video disc weight : {video_disc_weight} \
386 | perceptual loss weight : {PERCEPTUAL_LOSS_WEIGHT}')
387 |
388 | main(args)
--------------------------------------------------------------------------------
/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed import (
2 | get_rank,
3 | get_local_rank,
4 | is_primary,
5 | synchronize,
6 | get_world_size,
7 | all_reduce,
8 | all_gather,
9 | reduce_dict,
10 | data_sampler,
11 | LOCAL_PROCESS_GROUP,
12 | )
13 | from .launch import launch
14 |
--------------------------------------------------------------------------------
/distributed/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils import data
7 |
8 |
9 | LOCAL_PROCESS_GROUP = None
10 |
11 |
12 | def is_primary():
13 | return get_rank() == 0
14 |
15 |
16 | def get_rank():
17 | if not dist.is_available():
18 | return 0
19 |
20 | if not dist.is_initialized():
21 | return 0
22 |
23 | return dist.get_rank()
24 |
25 |
26 | def get_local_rank():
27 | if not dist.is_available():
28 | return 0
29 |
30 | if not dist.is_initialized():
31 | return 0
32 |
33 | if LOCAL_PROCESS_GROUP is None:
34 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")
35 |
36 | return dist.get_rank(group=LOCAL_PROCESS_GROUP)
37 |
38 |
39 | def synchronize():
40 | if not dist.is_available():
41 | return
42 |
43 | if not dist.is_initialized():
44 | return
45 |
46 | world_size = dist.get_world_size()
47 |
48 | if world_size == 1:
49 | return
50 |
51 | dist.barrier()
52 |
53 |
54 | def get_world_size():
55 | if not dist.is_available():
56 | return 1
57 |
58 | if not dist.is_initialized():
59 | return 1
60 |
61 | return dist.get_world_size()
62 |
63 |
64 | def all_reduce(tensor, op=dist.ReduceOp.SUM):
65 | world_size = get_world_size()
66 |
67 | if world_size == 1:
68 | return tensor
69 |
70 | dist.all_reduce(tensor, op=op)
71 |
72 | return tensor
73 |
74 |
75 | def all_gather(data):
76 | world_size = get_world_size()
77 |
78 | if world_size == 1:
79 | return [data]
80 |
81 | buffer = pickle.dumps(data)
82 | storage = torch.ByteStorage.from_buffer(buffer)
83 | tensor = torch.ByteTensor(storage).to("cuda")
84 |
85 | local_size = torch.IntTensor([tensor.numel()]).to("cuda")
86 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
87 | dist.all_gather(size_list, local_size)
88 | size_list = [int(size.item()) for size in size_list]
89 | max_size = max(size_list)
90 |
91 | tensor_list = []
92 | for _ in size_list:
93 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
94 |
95 | if local_size != max_size:
96 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
97 | tensor = torch.cat((tensor, padding), 0)
98 |
99 | dist.all_gather(tensor_list, tensor)
100 |
101 | data_list = []
102 |
103 | for size, tensor in zip(size_list, tensor_list):
104 | buffer = tensor.cpu().numpy().tobytes()[:size]
105 | data_list.append(pickle.loads(buffer))
106 |
107 | return data_list
108 |
109 |
110 | def reduce_dict(input_dict, average=True):
111 | world_size = get_world_size()
112 |
113 | if world_size < 2:
114 | return input_dict
115 |
116 | with torch.no_grad():
117 | keys = []
118 | values = []
119 |
120 | for k in sorted(input_dict.keys()):
121 | keys.append(k)
122 | values.append(input_dict[k])
123 |
124 | values = torch.stack(values, 0)
125 | dist.reduce(values, dst=0)
126 |
127 | if dist.get_rank() == 0 and average:
128 | values /= world_size
129 |
130 | reduced_dict = {k: v for k, v in zip(keys, values)}
131 |
132 | return reduced_dict
133 |
134 |
135 | def data_sampler(dataset, shuffle, distributed):
136 | if distributed:
137 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
138 |
139 | if shuffle:
140 | return data.RandomSampler(dataset)
141 |
142 | else:
143 | return data.SequentialSampler(dataset)
144 |
--------------------------------------------------------------------------------
/distributed/launch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import distributed as dist
5 | from torch import multiprocessing as mp
6 |
7 | import distributed as dist_fn
8 |
9 |
10 | def find_free_port():
11 | import socket
12 |
13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14 |
15 | sock.bind(("", 0))
16 | port = sock.getsockname()[1]
17 | sock.close()
18 |
19 | return port
20 |
21 |
22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
23 | world_size = n_machine * n_gpu_per_machine
24 |
25 | if world_size > 1:
26 | if "OMP_NUM_THREADS" not in os.environ:
27 | os.environ["OMP_NUM_THREADS"] = "1"
28 |
29 | if dist_url == "auto":
30 | if n_machine != 1:
31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs')
32 |
33 | port = find_free_port()
34 | dist_url = f"tcp://127.0.0.1:{port}"
35 |
36 | if n_machine > 1 and dist_url.startswith("file://"):
37 | raise ValueError(
38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
39 | )
40 |
41 | mp.spawn(
42 | distributed_worker,
43 | nprocs=n_gpu_per_machine,
44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
45 | daemon=False,
46 | )
47 |
48 | else:
49 | fn(*args)
50 |
51 |
52 | def distributed_worker(
53 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
54 | ):
55 | if not torch.cuda.is_available():
56 | raise OSError("CUDA is not available. Please check your environments")
57 |
58 | global_rank = machine_rank * n_gpu_per_machine + local_rank
59 |
60 | try:
61 | dist.init_process_group(
62 | backend="NCCL",
63 | init_method=dist_url,
64 | world_size=world_size,
65 | rank=global_rank,
66 | )
67 |
68 | except Exception:
69 | raise OSError("failed to initialize NCCL groups")
70 |
71 | dist_fn.synchronize()
72 |
73 | if n_gpu_per_machine > torch.cuda.device_count():
74 | raise ValueError(
75 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
76 | )
77 |
78 | torch.cuda.set_device(local_rank)
79 |
80 | if dist_fn.LOCAL_PROCESS_GROUP is not None:
81 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
82 |
83 | n_machine = world_size // n_gpu_per_machine
84 |
85 | for i in range(n_machine):
86 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
87 | pg = dist.new_group(ranks_on_i)
88 |
89 | if i == machine_rank:
90 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg
91 |
92 | fn(*args)
93 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: faceoff
2 | channels:
3 | - anaconda
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=2_gnu
9 | - backcall=0.2.0=pyhd3eb1b0_0
10 | - ca-certificates=2022.07.19=h06a4308_0
11 | - debugpy=1.5.1=py37h295c915_0
12 | - decorator=5.1.1=pyhd3eb1b0_0
13 | - entrypoints=0.4=py37h06a4308_0
14 | - ipykernel=6.9.1=py37h06a4308_0
15 | - ipython=7.31.1=py37h06a4308_1
16 | - jedi=0.18.1=py37h06a4308_1
17 | - jupyter_client=7.2.2=py37h06a4308_0
18 | - jupyter_core=4.10.0=py37h06a4308_0
19 | - ld_impl_linux-64=2.39=hcc3a1bd_1
20 | - libffi=3.4.2=h7f98852_5
21 | - libgcc-ng=12.2.0=h65d4601_19
22 | - libgomp=12.2.0=h65d4601_19
23 | - libnsl=2.0.0=h7f98852_0
24 | - libsodium=1.0.18=h7b6447c_0
25 | - libsqlite=3.40.0=h753d276_0
26 | - libstdcxx-ng=12.2.0=h46fd767_19
27 | - libzlib=1.2.13=h166bdaf_4
28 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
29 | - ncurses=6.3=h27087fc_1
30 | - nest-asyncio=1.5.5=py37h06a4308_0
31 | - openssl=3.0.7=h0b41bf4_1
32 | - parso=0.8.3=pyhd3eb1b0_0
33 | - pexpect=4.8.0=pyhd3eb1b0_3
34 | - pickleshare=0.7.5=pyhd3eb1b0_1003
35 | - pip=22.3.1=pyhd8ed1ab_0
36 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
37 | - ptyprocess=0.7.0=pyhd3eb1b0_2
38 | - pygments=2.11.2=pyhd3eb1b0_0
39 | - python=3.7.12=hf930737_100_cpython
40 | - python-dateutil=2.8.2=pyhd3eb1b0_0
41 | - pyzmq=23.2.0=py37h6a678d5_0
42 | - readline=8.1.2=h0f457ee_0
43 | - setuptools=65.6.3=pyhd8ed1ab_0
44 | - six=1.16.0=pyhd3eb1b0_1
45 | - sqlite=3.40.0=h4ff8645_0
46 | - tk=8.6.12=h27826a3_0
47 | - tornado=6.1=py37h27cfd23_0
48 | - traitlets=5.1.1=pyhd3eb1b0_0
49 | - wcwidth=0.2.5=pyhd3eb1b0_0
50 | - wheel=0.38.4=pyhd8ed1ab_0
51 | - xz=5.2.6=h166bdaf_0
52 | - zeromq=4.3.4=h2531618_0
53 | - pip:
54 | - bchlib==0.14.0
55 | - certifi==2022.12.7
56 | - charset-normalizer==2.1.1
57 | - cycler==0.11.0
58 | - fonttools==4.38.0
59 | - idna==3.4
60 | - imageio==2.24.0
61 | - joblib==1.2.0
62 | - kiwisolver==1.4.4
63 | - matplotlib==3.5.3
64 | - networkx==2.6.3
65 | - numpy==1.21.6
66 | - nvidia-cublas-cu11==11.10.3.66
67 | - nvidia-cuda-nvrtc-cu11==11.7.99
68 | - nvidia-cuda-runtime-cu11==11.7.99
69 | - nvidia-cudnn-cu11==8.5.0.96
70 | - opencv-python==4.6.0.66
71 | - packaging==22.0
72 | - pillow==9.3.0
73 | - pyparsing==3.0.9
74 | - pywavelets==1.3.0
75 | - requests==2.28.1
76 | - scikit-image==0.19.3
77 | - scikit-learn==1.0.2
78 | - scipy==1.7.3
79 | - sklearn==0.0.post1
80 | - stn==1.0.1
81 | - threadpoolctl==3.1.0
82 | - tifffile==2021.11.2
83 | - torch==1.13.1
84 | - torchvision==0.14.1
85 | - tqdm==4.64.1
86 | - typing-extensions==4.4.0
87 | - urllib3==1.26.13
88 | - wand==0.6.11
89 | prefix: /home2/aditya1/miniconda3/envs/stegastamp
90 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.lpips import LPIPS
6 | from models.discriminator import NLayerDiscriminator, weights_init
7 |
8 | criterion = nn.L1Loss()
9 |
10 | def adopt_weight(weight, global_step, threshold=0, value=0.):
11 | if global_step < threshold:
12 | weight = value
13 | return weight
14 |
15 | def hinge_d_loss(logits_real, logits_fake):
16 | loss_real = torch.mean(F.relu(1. - logits_real))
17 | loss_fake = torch.mean(F.relu(1. + logits_fake))
18 | d_loss = 0.5 * (loss_real + loss_fake)
19 | return d_loss
20 |
21 | def vanilla_d_loss(logits_real, logits_fake):
22 | d_loss = 0.5 * (
23 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
24 | torch.mean(torch.nn.functional.softplus(logits_fake)))
25 | return d_loss
26 |
27 | class VQLPIPS(nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | self.perceptual_loss = LPIPS().eval()
31 |
32 | def forward(self, targets, reconstructions):
33 | return self.perceptual_loss(targets.contiguous(), reconstructions.contiguous()).mean()
34 |
35 | class VQLPIPSWithDiscriminator(nn.Module):
36 | def __init__(self, disc_start,
37 | disc_num_layers=3, disc_in_channels=3,
38 | disc_factor=1.0, disc_weight=0.8, use_actnorm=False,
39 | disc_ndf=64, disc_loss="hinge"):
40 | super().__init__()
41 | assert disc_loss in ["hinge", "vanilla"]
42 |
43 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
44 | n_layers=disc_num_layers,
45 | use_actnorm=use_actnorm,
46 | ndf=disc_ndf
47 | ).apply(weights_init)
48 |
49 | self.discriminator_iter_start = disc_start
50 |
51 | self.perceptual_loss = LPIPS().eval()
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 |
62 | self.perceptual_weight = 1.0
63 |
64 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
65 | if last_layer is not None:
66 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
67 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
68 | else:
69 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
70 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
71 |
72 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
73 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
74 | d_weight = d_weight * self.discriminator_weight
75 | return d_weight
76 |
77 | def forward(self, rec_loss, targets, reconstructions):
78 | p_loss = self.perceptual_loss(targets.contiguous(), reconstructions.contiguous())
79 | rec_loss = rec_loss + self.perceptual_weight * p_loss
80 |
81 | nll_loss = torch.mean(rec_loss)
82 |
83 | return p_loss.mean(), nll_loss
84 |
85 | def second_forward(self, rec_loss, targets, reconstructions, optimizer_idx,
86 | global_step, perceptual_loss=False, last_layer=None):
87 |
88 | if perceptual_loss:
89 | p_loss = self.perceptual_loss(targets.contiguous(), reconstructions.contiguous())
90 | rec_loss = rec_loss + self.perceptual_weight * p_loss
91 | else:
92 | p_loss = 0
93 |
94 | nll_loss = torch.mean(rec_loss)
95 |
96 | # now the GAN part
97 | if optimizer_idx == 0:
98 | # generator update
99 | logits_fake = self.discriminator(reconstructions.contiguous())
100 | # g_loss = -torch.mean(logits_fake)
101 | g_loss = criterion(logits_fake, torch.ones_like(logits_fake, device=logits_fake.device))
102 |
103 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
104 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
105 | return d_weight * disc_factor * g_loss, p_loss.mean(), nll_loss
106 |
107 | if optimizer_idx == 1:
108 | # second pass for discriminator update
109 | logits_real = self.discriminator(targets.contiguous().detach())
110 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
111 |
112 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
113 | # d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
114 | loss_real = criterion(logits_real, torch.ones_like(logits_real, device=logits_real.device))
115 | loss_fake = criterion(logits_fake, torch.zeros_like(logits_fake, device=logits_fake.device))
116 | d_loss = disc_factor * (loss_real + loss_fake).mean()
117 |
118 | return d_loss, p_loss.mean(), nll_loss
119 |
120 | class ContrastiveLoss(torch.nn.Module):
121 | """
122 | Contrastive loss function.
123 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
124 | """
125 |
126 | def __init__(self, margin=2.0):
127 | super(ContrastiveLoss, self).__init__()
128 | self.margin = margin
129 |
130 | def forward(self, output1, output2, label):
131 | euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
132 | loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
133 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
134 |
135 |
136 | return loss_contrastive
137 |
138 | class SiameseNetworkFaceSimilarity(nn.Module):
139 | def __init__(self):
140 | super(SiameseNetworkFaceSimilarity, self).__init__()
141 | self.cnn1 = nn.Sequential(
142 | nn.ReflectionPad2d(1),
143 | nn.Conv2d(1, 4, kernel_size=3),
144 | nn.ReLU(inplace=True),
145 | nn.BatchNorm2d(4),
146 |
147 | nn.ReflectionPad2d(1),
148 | nn.Conv2d(4, 8, kernel_size=3),
149 | nn.ReLU(inplace=True),
150 | nn.BatchNorm2d(8),
151 |
152 | nn.ReflectionPad2d(1),
153 | nn.Conv2d(8, 8, kernel_size=3),
154 | nn.ReLU(inplace=True),
155 | nn.BatchNorm2d(8),
156 | )
157 |
158 | self.fc1 = nn.Sequential(
159 | nn.Linear(8*256*256, 500),
160 | nn.ReLU(inplace=True),
161 |
162 | nn.Linear(500, 500),
163 | nn.ReLU(inplace=True),
164 |
165 | nn.Linear(500, 5))
166 |
167 | def forward_once(self, x):
168 | x = x.unsqueeze(1)
169 | output = self.cnn1(x)
170 | output = output.view(output.size()[0], -1)
171 | output = self.fc1(output)
172 | return output
173 |
174 | def forward(self, input1, input2):
175 | output1 = self.forward_once(input1)
176 | output2 = self.forward_once(input2)
177 | return F.pairwise_distance(output1, output2).mean()
178 |
--------------------------------------------------------------------------------
/models/actnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from models.actnorm import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return nn.Sigmoid()(self.main(input))
68 |
--------------------------------------------------------------------------------
/models/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | import os, hashlib
9 | import requests
10 | from tqdm import tqdm
11 |
12 | URL_MAP = {
13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
14 | }
15 |
16 | CKPT_MAP = {
17 | "vgg_lpips": "vgg.pth"
18 | }
19 |
20 | MD5_MAP = {
21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
22 | }
23 |
24 | def download(url, local_path, chunk_size=1024):
25 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
26 | with requests.get(url, stream=True) as r:
27 | total_size = int(r.headers.get("content-length", 0))
28 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
29 | with open(local_path, "wb") as f:
30 | for data in r.iter_content(chunk_size=chunk_size):
31 | if data:
32 | f.write(data)
33 | pbar.update(chunk_size)
34 |
35 | def md5_hash(path):
36 | with open(path, "rb") as f:
37 | content = f.read()
38 | return hashlib.md5(content).hexdigest()
39 |
40 | def get_ckpt_path(name, root, check=False):
41 | assert name in URL_MAP
42 | path = os.path.join(root, CKPT_MAP[name])
43 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
44 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
45 | download(URL_MAP[name], path)
46 | md5 = md5_hash(path)
47 | assert md5 == MD5_MAP[name], md5
48 | return path
49 |
50 | class LPIPS(nn.Module):
51 | # Learned perceptual metric
52 | def __init__(self, use_dropout=True):
53 | super().__init__()
54 | self.scaling_layer = ScalingLayer()
55 | self.chns = [64, 128, 256, 512, 512] # vg16 features
56 | self.net = vgg16(pretrained=True, requires_grad=False)
57 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
58 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
59 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
60 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
61 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
62 | self.load_from_pretrained()
63 | for param in self.parameters():
64 | param.requires_grad = False
65 |
66 | def load_from_pretrained(self, name="vgg_lpips"):
67 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
68 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
69 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
70 |
71 | @classmethod
72 | def from_pretrained(cls, name="vgg_lpips"):
73 | if name != "vgg_lpips":
74 | raise NotImplementedError
75 | model = cls()
76 | ckpt = get_ckpt_path(name)
77 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
78 | return model
79 |
80 | def forward(self, input, target):
81 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
82 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
83 | feats0, feats1, diffs = {}, {}, {}
84 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
85 | for kk in range(len(self.chns)):
86 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
87 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
88 |
89 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
90 | val = res[0]
91 | for l in range(1, len(self.chns)):
92 | val += res[l]
93 | return val
94 |
95 |
96 | class ScalingLayer(nn.Module):
97 | def __init__(self):
98 | super(ScalingLayer, self).__init__()
99 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
100 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
101 |
102 | def forward(self, inp):
103 | return (inp - self.shift) / self.scale
104 |
105 |
106 | class NetLinLayer(nn.Module):
107 | """ A single linear layer which does a 1x1 conv """
108 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
109 | super(NetLinLayer, self).__init__()
110 | layers = [nn.Dropout(), ] if (use_dropout) else []
111 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
112 | self.model = nn.Sequential(*layers)
113 |
114 |
115 | class vgg16(torch.nn.Module):
116 | def __init__(self, requires_grad=False, pretrained=True):
117 | super(vgg16, self).__init__()
118 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
119 | self.slice1 = torch.nn.Sequential()
120 | self.slice2 = torch.nn.Sequential()
121 | self.slice3 = torch.nn.Sequential()
122 | self.slice4 = torch.nn.Sequential()
123 | self.slice5 = torch.nn.Sequential()
124 | self.N_slices = 5
125 | for x in range(4):
126 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
127 | for x in range(4, 9):
128 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
129 | for x in range(9, 16):
130 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
131 | for x in range(16, 23):
132 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
133 | for x in range(23, 30):
134 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
135 | if not requires_grad:
136 | for param in self.parameters():
137 | param.requires_grad = False
138 |
139 | def forward(self, X):
140 | h = self.slice1(X)
141 | h_relu1_2 = h
142 | h = self.slice2(h)
143 | h_relu2_2 = h
144 | h = self.slice3(h)
145 | h_relu3_3 = h
146 | h = self.slice4(h)
147 | h_relu4_3 = h
148 | h = self.slice5(h)
149 | h_relu5_3 = h
150 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
151 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
152 | return out
153 |
154 |
155 | def normalize_tensor(x,eps=1e-10):
156 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
157 | return x/(norm_factor+eps)
158 |
159 |
160 | def spatial_average(x, keepdim=True):
161 | return x.mean([2,3],keepdim=keepdim)
162 |
163 |
--------------------------------------------------------------------------------
/models/vqvae_conv3d_latent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from tqdm import tqdm
6 | import numpy as np
7 | import distributed as dist_fn
8 |
9 | import warnings
10 | warnings.filterwarnings("ignore")
11 |
12 | import matplotlib.pyplot as plt
13 |
14 | # Copyright 2018 The Sonnet Authors. All Rights Reserved.
15 | #
16 | # Licensed under the Apache License, Version 2.0 (the "License");
17 | # you may not use this file except in compliance with the License.
18 | # You may obtain a copy of the License at
19 | #
20 | # http://www.apache.org/licenses/LICENSE-2.0
21 | #
22 | # Unless required by applicable law or agreed to in writing, software
23 | # distributed under the License is distributed on an "AS IS" BASIS,
24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 | # See the License for the specific language governing permissions and
26 | # limitations under the License.
27 | # ============================================================================
28 |
29 |
30 | # Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch
31 |
32 |
33 | class Quantize(nn.Module):
34 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
35 | super().__init__()
36 |
37 | self.dim = dim
38 | self.n_embed = n_embed
39 | self.decay = decay
40 | self.eps = eps
41 |
42 | embed = torch.randn(dim, n_embed)
43 | self.register_buffer("embed", embed)
44 | self.register_buffer("cluster_size", torch.zeros(n_embed))
45 | self.register_buffer("embed_avg", embed.clone())
46 |
47 | def forward(self, input):
48 | flatten = input.reshape(-1, self.dim)
49 | dist = (
50 | flatten.pow(2).sum(1, keepdim=True)
51 | - 2 * flatten @ self.embed
52 | + self.embed.pow(2).sum(0, keepdim=True)
53 | )
54 | _, embed_ind = (-dist).max(1)
55 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
56 | embed_ind = embed_ind.view(*input.shape[:-1])
57 | quantize = self.embed_code(embed_ind)
58 |
59 | if self.training:
60 | embed_onehot_sum = embed_onehot.sum(0)
61 | embed_sum = flatten.transpose(0, 1) @ embed_onehot
62 |
63 | dist_fn.all_reduce(embed_onehot_sum)
64 | dist_fn.all_reduce(embed_sum)
65 |
66 | self.cluster_size.data.mul_(self.decay).add_(
67 | embed_onehot_sum, alpha=1 - self.decay
68 | )
69 | self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
70 | n = self.cluster_size.sum()
71 | cluster_size = (
72 | (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
73 | )
74 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
75 | self.embed.data.copy_(embed_normalized)
76 |
77 | diff = (quantize.detach() - input).pow(2).mean()
78 | quantize = input + (quantize - input).detach()
79 |
80 | return quantize, diff, embed_ind
81 |
82 | def embed_code(self, embed_id):
83 | return F.embedding(embed_id, self.embed.transpose(0, 1))
84 |
85 |
86 | class ResBlock(nn.Module):
87 | def __init__(self, in_channel, channel):
88 | super().__init__()
89 |
90 | self.conv = nn.Sequential(
91 | nn.ReLU(),
92 | nn.Conv2d(in_channel, channel, 3, padding=1),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(channel, in_channel, 1),
95 | )
96 |
97 | def forward(self, input):
98 | out = self.conv(input)
99 | out += input
100 |
101 | return out
102 |
103 | class Encoder(nn.Module):
104 | def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
105 | super().__init__()
106 |
107 | if stride == 4:
108 | blocks = [
109 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
110 | nn.ReLU(inplace=True),
111 | nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
112 | nn.ReLU(inplace=True),
113 | nn.Conv2d(channel, channel, 3, padding=1),
114 | ]
115 |
116 | elif stride == 2:
117 | blocks = [
118 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
119 | nn.ReLU(inplace=True),
120 | nn.Conv2d(channel // 2, channel, 3, padding=1),
121 | ]
122 |
123 | for i in range(n_res_block):
124 | blocks.append(ResBlock(channel, n_res_channel))
125 |
126 | blocks.append(nn.ReLU(inplace=True))
127 |
128 | self.blocks = nn.Sequential(*blocks)
129 |
130 | def forward(self, input):
131 | return self.blocks(input)
132 |
133 |
134 | class Decoder(nn.Module):
135 | def __init__(
136 | self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
137 | ):
138 | super().__init__()
139 |
140 | blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
141 |
142 | for i in range(n_res_block):
143 | blocks.append(ResBlock(channel, n_res_channel))
144 |
145 | blocks.append(nn.ReLU(inplace=True))
146 |
147 | if stride == 4:
148 | blocks.extend(
149 | [
150 | nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
151 | nn.ReLU(inplace=True),
152 | nn.ConvTranspose2d(
153 | channel // 2, out_channel, 4, stride=2, padding=1
154 | ),
155 | ]
156 | )
157 |
158 | elif stride == 2:
159 | blocks.append(
160 | nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
161 | )
162 |
163 | self.blocks = nn.Sequential(*blocks)
164 |
165 | def forward(self, input):
166 | return self.blocks(input)
167 |
168 | # applies a sequence of conv3d layers with relu activation for preserving temporal coherence
169 | class Conv3dLatentPostnet(nn.Module):
170 | def __init__(self, channels):
171 | super().__init__()
172 | self.conv3d = nn.Sequential(
173 | self.conv3d_layer(channels=channels),
174 | self.conv3d_layer(channels=channels),
175 | self.conv3d_layer(channels=channels, is_final=True)
176 | )
177 |
178 | def conv3d_layer(self, channels=128, kernel_size=3, padding=1, is_final=False):
179 | if is_final:
180 | return nn.Sequential(
181 | nn.Conv3d(channels, channels, kernel_size, padding=padding)
182 | )
183 | else:
184 | return nn.Sequential(
185 | nn.Conv3d(channels, channels, kernel_size, padding=padding),
186 | nn.ReLU()
187 | )
188 |
189 | def forward(self, input):
190 | return self.conv3d(input)
191 |
192 | class VQVAE(nn.Module):
193 | def __init__(
194 | self,
195 | in_channel=3,
196 | channel=128,
197 | n_res_block=2,
198 | n_res_channel=32,
199 | embed_dim=64,
200 | n_embed=512,
201 | decay=0.99,
202 | residual=False,
203 | ):
204 | super().__init__()
205 |
206 | self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
207 | self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
208 | self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
209 | self.quantize_t = Quantize(embed_dim, n_embed)
210 | self.dec_t = Decoder(
211 | embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
212 | )
213 | self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
214 | self.quantize_b = Quantize(embed_dim, n_embed)
215 | self.upsample_t = nn.ConvTranspose2d(
216 | embed_dim, embed_dim, 4, stride=2, padding=1
217 | )
218 | self.dec = Decoder(
219 | embed_dim + embed_dim,
220 | in_channel,
221 | channel,
222 | n_res_block,
223 | n_res_channel,
224 | stride=4,
225 | )
226 |
227 | self.residual = residual
228 |
229 | # bottom encoded dimension - batch_size x 128 x 64 x 64
230 | self.conv3d_encoded_b = Conv3dLatentPostnet(128)
231 | self.conv3d_encoded_t = Conv3dLatentPostnet(128)
232 |
233 | # self.conv3d_encoded_b = nn.Conv3d(128, 128, 3, padding=1)
234 | # self.conv3d_encoded_t = nn.Conv3d(128, 128, 3, padding=1)
235 |
236 |
237 | def only_encode(self, input):
238 | enc_b = self.enc_b(input)
239 | enc_t = self.enc_t(enc_b)
240 |
241 | return enc_b, enc_t
242 |
243 | def forward(self, input):
244 | enc_b, enc_t = self.only_encode(input)
245 |
246 | # enc_b dimension -> batch_size x 128 x 64 x 64
247 | enc_b, enc_t = enc_b.unsqueeze(0).permute(0, 2, 1, 3, 4), enc_t.unsqueeze(0).permute(0, 2, 1, 3, 4)
248 |
249 | # apply the 3d conv on the encoded representations
250 | enc_b_conv, enc_t_conv = self.conv3d_encoded_b(enc_b), self.conv3d_encoded_t(enc_t)
251 | enc_b_conv, enc_t_conv = enc_b_conv.squeeze(0).permute(1, 0, 2, 3), enc_t_conv.squeeze(0).permute(1, 0, 2, 3)
252 |
253 | # generate the quantized representation
254 | quant_t, quant_b, diff, _, _ = self.encode_quantized(enc_b_conv, enc_t_conv)
255 |
256 | # generate the decoded representation
257 | dec = self.decode(quant_t, quant_b)
258 |
259 | return dec, diff
260 |
261 | def encode_quantized(self, enc_b, enc_t):
262 | # enc_b = self.enc_b(input)
263 | # enc_t = self.enc_t(enc_b)
264 |
265 | quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
266 | quant_t, diff_t, id_t = self.quantize_t(quant_t)
267 | quant_t = quant_t.permute(0, 3, 1, 2)
268 | diff_t = diff_t.unsqueeze(0)
269 |
270 | dec_t = self.dec_t(quant_t)
271 | enc_b = torch.cat([dec_t, enc_b], 1)
272 |
273 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
274 | quant_b, diff_b, id_b = self.quantize_b(quant_b)
275 | quant_b = quant_b.permute(0, 3, 1, 2)
276 | diff_b = diff_b.unsqueeze(0)
277 |
278 | return quant_t, quant_b, diff_t + diff_b, id_t, id_b
279 |
280 | def decode(self, quant_t, quant_b):
281 | upsample_t = self.upsample_t(quant_t)
282 | quant = torch.cat([upsample_t, quant_b], 1)
283 | dec = self.dec(quant)
284 |
285 | return dec
286 |
287 | def decode_code(self, code_t, code_b):
288 | quant_t = self.quantize_t.embed_code(code_t)
289 | quant_t = quant_t.permute(0, 3, 1, 2)
290 | quant_b = self.quantize_b.embed_code(code_b)
291 | quant_b = quant_b.permute(0, 3, 1, 2)
292 |
293 | dec = self.decode(quant_t, quant_b)
294 |
295 | return dec
296 |
297 | # ============================================================================
298 | # ============================= VQVAE BLOB2FULL ==============================
299 | # ============================================================================
300 |
301 | class Encode(nn.Module):
302 | def __init__(
303 | self,
304 | in_channel,
305 | channel,
306 | n_res_block,
307 | n_res_channel,
308 | embed_dim,
309 | n_embed,
310 | decay,
311 | ):
312 | super().__init__()
313 |
314 | self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
315 | self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
316 | self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
317 | self.quantize_t = Quantize(embed_dim, n_embed)
318 | self.dec_t = Decoder(
319 | embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
320 | )
321 | self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
322 | self.quantize_b = Quantize(embed_dim, n_embed)
323 |
324 | def forward(self, input):
325 | enc_b = self.enc_b(input)
326 | enc_t = self.enc_t(enc_b)
327 |
328 | quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
329 | quant_t, diff_t, id_t = self.quantize_t(quant_t)
330 | quant_t = quant_t.permute(0, 3, 1, 2)
331 | diff_t = diff_t.unsqueeze(0)
332 |
333 | dec_t = self.dec_t(quant_t)
334 | enc_b = torch.cat([dec_t, enc_b], 1)
335 |
336 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
337 | quant_b, diff_b, id_b = self.quantize_b(quant_b)
338 | quant_b = quant_b.permute(0, 3, 1, 2)
339 | diff_b = diff_b.unsqueeze(0)
340 |
341 | return quant_t, quant_b, diff_t + diff_b, id_t, id_b
342 |
343 | class VQVAE_B2F(nn.Module):
344 | def __init__(
345 | self,
346 | in_channel=3,
347 | channel=128,
348 | n_res_block=2,
349 | n_res_channel=32,
350 | embed_dim=64,
351 | n_embed=512,
352 | decay=0.99,
353 | ):
354 | super().__init__()
355 |
356 | self.encode_face = Encode(in_channel,
357 | channel,
358 | n_res_block,
359 | n_res_channel,
360 | embed_dim,
361 | n_embed,
362 | decay)
363 |
364 | self.encode_rhand = Encode(in_channel,
365 | channel,
366 | n_res_block,
367 | n_res_channel,
368 | embed_dim,
369 | n_embed,
370 | decay)
371 |
372 | self.encode_lhand = Encode(in_channel,
373 | channel,
374 | n_res_block,
375 | n_res_channel,
376 | embed_dim,
377 | n_embed,
378 | decay)
379 |
380 | self.upsample_t = nn.ConvTranspose2d(
381 | embed_dim, embed_dim, 4, stride=2, padding=1
382 | )
383 | self.dec = Decoder(
384 | embed_dim + embed_dim,
385 | in_channel,
386 | channel,
387 | n_res_block,
388 | n_res_channel,
389 | stride=4,
390 | )
391 |
392 | def forward(self, input, save_idx=None, visual_folder=None):
393 | face, rhand, lhand = input
394 |
395 | quant_t_face, quant_b_face, diff_face, _, _ = self.encode_face(face)
396 | quant_t_rhand, quant_b_rhand, diff_rhand, _, _ = self.encode_rhand(rhand)
397 | quant_t_lhand, quant_b_lhand, diff_lhand, _, _ = self.encode_lhand(lhand)
398 |
399 | quant_t = quant_t_face + quant_t_rhand + quant_t_lhand
400 | quant_b = quant_b_face + quant_b_rhand + quant_b_lhand
401 | diff = diff_face + diff_rhand + diff_lhand
402 |
403 | dec = self.decode(quant_t, quant_b)
404 |
405 | if save_idx is not None:
406 | def save_img(img, save_idx, i, dtype):
407 | img = (img.detach().cpu() + 0.5).numpy()
408 | img = np.transpose(img, (1,2,0))
409 | fig = plt.imshow(img, interpolation='nearest')
410 | fig.axes.get_xaxis().set_visible(False)
411 | fig.axes.get_yaxis().set_visible(False)
412 | plt.savefig('{}/{}_{}_{}.jpg'.\
413 | format(visual_folder, save_idx, i, dtype), bbox_inches='tight')
414 |
415 | for i in tqdm(range(min(face.shape[0], 8))):
416 | save_img(face[i], save_idx, i, 'face')
417 | save_img(rhand[i], save_idx, i, 'rhand')
418 | save_img(lhand[i], save_idx, i, 'lhand')
419 | save_img(x_hat[i], save_idx, i, 'reconstructed')
420 |
421 | return dec, diff
422 |
423 | def decode(self, quant_t, quant_b):
424 | upsample_t = self.upsample_t(quant_t)
425 | quant = torch.cat([upsample_t, quant_b], 1)
426 | dec = self.dec(quant)
427 |
428 | return dec
429 |
430 | def decode_code(self, code_t, code_b):
431 | quant_t = self.quantize_t.embed_code(code_t)
432 | quant_t = quant_t.permute(0, 3, 1, 2)
433 | quant_b = self.quantize_b.embed_code(code_b)
434 | quant_b = quant_b.permute(0, 3, 1, 2)
435 |
436 | dec = self.decode(quant_t, quant_b)
437 |
438 | return dec
--------------------------------------------------------------------------------
/preprocessing/preprocess_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code for preprocessing the video dataset
3 | Processes a video, generates face crop, and creates a constant crop for all face crops generated
4 | Writes the generated face crops (as frames) to a video
5 | '''
6 |
7 | import math
8 | import os
9 | import os.path as osp
10 | from tqdm import tqdm
11 | import gc
12 | from glob import glob
13 |
14 | import cv2
15 | import matplotlib.pyplot as plt
16 | import mediapipe as mp
17 |
18 | def display_image(image, requires_colorization=True):
19 | plt.figure()
20 | if requires_colorization:
21 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
22 | else:
23 | plt.imshow(image)
24 |
25 | def get_iou(bb1, bb2):
26 | """
27 | Calculate the Intersection over Union (IoU) of two bounding boxes.
28 |
29 | Parameters
30 | ----------
31 | bb1 : dict
32 | Keys: {'x1', 'x2', 'y1', 'y2'}
33 | The (x1, y1) position is at the top left corner,
34 | the (x2, y2) position is at the bottom right corner
35 | bb2 : dict
36 | Keys: {'x1', 'x2', 'y1', 'y2'}
37 | The (x, y) position is at the top left corner,
38 | the (x2, y2) position is at the bottom right corner
39 |
40 | Returns
41 | -------
42 | float
43 | in [0, 1]
44 | """
45 | assert bb1['x1'] < bb1['x2']
46 | assert bb1['y1'] < bb1['y2']
47 | assert bb2['x1'] < bb2['x2']
48 | assert bb2['y1'] < bb2['y2']
49 |
50 | # determine the coordinates of the intersection rectangle
51 | x_left = max(bb1['x1'], bb2['x1'])
52 | y_top = max(bb1['y1'], bb2['y1'])
53 | x_right = min(bb1['x2'], bb2['x2'])
54 | y_bottom = min(bb1['y2'], bb2['y2'])
55 |
56 | if x_right < x_left or y_bottom < y_top:
57 | return 0.0
58 |
59 | # The intersection of two axis-aligned bounding boxes is always an
60 | # axis-aligned bounding box
61 | intersection_area = (x_right - x_left) * (y_bottom - y_top)
62 |
63 | # compute the area of both AABBs
64 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
65 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
66 |
67 | '''
68 | comptute the iou by taking the intersecting area over the sum of bounding boxes
69 | '''
70 |
71 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
72 | assert iou >= 0.0
73 | assert iou <= 1.0
74 | return iou
75 |
76 |
77 | '''
78 | Function for writing frames to video
79 | '''
80 | def write_frames_to_video(frames, current_frames, video_order_id, VIDEO_DIR, fps=30):
81 | height, width, _ = frames[0].shape
82 | # the frame coordinates have to be taken from the bounding box
83 | video_path = osp.join(VIDEO_DIR, str(video_order_id).zfill(5) + '.mp4')
84 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
85 | for frame_index in current_frames:
86 | video.write(frames[frame_index])
87 |
88 | cv2.destroyAllWindows()
89 | video.release()
90 |
91 | print(f'Video {video_path} written successfully')
92 |
93 |
94 | def crop_get_video(frames, current_indexes, bounding_box, VIDEO_DIR, video_order_id, fps=30):
95 | left, top, right, down = bounding_box['x1'], bounding_box['y1'], bounding_box['x2'], bounding_box['y2']
96 | width, height = right - left, down - top
97 | video_path = osp.join(VIDEO_DIR, str(video_order_id).zfill(5) + '.mp4')
98 |
99 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
100 |
101 | for index in current_indexes:
102 | current_frame = frames[index]
103 | cropped = current_frame[top : down, left : right]
104 |
105 | video.write(cropped)
106 |
107 | print(f'Writing to file : {video_order_id}')
108 | cv2.destroyAllWindows()
109 |
110 |
111 | '''
112 | Function to return the bb coordinates of the cropped face
113 | '''
114 | def crop_face_coordinates(image, x_px, y_px, width_px, height_px):
115 | # using different thresholds/bounds for the upper and lower faces
116 | image_height, image_width, _ = image.shape
117 | lower_face_buffer, upper_face_buffer = 0.25, 0.65
118 | min_x, min_y, max_x, max_y = x_px, y_px, x_px + width_px, y_px + height_px
119 |
120 | x_left = max(0, int(min_x - (max_x - min_x) * lower_face_buffer))
121 | x_right = min(image_width, int(max_x + (max_x - min_x) * lower_face_buffer))
122 | y_top = max(0, int(min_y - (max_y - min_y) * upper_face_buffer))
123 | y_down = min(image_height, int(max_y + (max_y - min_y) * lower_face_buffer))
124 |
125 | size = max(x_right - x_left, y_down - y_top)
126 | sw = int((x_left + x_right)/2 - size // 2)
127 |
128 | return sw, y_top, sw+size, y_down
129 |
130 |
131 | '''
132 | function to return the bb coordinates given the image
133 | '''
134 | def bb_coordinates(image):
135 | mp_face_detection = mp.solutions.face_detection
136 | face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
137 |
138 | results = face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
139 |
140 | result = results.detections
141 | image_height, image_width, _ = image.shape
142 |
143 | if result is None:
144 | return -1, -1, -1, -1
145 |
146 | bb_values = result[0].location_data.relative_bounding_box
147 |
148 | normalized_x, normalized_y, normalized_width, normalized_height = \
149 | bb_values.xmin, bb_values.ymin, bb_values.width, bb_values.height
150 |
151 | # the bounding box coordinates are given as normalized values, unnormalize them by multiplying by height and width
152 | x_px = min(math.floor(normalized_x * image_width), image_width - 1)
153 | y_px = min(math.floor(normalized_y * image_height), image_height - 1)
154 | width_px = min(math.floor(normalized_width * image_width), image_width - 1)
155 | height_px = min(math.floor(normalized_height * image_height), image_height - 1)
156 |
157 | return x_px, y_px, width_px, height_px
158 |
159 | '''
160 | the formula for using the intersection over union works differently for different resolutions
161 | this is because it is dependent on the pixel information and not necessarily what's inside
162 | the frames are read till either i) end of the frames are reached, or ii) 4000 frames (cpu limit) are read
163 | '''
164 |
165 | def process_frames(frames, VIDEO_DIR):
166 | # declaring global scopre for the video_order_id variable
167 | global video_order_id
168 |
169 | current_frames = list()
170 | mean_bounding_box = dict()
171 | iou_threshold = 0.7
172 | frame_count = 0
173 | frame_writing_threshold = 30 # minimum number of frames to write
174 | bb_prev_mean = dict()
175 |
176 | for index, frame in tqdm(enumerate(frames)):
177 | image_height, image_width, _ = frame.shape
178 | x_px, y_px, width_px, height_px = bb_coordinates(frame)
179 |
180 | if x_px == -1:
181 | if len(current_frames) > frame_writing_threshold:
182 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id)
183 | video_order_id += 1
184 |
185 | # reset
186 | current_frames = list()
187 | frame_count = 0
188 | mean_bounding_box = dict()
189 | bb_prev_mean = dict()
190 |
191 | else:
192 | left, top, right, bottom = crop_face_coordinates(frame, x_px, y_px, width_px, height_px)
193 | current_bounding_box = {'x1' : left, 'x2' : right, 'y1' : top, 'y2' : bottom}
194 |
195 | if len(mean_bounding_box) == 0:
196 |
197 | mean_bounding_box = current_bounding_box.copy()
198 | bb_prev_mean = current_bounding_box.copy()
199 |
200 | frame_count += 1
201 | current_frames.append(index)
202 |
203 | else:
204 | # UPDATE - compute the iou between the current bounding box and the mean bounding box
205 | iou = get_iou(bb_prev_mean, current_bounding_box)
206 |
207 | if iou < iou_threshold:
208 | mean_left, mean_right, mean_top, mean_down = mean_bounding_box['x1'], mean_bounding_box['x2'], mean_bounding_box['y1'], mean_bounding_box['y2']
209 |
210 | if len(current_frames) > frame_writing_threshold:
211 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id)
212 | video_order_id += 1
213 |
214 | current_frames = list()
215 | frame_count = 0
216 | mean_bounding_box = dict()
217 |
218 | # Add the current bounding box to the list of bounding boxes and compute the mean
219 | else:
220 | mean_bounding_box['x1'] = min(mean_bounding_box['x1'], current_bounding_box['x1'])
221 | mean_bounding_box['y1'] = min(mean_bounding_box['y1'], current_bounding_box['y1'])
222 | mean_bounding_box['x2'] = max(mean_bounding_box['x2'], current_bounding_box['x2'])
223 | mean_bounding_box['y2'] = max(mean_bounding_box['y2'], current_bounding_box['y2'])
224 |
225 | # update the coordinates of the mean bounding box
226 | for item in bb_prev_mean.keys():
227 | bb_prev_mean[item] = int((bb_prev_mean[item] * frame_count + current_bounding_box[item])/(frame_count + 1))
228 |
229 | frame_count += 1
230 | current_frames.append(index)
231 |
232 | if len(current_frames) > frame_writing_threshold:
233 | crop_get_video(frames, current_frames, mean_bounding_box, VIDEO_DIR, video_order_id)
234 | video_order_id += 1
235 |
236 |
237 | '''
238 | method for processing a single video
239 | reads video frames and calls function to process the frames
240 | '''
241 | def process_video(video_file, processed_videos_dir):
242 | video_stream = cv2.VideoCapture(video_file)
243 | print(f'Total number of frames in the current video : {video_stream.get(cv2.CAP_PROP_FRAME_COUNT)}')
244 | print(f'Processing video file {video_file}')
245 | frames = list()
246 |
247 | # keep reading the frames if the processing of the current frames is over
248 | frame_reading_threshold = 8000
249 | frames_processed = 0
250 | frames_processed_threshold = 2000
251 |
252 | video_file_name = osp.basename(video_file).split('.')[0]
253 | VIDEO_DIR = osp.join(processed_videos_dir, video_file_name)
254 |
255 | os.makedirs(VIDEO_DIR, exist_ok=True)
256 |
257 | ret, frame = video_stream.read()
258 | while ret:
259 | frames.append(frame)
260 | frames_processed += 1
261 |
262 | if frames_processed % frames_processed_threshold == 0:
263 | print(f'{frames_processed} frames read')
264 |
265 | if frames_processed%frame_reading_threshold == 0:
266 | # perform processing and generate frame videos
267 | print(f'Processing the frames for the current batch')
268 | process_frames(frames, VIDEO_DIR)
269 | print(f'Done with processing of the current batch')
270 |
271 | del frames
272 | gc.collect()
273 | frames = list() # clear out the frames
274 |
275 | ret, frame = video_stream.read()
276 |
277 | # check if more frames that were read need to be processed
278 | if len(frames) != 0:
279 | print(f'Processing remaining frames')
280 | process_frames(frames, VIDEO_DIR)
281 |
282 | video_stream.release()
283 |
284 |
285 | '''
286 | method to process multiple videos
287 | '''
288 | def process_videos(video_dir):
289 | video_files = glob(video_dir + '/*.mp4')
290 |
291 | for video_file in tqdm(video_files):
292 | process_video(video_file)
293 |
294 |
295 | if __name__ == '__main__':
296 | video_dir = 'videos_dir'
297 | process_videos(video_dir)
--------------------------------------------------------------------------------
/results/inference_pipeline.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/inference_pipeline.gif
--------------------------------------------------------------------------------
/results/training_pipeline.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/training_pipeline.gif
--------------------------------------------------------------------------------
/results/v2v_comparisons1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_comparisons1.gif
--------------------------------------------------------------------------------
/results/v2v_comparisons31.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_comparisons31.gif
--------------------------------------------------------------------------------
/results/v2v_face_swapping.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_face_swapping.gif
--------------------------------------------------------------------------------
/results/v2v_faceswapping_looped2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_faceswapping_looped2.gif
--------------------------------------------------------------------------------
/results/v2v_more_result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_more_result.gif
--------------------------------------------------------------------------------
/results/v2v_results1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results1.gif
--------------------------------------------------------------------------------
/results/v2v_results2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results2.gif
--------------------------------------------------------------------------------
/results/v2v_results3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results3.gif
--------------------------------------------------------------------------------
/results/v2v_results4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_results4.gif
--------------------------------------------------------------------------------
/results/v2v_same_identity1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity1.gif
--------------------------------------------------------------------------------
/results/v2v_same_identity2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity2.gif
--------------------------------------------------------------------------------
/results/v2v_same_identity3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skymanaditya1/FaceOff/61069cee8c51229d5c29f848ae618b5640153c48/results/v2v_same_identity3.gif
--------------------------------------------------------------------------------
/sample/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
2 |
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | from math import cos, pi, floor, sin
2 |
3 | from torch.optim import lr_scheduler
4 |
5 |
6 | class CosineLR(lr_scheduler._LRScheduler):
7 | def __init__(self, optimizer, lr_min, lr_max, step_size):
8 | self.lr_min = lr_min
9 | self.lr_max = lr_max
10 | self.step_size = step_size
11 | self.iteration = 0
12 |
13 | super().__init__(optimizer, -1)
14 |
15 | def get_lr(self):
16 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
17 | 1 + cos(self.iteration / self.step_size * pi)
18 | )
19 | self.iteration += 1
20 |
21 | if self.iteration == self.step_size:
22 | self.iteration = 0
23 |
24 | return [lr for base_lr in self.base_lrs]
25 |
26 |
27 | class PowerLR(lr_scheduler._LRScheduler):
28 | def __init__(self, optimizer, lr_min, lr_max, warmup):
29 | self.lr_min = lr_min
30 | self.lr_max = lr_max
31 | self.warmup = warmup
32 | self.iteration = 0
33 |
34 | super().__init__(optimizer, -1)
35 |
36 | def get_lr(self):
37 | if self.iteration < self.warmup:
38 | lr = (
39 | self.lr_min + (self.lr_max - self.lr_min) / self.warmup * self.iteration
40 | )
41 |
42 | else:
43 | lr = self.lr_max * (self.iteration - self.warmup + 1) ** -0.5
44 |
45 | self.iteration += 1
46 |
47 | return [lr for base_lr in self.base_lrs]
48 |
49 |
50 | class SineLR(lr_scheduler._LRScheduler):
51 | def __init__(self, optimizer, lr_min, lr_max, step_size):
52 | self.lr_min = lr_min
53 | self.lr_max = lr_max
54 | self.step_size = step_size
55 | self.iteration = 0
56 |
57 | super().__init__(optimizer, -1)
58 |
59 | def get_lr(self):
60 | lr = self.lr_min + (self.lr_max - self.lr_min) * sin(
61 | self.iteration / self.step_size * pi
62 | )
63 | self.iteration += 1
64 |
65 | if self.iteration == self.step_size:
66 | self.iteration = 0
67 |
68 | return [lr for base_lr in self.base_lrs]
69 |
70 |
71 | class LinearLR(lr_scheduler._LRScheduler):
72 | def __init__(self, optimizer, lr_min, lr_max, warmup, step_size):
73 | self.lr_min = lr_min
74 | self.lr_max = lr_max
75 | self.step_size = step_size
76 | self.warmup = warmup
77 | self.iteration = 0
78 |
79 | super().__init__(optimizer, -1)
80 |
81 | def get_lr(self):
82 | if self.iteration < self.warmup:
83 | lr = self.lr_max
84 |
85 | else:
86 | lr = self.lr_max + (self.iteration - self.warmup) * (
87 | self.lr_min - self.lr_max
88 | ) / (self.step_size - self.warmup)
89 | self.iteration += 1
90 |
91 | if self.iteration == self.step_size:
92 | self.iteration = 0
93 |
94 | return [lr for base_lr in self.base_lrs]
95 |
96 |
97 | class CLR(lr_scheduler._LRScheduler):
98 | def __init__(self, optimizer, lr_min, lr_max, step_size):
99 | self.epoch = 0
100 | self.lr_min = lr_min
101 | self.lr_max = lr_max
102 | self.current_lr = lr_min
103 | self.step_size = step_size
104 |
105 | super().__init__(optimizer, -1)
106 |
107 | def get_lr(self):
108 | cycle = floor(1 + self.epoch / (2 * self.step_size))
109 | x = abs(self.epoch / self.step_size - 2 * cycle + 1)
110 | lr = self.lr_min + (self.lr_max - self.lr_min) * max(0, 1 - x)
111 | self.current_lr = lr
112 |
113 | self.epoch += 1
114 |
115 | return [lr for base_lr in self.base_lrs]
116 |
117 |
118 | class Warmup(lr_scheduler._LRScheduler):
119 | def __init__(self, optimizer, model_dim, factor=1, warmup=16000):
120 | self.optimizer = optimizer
121 | self.model_dim = model_dim
122 | self.factor = factor
123 | self.warmup = warmup
124 | self.iteration = 0
125 |
126 | super().__init__(optimizer, -1)
127 |
128 | def get_lr(self):
129 | self.iteration += 1
130 | lr = (
131 | self.factor
132 | * self.model_dim ** (-0.5)
133 | * min(self.iteration ** (-0.5), self.iteration * self.warmup ** (-1.5))
134 | )
135 |
136 | return [lr for base_lr in self.base_lrs]
137 |
138 |
139 | # Copyright 2019 fastai
140 |
141 | # Licensed under the Apache License, Version 2.0 (the "License");
142 | # you may not use this file except in compliance with the License.
143 | # You may obtain a copy of the License at
144 |
145 | # http://www.apache.org/licenses/LICENSE-2.0
146 |
147 | # Unless required by applicable law or agreed to in writing, software
148 | # distributed under the License is distributed on an "AS IS" BASIS,
149 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
150 | # See the License for the specific language governing permissions and
151 | # limitations under the License.
152 |
153 |
154 | # Borrowed from https://github.com/fastai/fastai and changed to make it runs like PyTorch lr scheduler
155 |
156 |
157 | class CycleAnnealScheduler:
158 | def __init__(
159 | self, optimizer, lr_max, lr_divider, cut_point, step_size, momentum=None
160 | ):
161 | self.lr_max = lr_max
162 | self.lr_divider = lr_divider
163 | self.cut_point = step_size // cut_point
164 | self.step_size = step_size
165 | self.iteration = 0
166 | self.cycle_step = int(step_size * (1 - cut_point / 100) / 2)
167 | self.momentum = momentum
168 | self.optimizer = optimizer
169 |
170 | def get_lr(self):
171 | if self.iteration > 2 * self.cycle_step:
172 | cut = (self.iteration - 2 * self.cycle_step) / (
173 | self.step_size - 2 * self.cycle_step
174 | )
175 | lr = self.lr_max * (1 + (cut * (1 - 100) / 100)) / self.lr_divider
176 |
177 | elif self.iteration > self.cycle_step:
178 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step
179 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider
180 |
181 | else:
182 | cut = self.iteration / self.cycle_step
183 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider
184 |
185 | return lr
186 |
187 | def get_momentum(self):
188 | if self.iteration > 2 * self.cycle_step:
189 | momentum = self.momentum[0]
190 |
191 | elif self.iteration > self.cycle_step:
192 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step
193 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0])
194 |
195 | else:
196 | cut = self.iteration / self.cycle_step
197 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0])
198 |
199 | return momentum
200 |
201 | def step(self):
202 | lr = self.get_lr()
203 |
204 | if self.momentum is not None:
205 | momentum = self.get_momentum()
206 |
207 | self.iteration += 1
208 |
209 | if self.iteration == self.step_size:
210 | self.iteration = 0
211 |
212 | for group in self.optimizer.param_groups:
213 | group['lr'] = lr
214 |
215 | if self.momentum is not None:
216 | group['betas'] = (momentum, group['betas'][1])
217 |
218 | return lr
219 |
220 |
221 | def anneal_linear(start, end, proportion):
222 | return start + proportion * (end - start)
223 |
224 |
225 | def anneal_cos(start, end, proportion):
226 | cos_val = cos(pi * proportion) + 1
227 |
228 | return end + (start - end) / 2 * cos_val
229 |
230 |
231 | class Phase:
232 | def __init__(self, start, end, n_iter, anneal_fn):
233 | self.start, self.end = start, end
234 | self.n_iter = n_iter
235 | self.anneal_fn = anneal_fn
236 | self.n = 0
237 |
238 | def step(self):
239 | self.n += 1
240 |
241 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter)
242 |
243 | def reset(self):
244 | self.n = 0
245 |
246 | @property
247 | def is_done(self):
248 | return self.n >= self.n_iter
249 |
250 |
251 | class CycleScheduler:
252 | def __init__(
253 | self,
254 | optimizer,
255 | lr_max,
256 | n_iter,
257 | momentum=(0.95, 0.85),
258 | divider=25,
259 | warmup_proportion=0.3,
260 | phase=('linear', 'cos'),
261 | ):
262 | self.optimizer = optimizer
263 |
264 | phase1 = int(n_iter * warmup_proportion)
265 | phase2 = n_iter - phase1
266 | lr_min = lr_max / divider
267 |
268 | phase_map = {'linear': anneal_linear, 'cos': anneal_cos}
269 |
270 | self.lr_phase = [
271 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]),
272 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]),
273 | ]
274 |
275 | self.momentum = momentum
276 |
277 | if momentum is not None:
278 | mom1, mom2 = momentum
279 | self.momentum_phase = [
280 | Phase(mom1, mom2, phase1, phase_map[phase[0]]),
281 | Phase(mom2, mom1, phase2, phase_map[phase[1]]),
282 | ]
283 |
284 | else:
285 | self.momentum_phase = []
286 |
287 | self.phase = 0
288 |
289 | def step(self):
290 | lr = self.lr_phase[self.phase].step()
291 |
292 | if self.momentum is not None:
293 | momentum = self.momentum_phase[self.phase].step()
294 |
295 | else:
296 | momentum = None
297 |
298 | for group in self.optimizer.param_groups:
299 | group['lr'] = lr
300 |
301 | if self.momentum is not None:
302 | if 'betas' in group:
303 | group['betas'] = (momentum, group['betas'][1])
304 |
305 | else:
306 | group['momentum'] = momentum
307 |
308 | if self.lr_phase[self.phase].is_done:
309 | self.phase += 1
310 |
311 | if self.phase >= len(self.lr_phase):
312 | for phase in self.lr_phase:
313 | phase.reset()
314 |
315 | for phase in self.momentum_phase:
316 | phase.reset()
317 |
318 | self.phase = 0
319 |
320 | return lr, momentum
321 |
322 |
323 | class LRFinder(lr_scheduler._LRScheduler):
324 | def __init__(self, optimizer, lr_min, lr_max, step_size, linear=False):
325 | ratio = lr_max / lr_min
326 | self.linear = linear
327 | self.lr_min = lr_min
328 | self.lr_mult = (ratio / step_size) if linear else ratio ** (1 / step_size)
329 | self.iteration = 0
330 | self.lrs = []
331 | self.losses = []
332 |
333 | super().__init__(optimizer, -1)
334 |
335 | def get_lr(self):
336 | lr = (
337 | self.lr_mult * self.iteration
338 | if self.linear
339 | else self.lr_mult ** self.iteration
340 | )
341 | lr = self.lr_min + lr if self.linear else self.lr_min * lr
342 |
343 | self.iteration += 1
344 | self.lrs.append(lr)
345 |
346 | return [lr for base_lr in self.base_lrs]
347 |
348 | def record(self, loss):
349 | self.losses.append(loss)
350 |
351 | def save(self, filename):
352 | with open(filename, 'w') as f:
353 | for lr, loss in zip(self.lrs, self.losses):
354 | f.write('{},{}\n'.format(lr, loss))
355 |
--------------------------------------------------------------------------------
/train_faceoff.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import random
5 | import os.path as osp
6 |
7 | import torch
8 | from torch import nn, optim
9 | from torch.utils.data import DataLoader
10 |
11 | from torchvision import datasets, transforms, utils
12 |
13 | from tqdm import tqdm
14 |
15 | from scheduler import CycleScheduler
16 | import distributed as dist
17 |
18 | from utils import *
19 | from config import DATASET, LATENT_LOSS_WEIGHT, SAMPLE_SIZE_FOR_VISUALIZATION
20 |
21 | criterion = nn.MSELoss()
22 |
23 | global_step = 0
24 |
25 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION
26 |
27 | dataset = DATASET
28 |
29 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results'
30 |
31 | def run_step(model, data, device, run='train'):
32 | img, S, ground_truth, source_images_original = process_data(data, device, dataset)
33 |
34 | out, latent_loss = model(img)
35 |
36 | out = out[:, :3]
37 |
38 | recon_loss = criterion(out, ground_truth)
39 | latent_loss = latent_loss.mean()
40 |
41 | if run == 'train':
42 | return recon_loss, latent_loss, S
43 | else:
44 | return ground_truth, img, out
45 |
46 | def blob2full_validation(model, img):
47 | face, rhand, lhand = img
48 |
49 | face = face[:sample_size]
50 | rhand = rhand[:sample_size]
51 | lhand = lhand[:sample_size]
52 | sample = face, rhand, lhand
53 |
54 | gt = gt[:sample_size]
55 |
56 | with torch.no_grad():
57 | out, _ = model(sample)
58 |
59 | save_image(torch.cat([face, rhand, lhand, out, gt], 0),
60 | f"sample/{epoch + 1}_{i}.png")
61 |
62 | def jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder):
63 | for i, data in enumerate(tqdm(val_loader)):
64 | with torch.no_grad():
65 | source_images, input, prediction = run_step(model, data, device, run='val')
66 |
67 | source_hulls = input[:, :3]
68 | background = input[:, 3:]
69 |
70 | saves = {
71 | 'source': source_hulls,
72 | 'background': background,
73 | 'prediction': prediction,
74 | 'source_images': source_images
75 | }
76 |
77 | if i % (len(val_loader) // 10) == 0 or run_type != 'train':
78 | def denormalize(x):
79 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
80 |
81 | for name in saves:
82 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4"
83 | frames = saves[name].detach().cpu()
84 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames]
85 |
86 | # os.makedirs(sample_folder, exist_ok=True)
87 | save_frames_as_video(frames, saveas, fps=25)
88 |
89 | def base_validation(model, val_loader, device, epoch, i, run_type, sample_folder):
90 | def get_proper_shape(x):
91 | shape = x.shape
92 | return x.view(shape[0], -1, 3, shape[2], shape[3]).view(-1, 3, shape[2], shape[3])
93 |
94 | for val_i, data in enumerate(tqdm(val_loader)):
95 | with torch.no_grad():
96 | sample, _, out = run_step(model, data, device, 'val')
97 |
98 | if val_i % (len(val_loader)//10) == 0: # save 10 results
99 |
100 | def denormalize(x):
101 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
102 |
103 | # save_image(
104 | # torch.cat([sample[:3*3], out[:3*3]], 0),
105 | # f"{sample_folder}/{epoch + 1}_{i}_{val_i}.png")
106 |
107 | save_as_sample = f"{sample_folder}/{epoch+1}_{global_step}_{i}_sample.mp4"
108 | save_as_out = f"{sample_folder}/{epoch+1}_{global_step}_{i}_out.mp4"
109 |
110 | # save_as_sample = f"{sample_folder}/{val_i}_sample.mp4"
111 | # save_as_out = f"{sample_folder}/{val_i}_out.mp4"
112 |
113 | sample = sample.detach().cpu()
114 | sample = [denormalize(x).permute(1, 2, 0).numpy() for x in sample]
115 |
116 | out = out.detach().cpu()
117 | out = [denormalize(x).permute(1, 2, 0).numpy() for x in out]
118 |
119 | save_frames_as_video(sample, save_as_sample, fps=25)
120 | save_frames_as_video(out, save_as_out, fps=25)
121 |
122 | def validation(model, val_loader, device, epoch, i, sample_folder, run_type='train'):
123 | if dataset >= 6:
124 | jitter_validation(model, val_loader, device, epoch, i, run_type, sample_folder)
125 | else:
126 | base_validation(model, val_loader, device, epoch, i, run_type, sample_folder)
127 |
128 | def train(model, loader, val_loader, optimizer, scheduler, device, epoch, validate_at, checkpoint_dir, sample_folder):
129 | if dist.is_primary():
130 | loader = tqdm(loader, file=sys.stdout)
131 |
132 | mse_sum = 0
133 | mse_n = 0
134 |
135 | for i, data in enumerate(loader):
136 | model.zero_grad()
137 |
138 | recon_loss, latent_loss, S = run_step(model, data, device)
139 |
140 | loss = recon_loss + LATENT_LOSS_WEIGHT * latent_loss
141 |
142 | loss.backward()
143 |
144 | if scheduler is not None:
145 | scheduler.step()
146 |
147 | optimizer.step()
148 |
149 | global global_step
150 |
151 | global_step += 1
152 |
153 | part_mse_sum = recon_loss.item() * S
154 | part_mse_n = S
155 |
156 | comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
157 | comm = dist.all_gather(comm)
158 |
159 | for part in comm:
160 | mse_sum += part["mse_sum"]
161 | mse_n += part["mse_n"]
162 |
163 | if dist.is_primary():
164 | lr = optimizer.param_groups[0]["lr"]
165 |
166 | loader.set_description(
167 | (
168 | f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
169 | f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
170 | f"lr: {lr:.5f}"
171 | )
172 | )
173 |
174 | if i % validate_at == 0:
175 | model.eval()
176 |
177 | # going inside the validation
178 | print(f'Going inside the validation')
179 |
180 | validation(model, val_loader, device, epoch, i, sample_folder)
181 |
182 | if dist.is_primary():
183 | os.makedirs(checkpoint_dir, exist_ok=True)
184 |
185 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(i + 1).zfill(4)}.pt")
186 |
187 | model.train()
188 |
189 | def main(args):
190 | device = "cuda"
191 |
192 | args.distributed = dist.get_world_size() > 1
193 |
194 | default_transform = transforms.Compose(
195 | [
196 | transforms.ToPILImage(),
197 | transforms.Resize(args.size),
198 | transforms.ToTensor(),
199 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
200 | ]
201 | )
202 |
203 | loader, val_loader, model = get_loaders_and_models(
204 | args, dataset, default_transform, device, test=args.test)
205 |
206 | if args.distributed:
207 | model = nn.parallel.DistributedDataParallel(
208 | model,
209 | device_ids=[dist.get_local_rank()],
210 | output_device=dist.get_local_rank(),
211 | )
212 |
213 | if args.ckpt:
214 | print(f'Loading pretrained checkpoint - {args.ckpt}')
215 | state_dict = torch.load(args.ckpt)
216 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() }
217 | try:
218 | model.module.load_state_dict(state_dict)
219 | except:
220 | model.load_state_dict(state_dict)
221 |
222 | if args.test:
223 | # test(loader, model, device)
224 | validation(model, val_loader, device, 0, 0, args.sample_folder, 'val')
225 | else:
226 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
227 |
228 | scheduler = None
229 |
230 | if args.sched == "cycle":
231 | scheduler = CycleScheduler(
232 | optimizer,
233 | args.lr,
234 | n_iter=len(loader) * args.epoch,
235 | momentum=None,
236 | warmup_proportion=0.05,
237 | )
238 |
239 | for i in range(args.epoch):
240 | train(model, loader, val_loader, optimizer, scheduler, device, i, args.validate_at, args.checkpoint_dir, args.sample_folder)
241 |
242 | def get_random_name(cipher_length=5):
243 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
244 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)])
245 |
246 | if __name__ == "__main__":
247 | parser = argparse.ArgumentParser()
248 | parser.add_argument("--n_gpu", type=int, default=1)
249 |
250 | # port = (
251 | # 2 ** 15
252 | # + 2 ** 14
253 | # + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
254 | # )
255 |
256 | port = random.randint(51000, 52000)
257 |
258 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")
259 | parser.add_argument("--batch_size", type=int, default=32)
260 | parser.add_argument("--size", type=int, default=256)
261 | parser.add_argument("--epoch", type=int, default=560)
262 | parser.add_argument("--lr", type=float, default=3e-4)
263 | parser.add_argument("--sched", type=str)
264 | parser.add_argument("--checkpoint_suffix", type=str, default='')
265 | parser.add_argument("--validate_at", type=int, default=1024)
266 | parser.add_argument("--ckpt", required=False)
267 | parser.add_argument("--test", action='store_true', required=False)
268 | parser.add_argument("--gray", action='store_true', required=False)
269 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty')
270 | parser.add_argument("--crossid", action='store_true', required=False)
271 | parser.add_argument("--sample_folder", type=str, default='samples')
272 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint')
273 |
274 | args = parser.parse_args()
275 |
276 | # args.n_gpu = torch.cuda.device_count()
277 | current_run = get_random_name()
278 |
279 | # sample_folder = sample_folder.format(args.checkpoint_suffix)
280 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run)
281 | os.makedirs(args.sample_folder, exist_ok=True)
282 |
283 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run)
284 | # os.makedirs(args.checkpoint_dir, exist_ok=True)
285 |
286 | # checkpoint_dir = checkpoint_dir.format(args.checkpoint_suffix)
287 |
288 | print(args, flush=True)
289 |
290 | dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))
291 |
--------------------------------------------------------------------------------
/train_faceoff_perceptual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import random
5 | import os.path as osp
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | import torch
10 | from torch import nn, optim
11 | from torch.utils.data import DataLoader
12 | from torchvision import datasets, transforms, utils
13 |
14 | from scheduler import CycleScheduler
15 | import distributed as dist
16 |
17 | from utils import *
18 | from config import DATASET, LATENT_LOSS_WEIGHT, PERCEPTUAL_LOSS_WEIGHT, SAMPLE_SIZE_FOR_VISUALIZATION
19 |
20 |
21 | criterion = nn.MSELoss()
22 |
23 | global_step = 0
24 |
25 | sample_size = SAMPLE_SIZE_FOR_VISUALIZATION
26 |
27 | dataset = DATASET
28 |
29 | BASE = '/ssd_scratch/cvit/aditya1/video_vqvae2_results'
30 |
31 |
32 | def run_step(model, vqlpips, data, device, run='train'):
33 | img, S, ground_truth, source_image = process_data(data, device, dataset)
34 |
35 | out, latent_loss = model(img)
36 |
37 | out = out[:, :3]
38 |
39 | recon_loss = criterion(out, ground_truth)
40 | latent_loss = latent_loss.mean()
41 |
42 | perceptual_loss = vqlpips(ground_truth, out)
43 |
44 | if run == 'train':
45 | return recon_loss, latent_loss, perceptual_loss, S
46 | else:
47 | return ground_truth, img, out, source_image
48 |
49 |
50 | '''
51 | runs the validation step, saves the blended output video
52 | '''
53 | def validation(model, vqlpips, val_loader, device, epoch, i, sample_folder, run_type='train'):
54 | print(f'Inside jitter validation')
55 | for i, data in enumerate(tqdm(val_loader)):
56 | with torch.no_grad():
57 | source_images, input, prediction, source_images_original = run_step(model, vqlpips, data, device, run='val')
58 |
59 | source_hulls = input[:, :3]
60 | background = input[:, 3:]
61 |
62 | saves = {
63 | 'source': source_hulls,
64 | 'background': background,
65 | 'prediction': prediction,
66 | 'source_images': source_images,
67 | 'source_original': source_images_original,
68 | }
69 |
70 | if True:
71 | def denormalize(x):
72 | return (x.clamp(min=-1.0, max=1.0) + 1)/2
73 |
74 | for name in saves:
75 | saveas = f"{sample_folder}/{epoch + 1}_{global_step}_{i}_{name}.mp4"
76 | frames = saves[name].detach().cpu()
77 | frames = [denormalize(x).permute(1, 2, 0).numpy() for x in frames]
78 |
79 | save_frames_as_video(frames, saveas, fps=25)
80 |
81 | '''
82 | trainer code -- loss is computed as the weighted sum of reconstruction loss, latent loss, and perceptual loss
83 | '''
84 | def train(model, vqlpips, loader, val_loader, optimizer, scheduler, device, epoch, validate_at, checkpoint_dir, sample_folder):
85 | if dist.is_primary():
86 | loader = tqdm(loader, file=sys.stdout)
87 |
88 | mse_sum = 0
89 | mse_n = 0
90 | perceptual_losses = []
91 |
92 | for i, data in enumerate(loader):
93 | model.zero_grad()
94 |
95 | recon_loss, latent_loss, perceptual_loss, S = run_step(model, vqlpips, data, device)
96 |
97 | # loss is a weighted sum of i) reconstruction loss, ii) latent loss, and iii) perceptual loss
98 | loss = recon_loss + LATENT_LOSS_WEIGHT * latent_loss + PERCEPTUAL_LOSS_WEIGHT * perceptual_loss
99 |
100 | loss.backward()
101 |
102 | perceptual_losses.append(perceptual_loss.item())
103 |
104 | if scheduler is not None:
105 | scheduler.step()
106 |
107 | optimizer.step()
108 |
109 | global global_step
110 |
111 | global_step += 1
112 |
113 | part_mse_sum = recon_loss.item() * S
114 | part_mse_n = S
115 |
116 | comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
117 | comm = dist.all_gather(comm)
118 |
119 | for part in comm:
120 | mse_sum += part["mse_sum"]
121 | mse_n += part["mse_n"]
122 |
123 | if dist.is_primary():
124 | lr = optimizer.param_groups[0]["lr"]
125 |
126 | loader.set_description(
127 | (
128 | f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
129 | f"perceptual: {np.array(perceptual_losses).mean():.3f} "
130 | f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
131 | f"lr: {lr:.5f}"
132 | )
133 | )
134 |
135 | if i % validate_at == 0:
136 | model.eval()
137 |
138 | validation(model, vqlpips, val_loader, device, epoch, i, sample_folder)
139 |
140 | if dist.is_primary():
141 | os.makedirs(checkpoint_dir, exist_ok=True)
142 |
143 | torch.save(model.state_dict(), f"{checkpoint_dir}/vqvae_{epoch+1}_{str(i + 1).zfill(4)}.pt")
144 |
145 | model.train()
146 |
147 | def main(args):
148 | device = "cuda"
149 |
150 | args.distributed = dist.get_world_size() > 1
151 |
152 | default_transform = transforms.Compose(
153 | [
154 | transforms.ToPILImage(),
155 | transforms.Resize(args.size),
156 | transforms.ToTensor(),
157 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
158 | ]
159 | )
160 |
161 | loader, val_loader, model, vqlpips = get_loaders_and_models(
162 | args, device)
163 |
164 | if args.distributed:
165 | model = nn.parallel.DistributedDataParallel(
166 | model,
167 | device_ids=[dist.get_local_rank()],
168 | output_device=dist.get_local_rank(),
169 | )
170 |
171 | vqlpips = nn.parallel.DistributedDataParallel(
172 | vqlpips,
173 | device_ids=[dist.get_local_rank()],
174 | output_device=dist.get_local_rank(),
175 | )
176 |
177 | # load the pretrained checkpoints if available
178 | if args.ckpt:
179 | print(f'Loading pretrained checkpoint - {args.ckpt}')
180 | state_dict = torch.load(args.ckpt)
181 | state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() }
182 | try:
183 | model.module.load_state_dict(state_dict)
184 | except:
185 | model.load_state_dict(state_dict)
186 |
187 | if args.test:
188 | validation(model, vqlpips, val_loader, device, 0, 0, args.sample_folder, 'val')
189 | else:
190 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
191 |
192 | scheduler = None
193 |
194 | if args.sched == "cycle":
195 | scheduler = CycleScheduler(
196 | optimizer,
197 | args.lr,
198 | n_iter=len(loader) * args.epoch,
199 | momentum=None,
200 | warmup_proportion=0.05,
201 | )
202 |
203 | for i in range(args.epoch):
204 | train(model, vqlpips, loader, val_loader, optimizer, scheduler, device, i, args.validate_at, args.checkpoint_dir, args.sample_folder)
205 |
206 | def get_random_name(cipher_length=5):
207 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
208 | return ''.join([chars[random.randint(0, len(chars)-1)] for i in range(cipher_length)])
209 |
210 | if __name__ == "__main__":
211 | parser = argparse.ArgumentParser()
212 | parser.add_argument("--n_gpu", type=int, default=1)
213 |
214 | # port = (
215 | # 2 ** 15
216 | # + 2 ** 14
217 | # + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
218 | # )
219 |
220 | port = random.randint(51000, 52000)
221 |
222 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")
223 | parser.add_argument("--batch_size", type=int, default=32)
224 | parser.add_argument("--size", type=int, default=256, help='image resolution')
225 | parser.add_argument("--epoch", type=int, default=560, help='total epochs')
226 | parser.add_argument("--lr", type=float, default=3e-4, help='learning rate')
227 | parser.add_argument("--sched", type=str)
228 | parser.add_argument("--checkpoint_suffix", type=str, default='')
229 | parser.add_argument("--validate_at", type=int, default=1024, help='validate after (number of steps)')
230 | parser.add_argument("--ckpt", required=False)
231 | parser.add_argument("--test", action='store_true', required=False)
232 | parser.add_argument("--gray", action='store_true', required=False)
233 | parser.add_argument("--colorjit", type=str, default='', help='const or random or empty')
234 | parser.add_argument("--crossid", action='store_true', required=False)
235 | parser.add_argument("--custom_validation", action='store_true', required=False, help='perform custom validation')
236 | parser.add_argument("--sample_folder", type=str, default='samples')
237 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoint', help='dir path for saving the checkpoints')
238 | parser.add_argument("--validation_folder", type=str, default=None, help='dir path for saving the validated samples')
239 |
240 | args = parser.parse_args()
241 |
242 | # args.n_gpu = torch.cuda.device_count()
243 | current_run = get_random_name()
244 |
245 | args.sample_folder = osp.join(BASE, args.sample_folder + '_' + current_run)
246 | os.makedirs(args.sample_folder, exist_ok=True)
247 |
248 | args.checkpoint_dir = osp.join(BASE, args.checkpoint_dir + '_' + current_run)
249 |
250 | print(args, flush=True)
251 |
252 | print(f'Weight configuration used - perceptual loss weight : {PERCEPTUAL_LOSS_WEIGHT}, latent loss weight : {LATENT_LOSS_WEIGHT}')
253 |
254 | dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torchvision import utils
7 |
8 |
9 | def save_frames_as_video(frames, video_path, fps=30):
10 | height, width, layers = frames[0].shape
11 |
12 | video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
13 | for frame in frames:
14 | video.write(cv2.cvtColor((frame*255).astype(np.uint8), cv2.COLOR_RGB2BGR))
15 |
16 | cv2.destroyAllWindows()
17 | video.release()
18 |
19 | def save_image(data, saveas, video=False):
20 | utils.save_image(
21 | data,
22 | saveas,
23 | nrow=data.shape[0]//2,
24 | normalize=True,
25 | range=(-1, 1),
26 | )
27 |
28 |
29 | def process_data(data, device, dataset):
30 | source, target, background, source_images, source_images_original = data
31 |
32 | img = torch.cat([source, background], axis=2).squeeze(0).to(device)
33 | source = source.squeeze(0)
34 | ground_truth = source_images.squeeze(0).to(device)
35 |
36 | S = source.shape[0]
37 |
38 | return img, S, ground_truth, source_images_original.squeeze(0).to(device)
39 |
40 |
41 | '''
42 | Conv3D based temporal module is added before the quantization step
43 | LPIPS based perceptual loss is added
44 | '''
45 | def get_facetranslation_latent_conv_perceptual(args, device):
46 | from TemporalAlignment.dataset import TemporalAlignmentDataset
47 | from models.vqvae_conv3d_latent import VQVAE
48 | from loss import VQLPIPS
49 |
50 | print(f'Inside conv3d applied before quantization along with perceptual loss')
51 |
52 | model = VQVAE(in_channel=3*2).to(device)
53 | vqlpips = VQLPIPS().to(device)
54 |
55 | train_dataset = TemporalAlignmentDataset(
56 | 'train', 30,
57 | color_jitter_type=args.colorjit,
58 | grayscale_required=args.gray)
59 |
60 | val_dataset = TemporalAlignmentDataset(
61 | 'val', 50,
62 | color_jitter_type=args.colorjit,
63 | cross_identity_required=args.crossid,
64 | grayscale_required=args.gray,
65 | custom_validation_required=args.custom_validation,
66 | validation_datapoints=args.validation_folder)
67 |
68 | try:
69 | train_loader = DataLoader(
70 | train_dataset,
71 | batch_size=1,
72 | shuffle=True,
73 | num_workers=2)
74 | except:
75 | train_loader = None
76 |
77 | val_loader = DataLoader(
78 | val_dataset,
79 | shuffle=False,
80 | batch_size=1,
81 | num_workers=2)
82 |
83 | return train_loader, val_loader, model, vqlpips
84 |
85 |
86 | '''
87 | get the loaders and the models
88 | '''
89 | def get_loaders_and_models(args, device):
90 | return get_facetranslation_latent_conv_perceptual(args, device)
--------------------------------------------------------------------------------
/valid_videos.txt:
--------------------------------------------------------------------------------
1 | c3tTE3hzON8
2 | XGZMe5DsnRU
3 | iOqgDmaYdqA
4 | C93trWdL3Rg
5 | _J0XtWtloWE
6 | MTgyYqY3XTE
7 | TZpP05GHWnw
8 | y9lSgYW2ObQ
9 | ozEAkq3qSi0
10 | rYVGOvyAs-I
11 | 0uNEEfQKfwk
12 | WOEcMOEmbC8
13 | IGl4w4a_pcw
14 | wuS8Hq1Yhro
15 |
--------------------------------------------------------------------------------