├── .gitignore
├── LICENSE
├── README.md
├── pix2pix.py
├── requirements.txt
├── run.py
├── run.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/
2 | data/
3 | testing_output_images/
4 | test_output/
5 | input.mp4
6 | output.mp4
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Karan Vivek Bhargava
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
ObamaNet : Lip Sync from Audio
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ### List of Contents
16 |
17 | * [Requirements](https://github.com/karanvivekbhargava/obamanet#requirements)
18 | * [Data Extraction](https://github.com/karanvivekbhargava/obamanet#data_extraction)
19 | * [Data Preprocessing](https://github.com/karanvivekbhargava/obamanet#data_preprocessing)
20 | * [Training Different Models](https://github.com/karanvivekbhargava/obamanet#training_different_models)
21 | * [Pretrained Model](https://github.com/karanvivekbhargava/obamanet#pretrained_model)
22 | * [How to run an example](https://github.com/karanvivekbhargava/obamanet#how_to_run_an_example)
23 | * [Citations](https://github.com/karanvivekbhargava/obamanet#citations)
24 | * [FAQs](https://github.com/karanvivekbhargava/obamanet#faqs)
25 |
26 |
27 | ### Dependencies
28 |
29 | - [miniconda](https://docs.conda.io/en/latest/miniconda.html)
30 | - python 3.7 bash installer
31 | - ffmpeg (with x264 enabled)
32 | - `sudo apt-get install ffmpeg`
33 | - `brew install -i ffmpeg`
34 | - `./configure --enable-gpl --enable-libx264`
35 |
36 | ### Requirements
37 |
38 | You may install the requirements by running the following commands
39 | ```
40 | conda init
41 |
42 | conda deactivate
43 | conda env remove -n obamanet
44 | conda create -n obamanet
45 | conda activate obamanet
46 | conda install python=3.7 pip
47 | conda install numpy scikit-learn scipy tqdm cmake
48 | conda run pip install -r requirements.txt
49 | ```
50 |
51 | The project is built for python 3.5 and above. The other libraries are listed below
52 | * OpenCV (`sudo pip3 install opencv-contrib-python`)
53 | * Dlib (`sudo pip3 install dlib`) with [this](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) file unzipped in the data folder
54 | * Python Speech Features (`sudo pip3 install python-speech-features`)
55 |
56 | For a complete list refer to `requirements.txt` file.
57 |
58 | I used the tools below to extract and manipulate the data:
59 | * [YouTube-dl](https://github.com/rg3/youtube-dl#video-selection)
60 |
61 | ### Data Extraction
62 | ---
63 | I extracted the data from youtube using youtube-dl. It's perhaps the best downloader for youtube on linux. Commands for extracting particular streams are given below.
64 |
65 | * Subtitle Extraction
66 | ```
67 | youtube-dl --sub-lang en --skip-download --write-sub --output '~/obamanet/data/captions/%(autonumber)s.%(ext)s' --batch-file ~/obamanet/data/obama_addresses.txt --ignore-config
68 | ```
69 | * Video Extraction
70 | ```
71 | youtube-dl --batch-file ~/obamanet/data/obama_addresses.txt -o '~/obamanet/data/videos/%(autonumber)s.%(ext)s' -f "best[height=720]" --autonumber-start 1
72 | ```
73 | (Videos not available in 720p: 165)
74 | * Video to Audio Conversion
75 | ```
76 | python3 vid2wav.py
77 | ```
78 | * Video to Images
79 | ```
80 | ffmpeg -i 00001.mp4 -r 1/5 -vf scale=-1:720 images/00001-$filename%05d.bmp
81 | ```
82 |
83 | To convert from BMP format to JPG format, use the following in the directory
84 | ```
85 | mogrify -format jpg *.bmp
86 | rm -rf *.bmp
87 | ```
88 |
89 | Copy the patched images into folder `a` and the cropped images to folder `b`
90 | ```
91 | python3 tools/process.py --input_dir a --b_dir b --operation combine --output_dir c
92 | python3 tools/split.py --dir c
93 | ```
94 |
95 | You may use [this](https://drive.google.com/open?id=1zKip_rlNY2Dk14fzzOHQm03HJNvJTjGT) pretrained model or train pix2pix from scratch using [this](https://drive.google.com/open?id=1sJBp5bYe3XSyE7ys5i7ABORFZctWEQhW) dataset. Unzip the dataset into the [pix2pix](https://github.com/affinelayer/pix2pix-tensorflow) main directory.
96 |
97 | ```
98 | python3 pix2pix.py --mode train --output_dir output --max_epochs 200 --input_dir c/train/ --which_direction AtoB
99 | ```
100 |
101 | To run the pix2pix trained model
102 | ```
103 | python3 pix2pix.py --mode test --output_dir test_out/ --input_dir c_test/ --checkpoint output/
104 | ```
105 |
106 | To convert images to video
107 | ```
108 | ffmpeg -r 30 -f image2 -s 256x256 -i %d-targets.png -vcodec libx264 -crf 25 ../targets.mp4
109 | ```
110 |
111 |
112 | ### Pretrained Model
113 |
114 | Link to the pretrained model and a subset of the data is here - [Link](https://drive.google.com/drive/folders/1QDRCWmqr87E3LWmYztqE7cpBZ3fALo-x?usp=sharing)
115 |
116 | [landmarks](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) file unzipped in the data folder
117 |
118 | Download and extract the checkpoints and the data folders into the repository. The file structure should look as shown below.
119 |
120 | ```
121 | obamanet
122 | |
123 | └─ data
124 | | | audios
125 | | | a2key_data
126 | | | shape_predictor_68_face_landmarks.dat
127 | | ...
128 | |
129 | └─ checkpoints
130 | | | output
131 | | | my_model.h5
132 | | ...
133 | └─ train.py
134 | └─ run.py
135 | └─ run.sh
136 | ...
137 | ```
138 |
139 |
140 | ### Running sample wav file
141 |
142 | Run the following commands
143 | ```
144 | conda run ./run.sh
145 | ```
146 | Example:
147 | ```
148 | conda run ./run.sh data/audios/karan.wav
149 | ```
150 | Feel free to experiment with different voices. However, the result will depend on how close your voice is to the subject we trained on.
151 |
152 |
153 | ### Citation
154 | ---
155 | If you use this code for your research, please cite the paper this code is based on: [ObamaNet: Photo-realistic lip-sync from text](https://arxiv.org/pdf/1801.01442) and also the amazing repository of pix2pix by affinelayer.
156 | ```
157 | Cite as arXiv:1801.01442v1 [cs.CV]
158 | ```
159 |
160 |
161 | ### FAQs
162 |
163 | * [What is target delay for RNN/LSTM?](https://stats.stackexchange.com/questions/154814/what-is-target-delay-in-the-context-of-rnn-lstm)
164 | * [Keras implementation of time delayed LSTM?](https://github.com/keras-team/keras/issues/6063)
165 | * [Another link for the above](https://github.com/joncox123/Cortexsys/issues/4)
--------------------------------------------------------------------------------
/pix2pix.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | import argparse
8 | import os
9 | import json
10 | import glob
11 | import random
12 | import collections
13 | import math
14 | import time
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--input_dir", help="path to folder containing images")
18 | parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
19 | parser.add_argument("--output_dir", required=True, help="where to put output files")
20 | parser.add_argument("--seed", type=int)
21 | parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
22 |
23 | parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
24 | parser.add_argument("--max_epochs", type=int, help="number of training epochs")
25 | parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps")
26 | parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
27 | parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
28 | parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps")
29 | parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
30 |
31 | parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
32 | parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
33 | parser.add_argument("--lab_colorization", action="store_true", help="split input image into brightness (A) and color (B)")
34 | parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
35 | parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"])
36 | parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
37 | parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
38 | parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256")
39 | parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
40 | parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
41 | parser.set_defaults(flip=True)
42 | parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
43 | parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
44 | parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
45 | parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
46 |
47 | # export options
48 | parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
49 | a = parser.parse_args()
50 |
51 | EPS = 1e-12
52 | CROP_SIZE = 256
53 |
54 | Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
55 | Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
56 |
57 |
58 | def preprocess(image):
59 | with tf.name_scope("preprocess"):
60 | # [0, 1] => [-1, 1]
61 | return image * 2 - 1
62 |
63 |
64 | def deprocess(image):
65 | with tf.name_scope("deprocess"):
66 | # [-1, 1] => [0, 1]
67 | return (image + 1) / 2
68 |
69 |
70 | def preprocess_lab(lab):
71 | with tf.name_scope("preprocess_lab"):
72 | L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
73 | # L_chan: black and white with input range [0, 100]
74 | # a_chan/b_chan: color channels with input range ~[-110, 110], not exact
75 | # [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1]
76 | return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]
77 |
78 |
79 | def deprocess_lab(L_chan, a_chan, b_chan):
80 | with tf.name_scope("deprocess_lab"):
81 | # this is axis=3 instead of axis=2 because we process individual images but deprocess batches
82 | return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
83 |
84 |
85 | def augment(image, brightness):
86 | # (a, b) color channels, combine with L channel and convert to rgb
87 | a_chan, b_chan = tf.unstack(image, axis=3)
88 | L_chan = tf.squeeze(brightness, axis=3)
89 | lab = deprocess_lab(L_chan, a_chan, b_chan)
90 | rgb = lab_to_rgb(lab)
91 | return rgb
92 |
93 |
94 | def discrim_conv(batch_input, out_channels, stride):
95 | padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
96 | return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02))
97 |
98 |
99 | def gen_conv(batch_input, out_channels):
100 | # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
101 | initializer = tf.random_normal_initializer(0, 0.02)
102 | if a.separable_conv:
103 | return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
104 | else:
105 | return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
106 |
107 |
108 | def gen_deconv(batch_input, out_channels):
109 | # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
110 | initializer = tf.random_normal_initializer(0, 0.02)
111 | if a.separable_conv:
112 | _b, h, w, _c = batch_input.shape
113 | resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
114 | return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
115 | else:
116 | return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
117 |
118 |
119 | def lrelu(x, a):
120 | with tf.name_scope("lrelu"):
121 | # adding these together creates the leak part and linear part
122 | # then cancels them out by subtracting/adding an absolute value term
123 | # leak: a*x/2 - a*abs(x)/2
124 | # linear: x/2 + abs(x)/2
125 |
126 | # this block looks like it has 2 inputs on the graph unless we do this
127 | x = tf.identity(x)
128 | return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
129 |
130 |
131 | def batchnorm(inputs):
132 | return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
133 |
134 |
135 | def check_image(image):
136 | assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
137 | with tf.control_dependencies([assertion]):
138 | image = tf.identity(image)
139 |
140 | if image.get_shape().ndims not in (3, 4):
141 | raise ValueError("image must be either 3 or 4 dimensions")
142 |
143 | # make the last dimension 3 so that you can unstack the colors
144 | shape = list(image.get_shape())
145 | shape[-1] = 3
146 | image.set_shape(shape)
147 | return image
148 |
149 | # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
150 | def rgb_to_lab(srgb):
151 | with tf.name_scope("rgb_to_lab"):
152 | srgb = check_image(srgb)
153 | srgb_pixels = tf.reshape(srgb, [-1, 3])
154 |
155 | with tf.name_scope("srgb_to_xyz"):
156 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
157 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
158 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
159 | rgb_to_xyz = tf.constant([
160 | # X Y Z
161 | [0.412453, 0.212671, 0.019334], # R
162 | [0.357580, 0.715160, 0.119193], # G
163 | [0.180423, 0.072169, 0.950227], # B
164 | ])
165 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
166 |
167 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
168 | with tf.name_scope("xyz_to_cielab"):
169 | # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
170 |
171 | # normalize for D65 white point
172 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
173 |
174 | epsilon = 6/29
175 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
176 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
177 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
178 |
179 | # convert to lab
180 | fxfyfz_to_lab = tf.constant([
181 | # l a b
182 | [ 0.0, 500.0, 0.0], # fx
183 | [116.0, -500.0, 200.0], # fy
184 | [ 0.0, 0.0, -200.0], # fz
185 | ])
186 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
187 |
188 | return tf.reshape(lab_pixels, tf.shape(srgb))
189 |
190 |
191 | def lab_to_rgb(lab):
192 | with tf.name_scope("lab_to_rgb"):
193 | lab = check_image(lab)
194 | lab_pixels = tf.reshape(lab, [-1, 3])
195 |
196 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
197 | with tf.name_scope("cielab_to_xyz"):
198 | # convert to fxfyfz
199 | lab_to_fxfyfz = tf.constant([
200 | # fx fy fz
201 | [1/116.0, 1/116.0, 1/116.0], # l
202 | [1/500.0, 0.0, 0.0], # a
203 | [ 0.0, 0.0, -1/200.0], # b
204 | ])
205 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
206 |
207 | # convert to xyz
208 | epsilon = 6/29
209 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
210 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
211 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
212 |
213 | # denormalize for D65 white point
214 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
215 |
216 | with tf.name_scope("xyz_to_srgb"):
217 | xyz_to_rgb = tf.constant([
218 | # r g b
219 | [ 3.2404542, -0.9692660, 0.0556434], # x
220 | [-1.5371385, 1.8760108, -0.2040259], # y
221 | [-0.4985314, 0.0415560, 1.0572252], # z
222 | ])
223 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
224 | # avoid a slightly negative number messing up the conversion
225 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
226 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
227 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
228 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
229 |
230 | return tf.reshape(srgb_pixels, tf.shape(lab))
231 |
232 |
233 | def load_examples():
234 | if a.input_dir is None or not os.path.exists(a.input_dir):
235 | raise Exception("input_dir does not exist")
236 |
237 | input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
238 | decode = tf.image.decode_jpeg
239 | if len(input_paths) == 0:
240 | input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
241 | decode = tf.image.decode_png
242 |
243 | if len(input_paths) == 0:
244 | raise Exception("input_dir contains no image files")
245 |
246 | def get_name(path):
247 | name, _ = os.path.splitext(os.path.basename(path))
248 | return name
249 |
250 | # if the image names are numbers, sort by the value rather than asciibetically
251 | # having sorted inputs means that the outputs are sorted in test mode
252 | if all(get_name(path).isdigit() for path in input_paths):
253 | input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
254 | else:
255 | input_paths = sorted(input_paths)
256 |
257 | with tf.name_scope("load_images"):
258 | path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
259 | reader = tf.WholeFileReader()
260 | paths, contents = reader.read(path_queue)
261 | raw_input = decode(contents)
262 | raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
263 |
264 | assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels")
265 | with tf.control_dependencies([assertion]):
266 | raw_input = tf.identity(raw_input)
267 |
268 | raw_input.set_shape([None, None, 3])
269 |
270 | if a.lab_colorization:
271 | # load color and brightness from image, no B image exists here
272 | lab = rgb_to_lab(raw_input)
273 | L_chan, a_chan, b_chan = preprocess_lab(lab)
274 | a_images = tf.expand_dims(L_chan, axis=2)
275 | b_images = tf.stack([a_chan, b_chan], axis=2)
276 | else:
277 | # break apart image pair and move to range [-1, 1]
278 | width = tf.shape(raw_input)[1] # [height, width, channels]
279 | a_images = preprocess(raw_input[:,:width//2,:])
280 | b_images = preprocess(raw_input[:,width//2:,:])
281 |
282 | if a.which_direction == "AtoB":
283 | inputs, targets = [a_images, b_images]
284 | elif a.which_direction == "BtoA":
285 | inputs, targets = [b_images, a_images]
286 | else:
287 | raise Exception("invalid direction")
288 |
289 | # synchronize seed for image operations so that we do the same operations to both
290 | # input and output images
291 | seed = random.randint(0, 2**31 - 1)
292 | def transform(image):
293 | r = image
294 | if a.flip:
295 | r = tf.image.random_flip_left_right(r, seed=seed)
296 |
297 | # area produces a nice downscaling, but does nearest neighbor for upscaling
298 | # assume we're going to be doing downscaling here
299 | r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
300 |
301 | offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
302 | if a.scale_size > CROP_SIZE:
303 | r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
304 | elif a.scale_size < CROP_SIZE:
305 | raise Exception("scale size cannot be less than crop size")
306 | return r
307 |
308 | with tf.name_scope("input_images"):
309 | input_images = transform(inputs)
310 |
311 | with tf.name_scope("target_images"):
312 | target_images = transform(targets)
313 |
314 | paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
315 | steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
316 |
317 | return Examples(
318 | paths=paths_batch,
319 | inputs=inputs_batch,
320 | targets=targets_batch,
321 | count=len(input_paths),
322 | steps_per_epoch=steps_per_epoch,
323 | )
324 |
325 |
326 | def create_generator(generator_inputs, generator_outputs_channels):
327 | layers = []
328 |
329 | # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
330 | with tf.variable_scope("encoder_1"):
331 | output = gen_conv(generator_inputs, a.ngf)
332 | layers.append(output)
333 |
334 | layer_specs = [
335 | a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
336 | a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
337 | a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
338 | a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
339 | a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
340 | a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
341 | a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
342 | ]
343 |
344 | for out_channels in layer_specs:
345 | with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
346 | rectified = lrelu(layers[-1], 0.2)
347 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
348 | convolved = gen_conv(rectified, out_channels)
349 | output = batchnorm(convolved)
350 | layers.append(output)
351 |
352 | layer_specs = [
353 | (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
354 | (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
355 | (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
356 | (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
357 | (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
358 | (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
359 | (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
360 | ]
361 |
362 | num_encoder_layers = len(layers)
363 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
364 | skip_layer = num_encoder_layers - decoder_layer - 1
365 | with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
366 | if decoder_layer == 0:
367 | # first decoder layer doesn't have skip connections
368 | # since it is directly connected to the skip_layer
369 | input = layers[-1]
370 | else:
371 | input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
372 |
373 | rectified = tf.nn.relu(input)
374 | # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
375 | output = gen_deconv(rectified, out_channels)
376 | output = batchnorm(output)
377 |
378 | if dropout > 0.0:
379 | output = tf.nn.dropout(output, keep_prob=1 - dropout)
380 |
381 | layers.append(output)
382 |
383 | # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
384 | with tf.variable_scope("decoder_1"):
385 | input = tf.concat([layers[-1], layers[0]], axis=3)
386 | rectified = tf.nn.relu(input)
387 | output = gen_deconv(rectified, generator_outputs_channels)
388 | output = tf.tanh(output)
389 | layers.append(output)
390 |
391 | return layers[-1]
392 |
393 |
394 | def create_model(inputs, targets):
395 | def create_discriminator(discrim_inputs, discrim_targets):
396 | n_layers = 3
397 | layers = []
398 |
399 | # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
400 | input = tf.concat([discrim_inputs, discrim_targets], axis=3)
401 |
402 | # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
403 | with tf.variable_scope("layer_1"):
404 | convolved = discrim_conv(input, a.ndf, stride=2)
405 | rectified = lrelu(convolved, 0.2)
406 | layers.append(rectified)
407 |
408 | # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
409 | # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
410 | # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
411 | for i in range(n_layers):
412 | with tf.variable_scope("layer_%d" % (len(layers) + 1)):
413 | out_channels = a.ndf * min(2**(i+1), 8)
414 | stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
415 | convolved = discrim_conv(layers[-1], out_channels, stride=stride)
416 | normalized = batchnorm(convolved)
417 | rectified = lrelu(normalized, 0.2)
418 | layers.append(rectified)
419 |
420 | # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
421 | with tf.variable_scope("layer_%d" % (len(layers) + 1)):
422 | convolved = discrim_conv(rectified, out_channels=1, stride=1)
423 | output = tf.sigmoid(convolved)
424 | layers.append(output)
425 |
426 | return layers[-1]
427 |
428 | with tf.variable_scope("generator"):
429 | out_channels = int(targets.get_shape()[-1])
430 | outputs = create_generator(inputs, out_channels)
431 |
432 | # create two copies of discriminator, one for real pairs and one for fake pairs
433 | # they share the same underlying variables
434 | with tf.name_scope("real_discriminator"):
435 | with tf.variable_scope("discriminator"):
436 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
437 | predict_real = create_discriminator(inputs, targets)
438 |
439 | with tf.name_scope("fake_discriminator"):
440 | with tf.variable_scope("discriminator", reuse=True):
441 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
442 | predict_fake = create_discriminator(inputs, outputs)
443 |
444 | with tf.name_scope("discriminator_loss"):
445 | # minimizing -tf.log will try to get inputs to 1
446 | # predict_real => 1
447 | # predict_fake => 0
448 | discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
449 |
450 | with tf.name_scope("generator_loss"):
451 | # predict_fake => 1
452 | # abs(targets - outputs) => 0
453 | gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
454 | gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) # tf.nn.l2_loss(targets - outputs)
455 | gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight
456 |
457 | with tf.name_scope("discriminator_train"):
458 | discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
459 | discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
460 | discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
461 | discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
462 |
463 | with tf.name_scope("generator_train"):
464 | with tf.control_dependencies([discrim_train]):
465 | gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
466 | gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
467 | gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
468 | gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
469 |
470 | ema = tf.train.ExponentialMovingAverage(decay=0.99)
471 | update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
472 |
473 | global_step = tf.train.get_or_create_global_step()
474 | incr_global_step = tf.assign(global_step, global_step+1)
475 |
476 | return Model(
477 | predict_real=predict_real,
478 | predict_fake=predict_fake,
479 | discrim_loss=ema.average(discrim_loss),
480 | discrim_grads_and_vars=discrim_grads_and_vars,
481 | gen_loss_GAN=ema.average(gen_loss_GAN),
482 | gen_loss_L1=ema.average(gen_loss_L1),
483 | gen_grads_and_vars=gen_grads_and_vars,
484 | outputs=outputs,
485 | train=tf.group(update_losses, incr_global_step, gen_train),
486 | )
487 |
488 |
489 | def save_images(fetches, step=None):
490 | image_dir = os.path.join(a.output_dir, "images")
491 | if not os.path.exists(image_dir):
492 | os.makedirs(image_dir)
493 |
494 | filesets = []
495 | for i, in_path in enumerate(fetches["paths"]):
496 | name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
497 | fileset = {"name": name, "step": step}
498 | for kind in ["inputs", "outputs", "targets"]:
499 | filename = name + "-" + kind + ".png"
500 | if step is not None:
501 | filename = "%08d-%s" % (step, filename)
502 | fileset[kind] = filename
503 | out_path = os.path.join(image_dir, filename)
504 | contents = fetches[kind][i]
505 | with open(out_path, "wb") as f:
506 | f.write(contents)
507 | filesets.append(fileset)
508 | return filesets
509 |
510 |
511 | def append_index(filesets, step=False):
512 | index_path = os.path.join(a.output_dir, "index.html")
513 | if os.path.exists(index_path):
514 | index = open(index_path, "a")
515 | else:
516 | index = open(index_path, "w")
517 | index.write("")
518 | if step:
519 | index.write("step | ")
520 | index.write("name | input | output | target |
")
521 |
522 | for fileset in filesets:
523 | index.write("")
524 |
525 | if step:
526 | index.write("%d | " % fileset["step"])
527 | index.write("%s | " % fileset["name"])
528 |
529 | for kind in ["inputs", "outputs", "targets"]:
530 | index.write(" | " % fileset[kind])
531 |
532 | index.write("
")
533 | return index_path
534 |
535 |
536 | def main():
537 | if a.seed is None:
538 | a.seed = random.randint(0, 2**31 - 1)
539 |
540 | tf.set_random_seed(a.seed)
541 | np.random.seed(a.seed)
542 | random.seed(a.seed)
543 |
544 | if not os.path.exists(a.output_dir):
545 | os.makedirs(a.output_dir)
546 |
547 | if a.mode == "test" or a.mode == "export":
548 | if a.checkpoint is None:
549 | raise Exception("checkpoint required for test mode")
550 |
551 | # load some options from the checkpoint
552 | options = {"which_direction", "ngf", "ndf", "lab_colorization"}
553 | with open(os.path.join(a.checkpoint, "options.json")) as f:
554 | for key, val in json.loads(f.read()).items():
555 | if key in options:
556 | print("loaded", key, "=", val)
557 | setattr(a, key, val)
558 | # disable these features in test mode
559 | a.scale_size = CROP_SIZE
560 | a.flip = False
561 |
562 | for k, v in a._get_kwargs():
563 | print(k, "=", v)
564 |
565 | with open(os.path.join(a.output_dir, "options.json"), "w") as f:
566 | f.write(json.dumps(vars(a), sort_keys=True, indent=4))
567 |
568 | if a.mode == "export":
569 | # export the generator to a meta graph that can be imported later for standalone generation
570 | if a.lab_colorization:
571 | raise Exception("export not supported for lab_colorization")
572 |
573 | input = tf.placeholder(tf.string, shape=[1])
574 | input_data = tf.decode_base64(input[0])
575 | input_image = tf.image.decode_png(input_data)
576 |
577 | # remove alpha channel if present
578 | input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image)
579 | # convert grayscale to RGB
580 | input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image)
581 |
582 | input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
583 | input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
584 | batch_input = tf.expand_dims(input_image, axis=0)
585 |
586 | with tf.variable_scope("generator"):
587 | batch_output = deprocess(create_generator(preprocess(batch_input), 3))
588 |
589 | output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]
590 | if a.output_filetype == "png":
591 | output_data = tf.image.encode_png(output_image)
592 | elif a.output_filetype == "jpeg":
593 | output_data = tf.image.encode_jpeg(output_image, quality=80)
594 | else:
595 | raise Exception("invalid filetype")
596 | output = tf.convert_to_tensor([tf.encode_base64(output_data)])
597 |
598 | key = tf.placeholder(tf.string, shape=[1])
599 | inputs = {
600 | "key": key.name,
601 | "input": input.name
602 | }
603 | tf.add_to_collection("inputs", json.dumps(inputs))
604 | outputs = {
605 | "key": tf.identity(key).name,
606 | "output": output.name,
607 | }
608 | tf.add_to_collection("outputs", json.dumps(outputs))
609 |
610 | init_op = tf.global_variables_initializer()
611 | restore_saver = tf.train.Saver()
612 | export_saver = tf.train.Saver()
613 |
614 | with tf.Session() as sess:
615 | sess.run(init_op)
616 | print("loading model from checkpoint")
617 | checkpoint = tf.train.latest_checkpoint(a.checkpoint)
618 | restore_saver.restore(sess, checkpoint)
619 | print("exporting model")
620 | export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
621 | export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False)
622 |
623 | return
624 |
625 | examples = load_examples()
626 | print("examples count = %d" % examples.count)
627 |
628 | # inputs and targets are [batch_size, height, width, channels]
629 | model = create_model(examples.inputs, examples.targets)
630 |
631 | # undo colorization splitting on images that we use for display/output
632 | if a.lab_colorization:
633 | if a.which_direction == "AtoB":
634 | # inputs is brightness, this will be handled fine as a grayscale image
635 | # need to augment targets and outputs with brightness
636 | targets = augment(examples.targets, examples.inputs)
637 | outputs = augment(model.outputs, examples.inputs)
638 | # inputs can be deprocessed normally and handled as if they are single channel
639 | # grayscale images
640 | inputs = deprocess(examples.inputs)
641 | elif a.which_direction == "BtoA":
642 | # inputs will be color channels only, get brightness from targets
643 | inputs = augment(examples.inputs, examples.targets)
644 | targets = deprocess(examples.targets)
645 | outputs = deprocess(model.outputs)
646 | else:
647 | raise Exception("invalid direction")
648 | else:
649 | inputs = deprocess(examples.inputs)
650 | targets = deprocess(examples.targets)
651 | outputs = deprocess(model.outputs)
652 |
653 | def convert(image):
654 | if a.aspect_ratio != 1.0:
655 | # upscale to correct aspect ratio
656 | size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))]
657 | image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)
658 |
659 | return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
660 |
661 | # reverse any processing on images so they can be written to disk or displayed to user
662 | with tf.name_scope("convert_inputs"):
663 | converted_inputs = convert(inputs)
664 |
665 | with tf.name_scope("convert_targets"):
666 | converted_targets = convert(targets)
667 |
668 | with tf.name_scope("convert_outputs"):
669 | converted_outputs = convert(outputs)
670 |
671 | with tf.name_scope("encode_images"):
672 | display_fetches = {
673 | "paths": examples.paths,
674 | "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
675 | "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
676 | "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
677 | }
678 |
679 | # summaries
680 | with tf.name_scope("inputs_summary"):
681 | tf.summary.image("inputs", converted_inputs)
682 |
683 | with tf.name_scope("targets_summary"):
684 | tf.summary.image("targets", converted_targets)
685 |
686 | with tf.name_scope("outputs_summary"):
687 | tf.summary.image("outputs", converted_outputs)
688 |
689 | with tf.name_scope("predict_real_summary"):
690 | tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
691 |
692 | with tf.name_scope("predict_fake_summary"):
693 | tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))
694 |
695 | tf.summary.scalar("discriminator_loss", model.discrim_loss)
696 | tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
697 | tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
698 |
699 | for var in tf.trainable_variables():
700 | tf.summary.histogram(var.op.name + "/values", var)
701 |
702 | for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
703 | tf.summary.histogram(var.op.name + "/gradients", grad)
704 |
705 | with tf.name_scope("parameter_count"):
706 | parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
707 |
708 | saver = tf.train.Saver(max_to_keep=1)
709 |
710 | logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
711 | sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
712 | with sv.managed_session() as sess:
713 | print("parameter_count =", sess.run(parameter_count))
714 |
715 | if a.checkpoint is not None:
716 | print("loading model from checkpoint")
717 | checkpoint = tf.train.latest_checkpoint(a.checkpoint)
718 | saver.restore(sess, checkpoint)
719 |
720 | max_steps = 2**32
721 | if a.max_epochs is not None:
722 | max_steps = examples.steps_per_epoch * a.max_epochs
723 | if a.max_steps is not None:
724 | max_steps = a.max_steps
725 |
726 | if a.mode == "test":
727 | # testing
728 | # at most, process the test data once
729 | start = time.time()
730 | max_steps = min(examples.steps_per_epoch, max_steps)
731 | for step in range(max_steps):
732 | results = sess.run(display_fetches)
733 | filesets = save_images(results)
734 | for i, f in enumerate(filesets):
735 | print("evaluated image", f["name"])
736 | index_path = append_index(filesets)
737 | print("wrote index at", index_path)
738 | print("rate", (time.time() - start) / max_steps)
739 | else:
740 | # training
741 | start = time.time()
742 |
743 | for step in range(max_steps):
744 | def should(freq):
745 | return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)
746 |
747 | options = None
748 | run_metadata = None
749 | if should(a.trace_freq):
750 | options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
751 | run_metadata = tf.RunMetadata()
752 |
753 | fetches = {
754 | "train": model.train,
755 | "global_step": sv.global_step,
756 | }
757 |
758 | if should(a.progress_freq):
759 | fetches["discrim_loss"] = model.discrim_loss
760 | fetches["gen_loss_GAN"] = model.gen_loss_GAN
761 | fetches["gen_loss_L1"] = model.gen_loss_L1
762 |
763 | if should(a.summary_freq):
764 | fetches["summary"] = sv.summary_op
765 |
766 | if should(a.display_freq):
767 | fetches["display"] = display_fetches
768 |
769 | results = sess.run(fetches, options=options, run_metadata=run_metadata)
770 |
771 | if should(a.summary_freq):
772 | print("recording summary")
773 | sv.summary_writer.add_summary(results["summary"], results["global_step"])
774 |
775 | if should(a.display_freq):
776 | print("saving display images")
777 | filesets = save_images(results["display"], step=results["global_step"])
778 | append_index(filesets, step=True)
779 |
780 | if should(a.trace_freq):
781 | print("recording trace")
782 | sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
783 |
784 | if should(a.progress_freq):
785 | # global_step will have the correct step count if we resume from a checkpoint
786 | train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
787 | train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
788 | rate = (step + 1) * a.batch_size / (time.time() - start)
789 | remaining = (max_steps - step) * a.batch_size / rate
790 | print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60))
791 | print("discrim_loss", results["discrim_loss"])
792 | print("gen_loss_GAN", results["gen_loss_GAN"])
793 | print("gen_loss_L1", results["gen_loss_L1"])
794 |
795 | if should(a.save_freq):
796 | print("saving model")
797 | saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step)
798 |
799 | if sv.should_stop():
800 | break
801 |
802 |
803 | main()
804 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | keras==2.1.3
2 | webvtt-py==0.4.1
3 | tensorflow==1.15.0
4 | tensorboard==1.15.0
5 | python-speech-features==0.6
6 | dlib==19.9.0
7 | opencv-contrib-python==3.4.2.16
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 |
2 | from keras.models import Sequential
3 | from keras.layers import Dense, LSTM, Dropout, Embedding, Lambda, TimeDistributed
4 | import keras.backend as K
5 | from keras.preprocessing.sequence import pad_sequences
6 | from keras.models import load_model
7 | import keras
8 | from sklearn.preprocessing import MinMaxScaler
9 | import numpy as np
10 | from tqdm import tqdm
11 | import pickle as pkl
12 | from keras.callbacks import TensorBoard
13 | from time import time
14 | import cv2
15 | import scipy.io.wavfile as wav
16 | from python_speech_features import logfbank
17 | import subprocess
18 | import argparse
19 |
20 | #########################################################################################
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--sf", help="path to wav file")
24 | a = parser.parse_args()
25 |
26 | key_audio = a.sf # '00003' # '00001-003' # 'karan' # '00002-002' # '00002-007' #
27 | time_delay = 20
28 | look_back = 50
29 | n_epoch = 50
30 | outputFolder = 'testing_output_images/'
31 |
32 | #########################################################################################
33 |
34 | cmd = 'rm -rf '+outputFolder + '&& mkdir ' + outputFolder
35 | subprocess.call(cmd ,shell=True)
36 |
37 | #########################################################################################
38 |
39 | model = load_model('checkpoints/my_model.h5')
40 |
41 | #########################################################################################
42 |
43 | def subsample(y, fps_from = 100.0, fps_to = 29.97):
44 | factor = int(np.ceil(fps_from/fps_to))
45 | # Subsample the points
46 | new_y = np.zeros((int(y.shape[0]/factor), 20, 2)) #(timesteps, 20) = (500, 20x2)
47 | for idx in range(new_y.shape[0]):
48 | if not (idx*factor > y.shape[0]-1):
49 | # Get into (x, y) format
50 | new_y[idx, :, 0] = y[idx*factor, 0:20]
51 | new_y[idx, :, 1] = y[idx*factor, 20:]
52 | else:
53 | break
54 | # print('Subsampled y:', new_y.shape)
55 | new_y = [np.array(each) for each in new_y.tolist()]
56 | # print(len(new_y))
57 | return new_y
58 |
59 | def drawLips(keypoints, new_img, c = (255, 255, 255), th = 1, show = False):
60 |
61 | keypoints = np.float32(keypoints)
62 |
63 | for i in range(48, 59):
64 | cv2.line(new_img, tuple(keypoints[i]), tuple(keypoints[i+1]), color=c, thickness=th)
65 | cv2.line(new_img, tuple(keypoints[48]), tuple(keypoints[59]), color=c, thickness=th)
66 | cv2.line(new_img, tuple(keypoints[48]), tuple(keypoints[60]), color=c, thickness=th)
67 | cv2.line(new_img, tuple(keypoints[54]), tuple(keypoints[64]), color=c, thickness=th)
68 | cv2.line(new_img, tuple(keypoints[67]), tuple(keypoints[60]), color=c, thickness=th)
69 | for i in range(60, 67):
70 | cv2.line(new_img, tuple(keypoints[i]), tuple(keypoints[i+1]), color=c, thickness=th)
71 |
72 | if (show == True):
73 | cv2.imshow('lol', new_img)
74 | cv2.waitKey(10000)
75 |
76 | def getOriginalKeypoints(kp_features_mouth, N, tilt, mean):
77 | # Denormalize the points
78 | kp_dn = N * kp_features_mouth
79 | # Add the tilt
80 | x, y = kp_dn[:, 0], kp_dn[:, 1]
81 | c, s = np.cos(tilt), np.sin(tilt)
82 | x_dash, y_dash = x*c + y*s, -x*s + y*c
83 | kp_tilt = np.hstack((x_dash.reshape((-1,1)), y_dash.reshape((-1,1))))
84 | # Shift to the mean
85 | kp = kp_tilt + mean
86 | return kp
87 |
88 | #########################################################################################
89 |
90 | # Load the files
91 | # with open('data/audio_kp/audio_kp1467_mel.pickle', 'rb') as pkl_file:
92 | # audio_kp = pkl.load(pkl_file)
93 | with open('data/pca/pkp1467.pickle', 'rb') as pkl_file:
94 | video_kp = pkl.load(pkl_file)
95 | with open('data/pca/pca1467.pickle', 'rb') as pkl_file:
96 | pca = pkl.load(pkl_file)
97 | # Get the original keypoints file
98 | with open('data/a2key_data/kp_test.pickle', 'rb') as pkl_file:
99 | kp = pkl.load(pkl_file)
100 |
101 | # Get the data
102 | X, y = [], [] # Create the empty lists
103 | # audio = audio_kp[key_audio]
104 | video = video_kp['00001-000']
105 |
106 |
107 | # Get audio features
108 | (rate, sig) = wav.read(key_audio)
109 | audio = logfbank(sig,rate)
110 |
111 |
112 | # if (len(audio) > len(video)):
113 | # audio = audio[0:len(video)]
114 | # else:
115 | # video = video[0:len(audio)]
116 | start = (time_delay-look_back) if (time_delay-look_back > 0) else 0
117 | for i in range(start, len(audio)-look_back):
118 | a = np.array(audio[i:i+look_back])
119 | # v = np.array(video[i+look_back-time_delay]).reshape((1, -1))
120 | X.append(a)
121 | # y.append(v)
122 |
123 | for i in range(start, len(video)-look_back):
124 | v = np.array(video[i+look_back-time_delay]).reshape((1, -1))
125 | y.append(v)
126 |
127 | X = np.array(X)
128 | y = np.array(y)
129 | shapeX = X.shape
130 | shapey = y.shape
131 | print('Shapes:', X.shape)
132 | X = X.reshape(-1, X.shape[2])
133 | y = y.reshape(-1, y.shape[2])
134 | print('Shapes:', X.shape)
135 |
136 | scalerX = MinMaxScaler(feature_range=(0, 1))
137 | scalery = MinMaxScaler(feature_range=(0, 1))
138 |
139 | X = scalerX.fit_transform(X)
140 | y = scalery.fit_transform(y)
141 |
142 |
143 | X = X.reshape(shapeX)
144 | # y = y.reshape(shapey[0], shapey[2])
145 |
146 | # print('Shapes:', X.shape, y.shape)
147 | # print('X mean:', np.mean(X), 'X var:', np.var(X))
148 | # print('y mean:', np.mean(y), 'y var:', np.var(y))
149 |
150 | y_pred = model.predict(X)
151 |
152 | # Scale it up
153 | y_pred = scalery.inverse_transform(y_pred)
154 | # y = scalery.inverse_transform(y)
155 |
156 | y_pred = pca.inverse_transform(y_pred)
157 | # y = pca.inverse_transform(y)
158 |
159 | print('Upsampled number:', len(y_pred))
160 |
161 | y_pred = subsample(y_pred, 100, 34)
162 |
163 | # y = subsample(y, 100, 100)
164 |
165 | # error = np.mean(np.square(np.array(y_pred) - np.array(y)))
166 |
167 | # print('Error:', error)
168 |
169 | print('Subsampled number:', len(y_pred))
170 |
171 | # Visualization
172 | # Cut the other stream according to whichever is smaller
173 | if (len(kp) < len(y_pred)):
174 | n = len(kp)
175 | y_pred = y_pred[:n]
176 | else:
177 | n = len(y_pred)
178 | kp = kp[:n]
179 |
180 |
181 | for idx, (x, k) in enumerate(zip(y_pred, kp)):
182 |
183 | unit_mouth_kp, N, tilt, mean, unit_kp, keypoints = k[0], k[1], k[2], k[3], k[4], k[5]
184 | kps = getOriginalKeypoints(x, N, tilt, mean)
185 | keypoints[48:68] = kps
186 |
187 | imgfile = 'data/a2key_data/images/' + str(idx+1).rjust(5, '0') + '.png'
188 | im = cv2.imread(imgfile)
189 | drawLips(keypoints, im, c = (255, 255, 255), th = 1, show = False)
190 |
191 | # make it pix2pix style
192 | im_out = np.zeros_like(im)
193 | im1 = np.hstack((im, im_out))
194 | # print('Shape: ', im1.shape)
195 | cv2.imwrite(outputFolder + str(idx) + '.png', im1)
196 |
197 | print('Done writing', n, 'images')
198 |
199 | # cmd = 'rm -rf input0.mp4 && rm -rf output.mp4'
200 | # subprocess.call(cmd ,shell=True)
201 |
202 | # cmd = 'ffmpeg -r 30 -f image2 -s 256x256 -i output_images/%d.png -vcodec libx264 -crf 25 input0.mp4 && ffmpeg -i input0.mp4 -i data/audios/00003.wav -c:v copy -c:a aac -strict experimental output.mp4 && rm -rf output_images/*.png'
203 | # subprocess.call(cmd ,shell=True)
204 |
205 | # cmd = 'rm -rf input0.mp4'
206 | # subprocess.call(cmd ,shell=True)
207 |
208 | # cmd = 'rm -rf output_images'
209 | # subprocess.call(cmd ,shell=True)
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | SF="$1"
2 |
3 | rm -rf input.mp4
4 | rm -rf output.mp4
5 |
6 | python3 run.py --sf $SF &&
7 | python3 pix2pix.py --mode test --output_dir test_output/ --input_dir testing_output_images/ --checkpoint checkpoints/output/ &&
8 |
9 | ffmpeg -r 30 -f image2 -s 256x256 -i test_output/images/%d-outputs.png -vcodec libx264 -crf 25 output0.mp4 &&
10 | ffmpeg -r 30 -f image2 -s 256x256 -i test_output/images/%d-inputs.png -vcodec libx264 -crf 25 input0.mp4 &&
11 | # ffmpeg -i output0.mp4 -i audio_testing.wav -c:v copy -c:a aac -strict experimental output.mp4
12 | # audios/00001-000.wav
13 | ffmpeg -i $SF output_audio_trim.wav &&
14 |
15 | ffmpeg -i output0.mp4 -i output_audio_trim.wav -c:v copy -c:a aac output.mp4 &&
16 | ffmpeg -i input0.mp4 -i output_audio_trim.wav -c:v copy -c:a aac input.mp4 &&
17 |
18 | rm -rf testing_output_images
19 | rm -rf test_output
20 | rm -rf output0.mp4
21 | rm -rf input0.mp4
22 | rm -rf output_audio_trim.wav
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from keras.models import Sequential
2 | from keras.layers import Dense, LSTM, Dropout, Embedding, Lambda, TimeDistributed
3 | import keras.backend as K
4 | from keras.preprocessing.sequence import pad_sequences
5 | from keras.models import load_model
6 | import keras
7 | from sklearn.preprocessing import MinMaxScaler
8 | import numpy as np
9 | from tqdm import tqdm
10 | import pickle as pkl
11 | from keras.callbacks import TensorBoard
12 | from time import time
13 |
14 | #########################################################################################
15 |
16 | time_delay = 20 #0
17 | look_back = 50
18 | n_epoch = 50
19 | n_videos = 50
20 | tbCallback = TensorBoard(log_dir="logs/{}".format(time())) # TensorBoard(log_dir='./Graph', histogram_freq=0, batch_size=n_batch, write_graph=True, write_images=True)
21 |
22 | #########################################################################################
23 |
24 | # Load the files
25 | with open('data/audio_kp/audio_kp1467_mel.pickle', 'rb') as pkl_file:
26 | audio_kp = pkl.load(pkl_file)
27 | with open('data/pca/pkp1467.pickle', 'rb') as pkl_file:
28 | video_kp = pkl.load(pkl_file)
29 | with open('data/pca/pca1467.pickle', 'rb') as pkl_file:
30 | pca = pkl.load(pkl_file)
31 |
32 |
33 | # Get the data
34 |
35 | X, y = [], [] # Create the empty lists
36 | # Get the common keys
37 | keys_audio = audio_kp.keys()
38 | keys_video = video_kp.keys()
39 | keys = sorted(list(set(keys_audio).intersection(set(keys_video))))
40 | # print('Length of common keys:', len(keys), 'First common key:', keys[0])
41 |
42 | # X = np.array(X).reshape((-1, 26))
43 | # y = np.array(y).reshape((-1, 8))
44 |
45 | for key in tqdm(keys[0:n_videos]):
46 | audio = audio_kp[key]
47 | video = video_kp[key]
48 | if (len(audio) > len(video)):
49 | audio = audio[0:len(video)]
50 | else:
51 | video = video[0:len(audio)]
52 | start = (time_delay-look_back) if (time_delay-look_back > 0) else 0
53 | for i in range(start, len(audio)-look_back):
54 | a = np.array(audio[i:i+look_back])
55 | v = np.array(video[i+look_back-time_delay]).reshape((1, -1))
56 | X.append(a)
57 | y.append(v)
58 |
59 | X = np.array(X)
60 | y = np.array(y)
61 | shapeX = X.shape
62 | shapey = y.shape
63 | print('Shapes:', X.shape, y.shape)
64 | X = X.reshape(-1, X.shape[2])
65 | y = y.reshape(-1, y.shape[2])
66 | print('Shapes:', X.shape, y.shape)
67 |
68 | scalerX = MinMaxScaler(feature_range=(0, 1))
69 | scalery = MinMaxScaler(feature_range=(0, 1))
70 |
71 | X = scalerX.fit_transform(X)
72 | y = scalery.fit_transform(y)
73 |
74 |
75 | X = X.reshape(shapeX)
76 | y = y.reshape(shapey[0], shapey[2])
77 |
78 | print('Shapes:', X.shape, y.shape)
79 | print('X mean:', np.mean(X), 'X var:', np.var(X))
80 | print('y mean:', np.mean(y), 'y var:', np.var(y))
81 |
82 | split1 = int(0.8*X.shape[0])
83 | split2 = int(0.9*X.shape[0])
84 |
85 | train_X = X[0:split1]
86 | train_y = y[0:split1]
87 | val_X = X[split1:split2]
88 | val_y = y[split1:split2]
89 | test_X = X[split2:]
90 | test_y = y[split2:]
91 |
92 |
93 |
94 |
95 | # Initialize the model
96 |
97 | model = Sequential()
98 | model.add(LSTM(25, input_shape=(look_back, 26)))
99 | model.add(Dropout(0.25))
100 | model.add(Dense(8))
101 | model.compile(loss='mean_squared_error', optimizer='adam')
102 | print(model.summary())
103 | # model = load_model('my_model.h5')
104 |
105 | # train LSTM with validation data
106 | for i in tqdm(range(n_epoch)):
107 | print('Epoch', (i+1), '/', n_epoch, ' - ', int(100*(i+1)/n_epoch))
108 | model.fit(train_X, train_y, epochs=1, batch_size=1,
109 | verbose=1, shuffle=True, callbacks=[tbCallback], validation_data=(val_X, val_y))
110 | # model.reset_states()
111 | test_error = np.mean(np.square(test_y - model.predict(test_X)))
112 | # model.reset_states()
113 | print('Test Error: ', test_error)
114 |
115 | # Save the model
116 | model.save('my_model.h5')
117 | model.save_weights('my_model_weights.h5')
118 | print('Saved Model.')
119 |
120 |
121 |
122 | # X, num = audioToPrediction('audios/' + key_audio + '.wav')
123 | # y = model.predict(X, batch_size=n_batch)
124 | # y = y.reshape(y.shape[0]*y.shape[1], y.shape[2])
125 | # print('y:', y[0:num].shape)
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------