├── .gitignore ├── LICENSE ├── README.md ├── pix2pix.py ├── requirements.txt ├── run.py ├── run.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | data/ 3 | testing_output_images/ 4 | test_output/ 5 | input.mp4 6 | output.mp4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Karan Vivek Bhargava 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |

ObamaNet : Lip Sync from Audio

5 | 6 | 7 |

8 | NMPC 9 | NMPC 10 | NMPC 11 | NMPC 12 | NMPC 13 |

14 | 15 | ### List of Contents 16 | 17 | * [Requirements](https://github.com/karanvivekbhargava/obamanet#requirements) 18 | * [Data Extraction](https://github.com/karanvivekbhargava/obamanet#data_extraction) 19 | * [Data Preprocessing](https://github.com/karanvivekbhargava/obamanet#data_preprocessing) 20 | * [Training Different Models](https://github.com/karanvivekbhargava/obamanet#training_different_models) 21 | * [Pretrained Model](https://github.com/karanvivekbhargava/obamanet#pretrained_model) 22 | * [How to run an example](https://github.com/karanvivekbhargava/obamanet#how_to_run_an_example) 23 | * [Citations](https://github.com/karanvivekbhargava/obamanet#citations) 24 | * [FAQs](https://github.com/karanvivekbhargava/obamanet#faqs) 25 | 26 | 27 | ### Dependencies 28 | 29 | - [miniconda](https://docs.conda.io/en/latest/miniconda.html) 30 | - python 3.7 bash installer 31 | - ffmpeg (with x264 enabled) 32 | - `sudo apt-get install ffmpeg` 33 | - `brew install -i ffmpeg` 34 | - `./configure --enable-gpl --enable-libx264` 35 | 36 | ### Requirements 37 | 38 | You may install the requirements by running the following commands 39 | ``` 40 | conda init 41 | 42 | conda deactivate 43 | conda env remove -n obamanet 44 | conda create -n obamanet 45 | conda activate obamanet 46 | conda install python=3.7 pip 47 | conda install numpy scikit-learn scipy tqdm cmake 48 | conda run pip install -r requirements.txt 49 | ``` 50 | 51 | The project is built for python 3.5 and above. The other libraries are listed below 52 | * OpenCV (`sudo pip3 install opencv-contrib-python`) 53 | * Dlib (`sudo pip3 install dlib`) with [this](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) file unzipped in the data folder 54 | * Python Speech Features (`sudo pip3 install python-speech-features`) 55 | 56 | For a complete list refer to `requirements.txt` file. 57 | 58 | I used the tools below to extract and manipulate the data: 59 | * [YouTube-dl](https://github.com/rg3/youtube-dl#video-selection) 60 | 61 | ### Data Extraction 62 | --- 63 | I extracted the data from youtube using youtube-dl. It's perhaps the best downloader for youtube on linux. Commands for extracting particular streams are given below. 64 | 65 | * Subtitle Extraction 66 | ``` 67 | youtube-dl --sub-lang en --skip-download --write-sub --output '~/obamanet/data/captions/%(autonumber)s.%(ext)s' --batch-file ~/obamanet/data/obama_addresses.txt --ignore-config 68 | ``` 69 | * Video Extraction 70 | ``` 71 | youtube-dl --batch-file ~/obamanet/data/obama_addresses.txt -o '~/obamanet/data/videos/%(autonumber)s.%(ext)s' -f "best[height=720]" --autonumber-start 1 72 | ``` 73 | (Videos not available in 720p: 165) 74 | * Video to Audio Conversion 75 | ``` 76 | python3 vid2wav.py 77 | ``` 78 | * Video to Images 79 | ``` 80 | ffmpeg -i 00001.mp4 -r 1/5 -vf scale=-1:720 images/00001-$filename%05d.bmp 81 | ``` 82 | 83 | To convert from BMP format to JPG format, use the following in the directory 84 | ``` 85 | mogrify -format jpg *.bmp 86 | rm -rf *.bmp 87 | ``` 88 | 89 | Copy the patched images into folder `a` and the cropped images to folder `b` 90 | ``` 91 | python3 tools/process.py --input_dir a --b_dir b --operation combine --output_dir c 92 | python3 tools/split.py --dir c 93 | ``` 94 | 95 | You may use [this](https://drive.google.com/open?id=1zKip_rlNY2Dk14fzzOHQm03HJNvJTjGT) pretrained model or train pix2pix from scratch using [this](https://drive.google.com/open?id=1sJBp5bYe3XSyE7ys5i7ABORFZctWEQhW) dataset. Unzip the dataset into the [pix2pix](https://github.com/affinelayer/pix2pix-tensorflow) main directory. 96 | 97 | ``` 98 | python3 pix2pix.py --mode train --output_dir output --max_epochs 200 --input_dir c/train/ --which_direction AtoB 99 | ``` 100 | 101 | To run the pix2pix trained model 102 | ``` 103 | python3 pix2pix.py --mode test --output_dir test_out/ --input_dir c_test/ --checkpoint output/ 104 | ``` 105 | 106 | To convert images to video 107 | ``` 108 | ffmpeg -r 30 -f image2 -s 256x256 -i %d-targets.png -vcodec libx264 -crf 25 ../targets.mp4 109 | ``` 110 | 111 | 112 | ### Pretrained Model 113 | 114 | Link to the pretrained model and a subset of the data is here - [Link](https://drive.google.com/drive/folders/1QDRCWmqr87E3LWmYztqE7cpBZ3fALo-x?usp=sharing) 115 | 116 | [landmarks](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) file unzipped in the data folder 117 | 118 | Download and extract the checkpoints and the data folders into the repository. The file structure should look as shown below. 119 | 120 | ``` 121 | obamanet 122 | | 123 | └─ data 124 | | | audios 125 | | | a2key_data 126 | | | shape_predictor_68_face_landmarks.dat 127 | | ... 128 | | 129 | └─ checkpoints 130 | | | output 131 | | | my_model.h5 132 | | ... 133 | └─ train.py 134 | └─ run.py 135 | └─ run.sh 136 | ... 137 | ``` 138 | 139 | 140 | ### Running sample wav file 141 | 142 | Run the following commands 143 | ``` 144 | conda run ./run.sh 145 | ``` 146 | Example: 147 | ``` 148 | conda run ./run.sh data/audios/karan.wav 149 | ``` 150 | Feel free to experiment with different voices. However, the result will depend on how close your voice is to the subject we trained on. 151 | 152 | 153 | ### Citation 154 | --- 155 | If you use this code for your research, please cite the paper this code is based on: [ObamaNet: Photo-realistic lip-sync from text](https://arxiv.org/pdf/1801.01442) and also the amazing repository of pix2pix by affinelayer. 156 | ``` 157 | Cite as arXiv:1801.01442v1 [cs.CV] 158 | ``` 159 | 160 | 161 | ### FAQs 162 | 163 | * [What is target delay for RNN/LSTM?](https://stats.stackexchange.com/questions/154814/what-is-target-delay-in-the-context-of-rnn-lstm) 164 | * [Keras implementation of time delayed LSTM?](https://github.com/keras-team/keras/issues/6063) 165 | * [Another link for the above](https://github.com/joncox123/Cortexsys/issues/4) -------------------------------------------------------------------------------- /pix2pix.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import argparse 8 | import os 9 | import json 10 | import glob 11 | import random 12 | import collections 13 | import math 14 | import time 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--input_dir", help="path to folder containing images") 18 | parser.add_argument("--mode", required=True, choices=["train", "test", "export"]) 19 | parser.add_argument("--output_dir", required=True, help="where to put output files") 20 | parser.add_argument("--seed", type=int) 21 | parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing") 22 | 23 | parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)") 24 | parser.add_argument("--max_epochs", type=int, help="number of training epochs") 25 | parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps") 26 | parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps") 27 | parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps") 28 | parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps") 29 | parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable") 30 | 31 | parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator") 32 | parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)") 33 | parser.add_argument("--lab_colorization", action="store_true", help="split input image into brightness (A) and color (B)") 34 | parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch") 35 | parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"]) 36 | parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer") 37 | parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer") 38 | parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256") 39 | parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally") 40 | parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally") 41 | parser.set_defaults(flip=True) 42 | parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam") 43 | parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam") 44 | parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient") 45 | parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient") 46 | 47 | # export options 48 | parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"]) 49 | a = parser.parse_args() 50 | 51 | EPS = 1e-12 52 | CROP_SIZE = 256 53 | 54 | Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") 55 | Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train") 56 | 57 | 58 | def preprocess(image): 59 | with tf.name_scope("preprocess"): 60 | # [0, 1] => [-1, 1] 61 | return image * 2 - 1 62 | 63 | 64 | def deprocess(image): 65 | with tf.name_scope("deprocess"): 66 | # [-1, 1] => [0, 1] 67 | return (image + 1) / 2 68 | 69 | 70 | def preprocess_lab(lab): 71 | with tf.name_scope("preprocess_lab"): 72 | L_chan, a_chan, b_chan = tf.unstack(lab, axis=2) 73 | # L_chan: black and white with input range [0, 100] 74 | # a_chan/b_chan: color channels with input range ~[-110, 110], not exact 75 | # [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1] 76 | return [L_chan / 50 - 1, a_chan / 110, b_chan / 110] 77 | 78 | 79 | def deprocess_lab(L_chan, a_chan, b_chan): 80 | with tf.name_scope("deprocess_lab"): 81 | # this is axis=3 instead of axis=2 because we process individual images but deprocess batches 82 | return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3) 83 | 84 | 85 | def augment(image, brightness): 86 | # (a, b) color channels, combine with L channel and convert to rgb 87 | a_chan, b_chan = tf.unstack(image, axis=3) 88 | L_chan = tf.squeeze(brightness, axis=3) 89 | lab = deprocess_lab(L_chan, a_chan, b_chan) 90 | rgb = lab_to_rgb(lab) 91 | return rgb 92 | 93 | 94 | def discrim_conv(batch_input, out_channels, stride): 95 | padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") 96 | return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02)) 97 | 98 | 99 | def gen_conv(batch_input, out_channels): 100 | # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] 101 | initializer = tf.random_normal_initializer(0, 0.02) 102 | if a.separable_conv: 103 | return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) 104 | else: 105 | return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) 106 | 107 | 108 | def gen_deconv(batch_input, out_channels): 109 | # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] 110 | initializer = tf.random_normal_initializer(0, 0.02) 111 | if a.separable_conv: 112 | _b, h, w, _c = batch_input.shape 113 | resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 114 | return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) 115 | else: 116 | return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) 117 | 118 | 119 | def lrelu(x, a): 120 | with tf.name_scope("lrelu"): 121 | # adding these together creates the leak part and linear part 122 | # then cancels them out by subtracting/adding an absolute value term 123 | # leak: a*x/2 - a*abs(x)/2 124 | # linear: x/2 + abs(x)/2 125 | 126 | # this block looks like it has 2 inputs on the graph unless we do this 127 | x = tf.identity(x) 128 | return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) 129 | 130 | 131 | def batchnorm(inputs): 132 | return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02)) 133 | 134 | 135 | def check_image(image): 136 | assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels") 137 | with tf.control_dependencies([assertion]): 138 | image = tf.identity(image) 139 | 140 | if image.get_shape().ndims not in (3, 4): 141 | raise ValueError("image must be either 3 or 4 dimensions") 142 | 143 | # make the last dimension 3 so that you can unstack the colors 144 | shape = list(image.get_shape()) 145 | shape[-1] = 3 146 | image.set_shape(shape) 147 | return image 148 | 149 | # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c 150 | def rgb_to_lab(srgb): 151 | with tf.name_scope("rgb_to_lab"): 152 | srgb = check_image(srgb) 153 | srgb_pixels = tf.reshape(srgb, [-1, 3]) 154 | 155 | with tf.name_scope("srgb_to_xyz"): 156 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) 157 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) 158 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 159 | rgb_to_xyz = tf.constant([ 160 | # X Y Z 161 | [0.412453, 0.212671, 0.019334], # R 162 | [0.357580, 0.715160, 0.119193], # G 163 | [0.180423, 0.072169, 0.950227], # B 164 | ]) 165 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) 166 | 167 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 168 | with tf.name_scope("xyz_to_cielab"): 169 | # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) 170 | 171 | # normalize for D65 white point 172 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754]) 173 | 174 | epsilon = 6/29 175 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) 176 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) 177 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask 178 | 179 | # convert to lab 180 | fxfyfz_to_lab = tf.constant([ 181 | # l a b 182 | [ 0.0, 500.0, 0.0], # fx 183 | [116.0, -500.0, 200.0], # fy 184 | [ 0.0, 0.0, -200.0], # fz 185 | ]) 186 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) 187 | 188 | return tf.reshape(lab_pixels, tf.shape(srgb)) 189 | 190 | 191 | def lab_to_rgb(lab): 192 | with tf.name_scope("lab_to_rgb"): 193 | lab = check_image(lab) 194 | lab_pixels = tf.reshape(lab, [-1, 3]) 195 | 196 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 197 | with tf.name_scope("cielab_to_xyz"): 198 | # convert to fxfyfz 199 | lab_to_fxfyfz = tf.constant([ 200 | # fx fy fz 201 | [1/116.0, 1/116.0, 1/116.0], # l 202 | [1/500.0, 0.0, 0.0], # a 203 | [ 0.0, 0.0, -1/200.0], # b 204 | ]) 205 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) 206 | 207 | # convert to xyz 208 | epsilon = 6/29 209 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) 210 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) 211 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask 212 | 213 | # denormalize for D65 white point 214 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) 215 | 216 | with tf.name_scope("xyz_to_srgb"): 217 | xyz_to_rgb = tf.constant([ 218 | # r g b 219 | [ 3.2404542, -0.9692660, 0.0556434], # x 220 | [-1.5371385, 1.8760108, -0.2040259], # y 221 | [-0.4985314, 0.0415560, 1.0572252], # z 222 | ]) 223 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) 224 | # avoid a slightly negative number messing up the conversion 225 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) 226 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) 227 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) 228 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask 229 | 230 | return tf.reshape(srgb_pixels, tf.shape(lab)) 231 | 232 | 233 | def load_examples(): 234 | if a.input_dir is None or not os.path.exists(a.input_dir): 235 | raise Exception("input_dir does not exist") 236 | 237 | input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg")) 238 | decode = tf.image.decode_jpeg 239 | if len(input_paths) == 0: 240 | input_paths = glob.glob(os.path.join(a.input_dir, "*.png")) 241 | decode = tf.image.decode_png 242 | 243 | if len(input_paths) == 0: 244 | raise Exception("input_dir contains no image files") 245 | 246 | def get_name(path): 247 | name, _ = os.path.splitext(os.path.basename(path)) 248 | return name 249 | 250 | # if the image names are numbers, sort by the value rather than asciibetically 251 | # having sorted inputs means that the outputs are sorted in test mode 252 | if all(get_name(path).isdigit() for path in input_paths): 253 | input_paths = sorted(input_paths, key=lambda path: int(get_name(path))) 254 | else: 255 | input_paths = sorted(input_paths) 256 | 257 | with tf.name_scope("load_images"): 258 | path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train") 259 | reader = tf.WholeFileReader() 260 | paths, contents = reader.read(path_queue) 261 | raw_input = decode(contents) 262 | raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) 263 | 264 | assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") 265 | with tf.control_dependencies([assertion]): 266 | raw_input = tf.identity(raw_input) 267 | 268 | raw_input.set_shape([None, None, 3]) 269 | 270 | if a.lab_colorization: 271 | # load color and brightness from image, no B image exists here 272 | lab = rgb_to_lab(raw_input) 273 | L_chan, a_chan, b_chan = preprocess_lab(lab) 274 | a_images = tf.expand_dims(L_chan, axis=2) 275 | b_images = tf.stack([a_chan, b_chan], axis=2) 276 | else: 277 | # break apart image pair and move to range [-1, 1] 278 | width = tf.shape(raw_input)[1] # [height, width, channels] 279 | a_images = preprocess(raw_input[:,:width//2,:]) 280 | b_images = preprocess(raw_input[:,width//2:,:]) 281 | 282 | if a.which_direction == "AtoB": 283 | inputs, targets = [a_images, b_images] 284 | elif a.which_direction == "BtoA": 285 | inputs, targets = [b_images, a_images] 286 | else: 287 | raise Exception("invalid direction") 288 | 289 | # synchronize seed for image operations so that we do the same operations to both 290 | # input and output images 291 | seed = random.randint(0, 2**31 - 1) 292 | def transform(image): 293 | r = image 294 | if a.flip: 295 | r = tf.image.random_flip_left_right(r, seed=seed) 296 | 297 | # area produces a nice downscaling, but does nearest neighbor for upscaling 298 | # assume we're going to be doing downscaling here 299 | r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA) 300 | 301 | offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32) 302 | if a.scale_size > CROP_SIZE: 303 | r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE) 304 | elif a.scale_size < CROP_SIZE: 305 | raise Exception("scale size cannot be less than crop size") 306 | return r 307 | 308 | with tf.name_scope("input_images"): 309 | input_images = transform(inputs) 310 | 311 | with tf.name_scope("target_images"): 312 | target_images = transform(targets) 313 | 314 | paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size) 315 | steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size)) 316 | 317 | return Examples( 318 | paths=paths_batch, 319 | inputs=inputs_batch, 320 | targets=targets_batch, 321 | count=len(input_paths), 322 | steps_per_epoch=steps_per_epoch, 323 | ) 324 | 325 | 326 | def create_generator(generator_inputs, generator_outputs_channels): 327 | layers = [] 328 | 329 | # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] 330 | with tf.variable_scope("encoder_1"): 331 | output = gen_conv(generator_inputs, a.ngf) 332 | layers.append(output) 333 | 334 | layer_specs = [ 335 | a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] 336 | a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] 337 | a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] 338 | a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] 339 | a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] 340 | a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] 341 | a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] 342 | ] 343 | 344 | for out_channels in layer_specs: 345 | with tf.variable_scope("encoder_%d" % (len(layers) + 1)): 346 | rectified = lrelu(layers[-1], 0.2) 347 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] 348 | convolved = gen_conv(rectified, out_channels) 349 | output = batchnorm(convolved) 350 | layers.append(output) 351 | 352 | layer_specs = [ 353 | (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] 354 | (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] 355 | (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] 356 | (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] 357 | (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] 358 | (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] 359 | (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] 360 | ] 361 | 362 | num_encoder_layers = len(layers) 363 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): 364 | skip_layer = num_encoder_layers - decoder_layer - 1 365 | with tf.variable_scope("decoder_%d" % (skip_layer + 1)): 366 | if decoder_layer == 0: 367 | # first decoder layer doesn't have skip connections 368 | # since it is directly connected to the skip_layer 369 | input = layers[-1] 370 | else: 371 | input = tf.concat([layers[-1], layers[skip_layer]], axis=3) 372 | 373 | rectified = tf.nn.relu(input) 374 | # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] 375 | output = gen_deconv(rectified, out_channels) 376 | output = batchnorm(output) 377 | 378 | if dropout > 0.0: 379 | output = tf.nn.dropout(output, keep_prob=1 - dropout) 380 | 381 | layers.append(output) 382 | 383 | # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] 384 | with tf.variable_scope("decoder_1"): 385 | input = tf.concat([layers[-1], layers[0]], axis=3) 386 | rectified = tf.nn.relu(input) 387 | output = gen_deconv(rectified, generator_outputs_channels) 388 | output = tf.tanh(output) 389 | layers.append(output) 390 | 391 | return layers[-1] 392 | 393 | 394 | def create_model(inputs, targets): 395 | def create_discriminator(discrim_inputs, discrim_targets): 396 | n_layers = 3 397 | layers = [] 398 | 399 | # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] 400 | input = tf.concat([discrim_inputs, discrim_targets], axis=3) 401 | 402 | # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] 403 | with tf.variable_scope("layer_1"): 404 | convolved = discrim_conv(input, a.ndf, stride=2) 405 | rectified = lrelu(convolved, 0.2) 406 | layers.append(rectified) 407 | 408 | # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] 409 | # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] 410 | # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] 411 | for i in range(n_layers): 412 | with tf.variable_scope("layer_%d" % (len(layers) + 1)): 413 | out_channels = a.ndf * min(2**(i+1), 8) 414 | stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 415 | convolved = discrim_conv(layers[-1], out_channels, stride=stride) 416 | normalized = batchnorm(convolved) 417 | rectified = lrelu(normalized, 0.2) 418 | layers.append(rectified) 419 | 420 | # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] 421 | with tf.variable_scope("layer_%d" % (len(layers) + 1)): 422 | convolved = discrim_conv(rectified, out_channels=1, stride=1) 423 | output = tf.sigmoid(convolved) 424 | layers.append(output) 425 | 426 | return layers[-1] 427 | 428 | with tf.variable_scope("generator"): 429 | out_channels = int(targets.get_shape()[-1]) 430 | outputs = create_generator(inputs, out_channels) 431 | 432 | # create two copies of discriminator, one for real pairs and one for fake pairs 433 | # they share the same underlying variables 434 | with tf.name_scope("real_discriminator"): 435 | with tf.variable_scope("discriminator"): 436 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1] 437 | predict_real = create_discriminator(inputs, targets) 438 | 439 | with tf.name_scope("fake_discriminator"): 440 | with tf.variable_scope("discriminator", reuse=True): 441 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1] 442 | predict_fake = create_discriminator(inputs, outputs) 443 | 444 | with tf.name_scope("discriminator_loss"): 445 | # minimizing -tf.log will try to get inputs to 1 446 | # predict_real => 1 447 | # predict_fake => 0 448 | discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) 449 | 450 | with tf.name_scope("generator_loss"): 451 | # predict_fake => 1 452 | # abs(targets - outputs) => 0 453 | gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) 454 | gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) # tf.nn.l2_loss(targets - outputs) 455 | gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight 456 | 457 | with tf.name_scope("discriminator_train"): 458 | discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] 459 | discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) 460 | discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) 461 | discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) 462 | 463 | with tf.name_scope("generator_train"): 464 | with tf.control_dependencies([discrim_train]): 465 | gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] 466 | gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) 467 | gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) 468 | gen_train = gen_optim.apply_gradients(gen_grads_and_vars) 469 | 470 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 471 | update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) 472 | 473 | global_step = tf.train.get_or_create_global_step() 474 | incr_global_step = tf.assign(global_step, global_step+1) 475 | 476 | return Model( 477 | predict_real=predict_real, 478 | predict_fake=predict_fake, 479 | discrim_loss=ema.average(discrim_loss), 480 | discrim_grads_and_vars=discrim_grads_and_vars, 481 | gen_loss_GAN=ema.average(gen_loss_GAN), 482 | gen_loss_L1=ema.average(gen_loss_L1), 483 | gen_grads_and_vars=gen_grads_and_vars, 484 | outputs=outputs, 485 | train=tf.group(update_losses, incr_global_step, gen_train), 486 | ) 487 | 488 | 489 | def save_images(fetches, step=None): 490 | image_dir = os.path.join(a.output_dir, "images") 491 | if not os.path.exists(image_dir): 492 | os.makedirs(image_dir) 493 | 494 | filesets = [] 495 | for i, in_path in enumerate(fetches["paths"]): 496 | name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) 497 | fileset = {"name": name, "step": step} 498 | for kind in ["inputs", "outputs", "targets"]: 499 | filename = name + "-" + kind + ".png" 500 | if step is not None: 501 | filename = "%08d-%s" % (step, filename) 502 | fileset[kind] = filename 503 | out_path = os.path.join(image_dir, filename) 504 | contents = fetches[kind][i] 505 | with open(out_path, "wb") as f: 506 | f.write(contents) 507 | filesets.append(fileset) 508 | return filesets 509 | 510 | 511 | def append_index(filesets, step=False): 512 | index_path = os.path.join(a.output_dir, "index.html") 513 | if os.path.exists(index_path): 514 | index = open(index_path, "a") 515 | else: 516 | index = open(index_path, "w") 517 | index.write("") 518 | if step: 519 | index.write("") 520 | index.write("") 521 | 522 | for fileset in filesets: 523 | index.write("") 524 | 525 | if step: 526 | index.write("" % fileset["step"]) 527 | index.write("" % fileset["name"]) 528 | 529 | for kind in ["inputs", "outputs", "targets"]: 530 | index.write("" % fileset[kind]) 531 | 532 | index.write("") 533 | return index_path 534 | 535 | 536 | def main(): 537 | if a.seed is None: 538 | a.seed = random.randint(0, 2**31 - 1) 539 | 540 | tf.set_random_seed(a.seed) 541 | np.random.seed(a.seed) 542 | random.seed(a.seed) 543 | 544 | if not os.path.exists(a.output_dir): 545 | os.makedirs(a.output_dir) 546 | 547 | if a.mode == "test" or a.mode == "export": 548 | if a.checkpoint is None: 549 | raise Exception("checkpoint required for test mode") 550 | 551 | # load some options from the checkpoint 552 | options = {"which_direction", "ngf", "ndf", "lab_colorization"} 553 | with open(os.path.join(a.checkpoint, "options.json")) as f: 554 | for key, val in json.loads(f.read()).items(): 555 | if key in options: 556 | print("loaded", key, "=", val) 557 | setattr(a, key, val) 558 | # disable these features in test mode 559 | a.scale_size = CROP_SIZE 560 | a.flip = False 561 | 562 | for k, v in a._get_kwargs(): 563 | print(k, "=", v) 564 | 565 | with open(os.path.join(a.output_dir, "options.json"), "w") as f: 566 | f.write(json.dumps(vars(a), sort_keys=True, indent=4)) 567 | 568 | if a.mode == "export": 569 | # export the generator to a meta graph that can be imported later for standalone generation 570 | if a.lab_colorization: 571 | raise Exception("export not supported for lab_colorization") 572 | 573 | input = tf.placeholder(tf.string, shape=[1]) 574 | input_data = tf.decode_base64(input[0]) 575 | input_image = tf.image.decode_png(input_data) 576 | 577 | # remove alpha channel if present 578 | input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image) 579 | # convert grayscale to RGB 580 | input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image) 581 | 582 | input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) 583 | input_image.set_shape([CROP_SIZE, CROP_SIZE, 3]) 584 | batch_input = tf.expand_dims(input_image, axis=0) 585 | 586 | with tf.variable_scope("generator"): 587 | batch_output = deprocess(create_generator(preprocess(batch_input), 3)) 588 | 589 | output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] 590 | if a.output_filetype == "png": 591 | output_data = tf.image.encode_png(output_image) 592 | elif a.output_filetype == "jpeg": 593 | output_data = tf.image.encode_jpeg(output_image, quality=80) 594 | else: 595 | raise Exception("invalid filetype") 596 | output = tf.convert_to_tensor([tf.encode_base64(output_data)]) 597 | 598 | key = tf.placeholder(tf.string, shape=[1]) 599 | inputs = { 600 | "key": key.name, 601 | "input": input.name 602 | } 603 | tf.add_to_collection("inputs", json.dumps(inputs)) 604 | outputs = { 605 | "key": tf.identity(key).name, 606 | "output": output.name, 607 | } 608 | tf.add_to_collection("outputs", json.dumps(outputs)) 609 | 610 | init_op = tf.global_variables_initializer() 611 | restore_saver = tf.train.Saver() 612 | export_saver = tf.train.Saver() 613 | 614 | with tf.Session() as sess: 615 | sess.run(init_op) 616 | print("loading model from checkpoint") 617 | checkpoint = tf.train.latest_checkpoint(a.checkpoint) 618 | restore_saver.restore(sess, checkpoint) 619 | print("exporting model") 620 | export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta")) 621 | export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False) 622 | 623 | return 624 | 625 | examples = load_examples() 626 | print("examples count = %d" % examples.count) 627 | 628 | # inputs and targets are [batch_size, height, width, channels] 629 | model = create_model(examples.inputs, examples.targets) 630 | 631 | # undo colorization splitting on images that we use for display/output 632 | if a.lab_colorization: 633 | if a.which_direction == "AtoB": 634 | # inputs is brightness, this will be handled fine as a grayscale image 635 | # need to augment targets and outputs with brightness 636 | targets = augment(examples.targets, examples.inputs) 637 | outputs = augment(model.outputs, examples.inputs) 638 | # inputs can be deprocessed normally and handled as if they are single channel 639 | # grayscale images 640 | inputs = deprocess(examples.inputs) 641 | elif a.which_direction == "BtoA": 642 | # inputs will be color channels only, get brightness from targets 643 | inputs = augment(examples.inputs, examples.targets) 644 | targets = deprocess(examples.targets) 645 | outputs = deprocess(model.outputs) 646 | else: 647 | raise Exception("invalid direction") 648 | else: 649 | inputs = deprocess(examples.inputs) 650 | targets = deprocess(examples.targets) 651 | outputs = deprocess(model.outputs) 652 | 653 | def convert(image): 654 | if a.aspect_ratio != 1.0: 655 | # upscale to correct aspect ratio 656 | size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))] 657 | image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC) 658 | 659 | return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) 660 | 661 | # reverse any processing on images so they can be written to disk or displayed to user 662 | with tf.name_scope("convert_inputs"): 663 | converted_inputs = convert(inputs) 664 | 665 | with tf.name_scope("convert_targets"): 666 | converted_targets = convert(targets) 667 | 668 | with tf.name_scope("convert_outputs"): 669 | converted_outputs = convert(outputs) 670 | 671 | with tf.name_scope("encode_images"): 672 | display_fetches = { 673 | "paths": examples.paths, 674 | "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), 675 | "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), 676 | "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), 677 | } 678 | 679 | # summaries 680 | with tf.name_scope("inputs_summary"): 681 | tf.summary.image("inputs", converted_inputs) 682 | 683 | with tf.name_scope("targets_summary"): 684 | tf.summary.image("targets", converted_targets) 685 | 686 | with tf.name_scope("outputs_summary"): 687 | tf.summary.image("outputs", converted_outputs) 688 | 689 | with tf.name_scope("predict_real_summary"): 690 | tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) 691 | 692 | with tf.name_scope("predict_fake_summary"): 693 | tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) 694 | 695 | tf.summary.scalar("discriminator_loss", model.discrim_loss) 696 | tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) 697 | tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) 698 | 699 | for var in tf.trainable_variables(): 700 | tf.summary.histogram(var.op.name + "/values", var) 701 | 702 | for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: 703 | tf.summary.histogram(var.op.name + "/gradients", grad) 704 | 705 | with tf.name_scope("parameter_count"): 706 | parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) 707 | 708 | saver = tf.train.Saver(max_to_keep=1) 709 | 710 | logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None 711 | sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None) 712 | with sv.managed_session() as sess: 713 | print("parameter_count =", sess.run(parameter_count)) 714 | 715 | if a.checkpoint is not None: 716 | print("loading model from checkpoint") 717 | checkpoint = tf.train.latest_checkpoint(a.checkpoint) 718 | saver.restore(sess, checkpoint) 719 | 720 | max_steps = 2**32 721 | if a.max_epochs is not None: 722 | max_steps = examples.steps_per_epoch * a.max_epochs 723 | if a.max_steps is not None: 724 | max_steps = a.max_steps 725 | 726 | if a.mode == "test": 727 | # testing 728 | # at most, process the test data once 729 | start = time.time() 730 | max_steps = min(examples.steps_per_epoch, max_steps) 731 | for step in range(max_steps): 732 | results = sess.run(display_fetches) 733 | filesets = save_images(results) 734 | for i, f in enumerate(filesets): 735 | print("evaluated image", f["name"]) 736 | index_path = append_index(filesets) 737 | print("wrote index at", index_path) 738 | print("rate", (time.time() - start) / max_steps) 739 | else: 740 | # training 741 | start = time.time() 742 | 743 | for step in range(max_steps): 744 | def should(freq): 745 | return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) 746 | 747 | options = None 748 | run_metadata = None 749 | if should(a.trace_freq): 750 | options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 751 | run_metadata = tf.RunMetadata() 752 | 753 | fetches = { 754 | "train": model.train, 755 | "global_step": sv.global_step, 756 | } 757 | 758 | if should(a.progress_freq): 759 | fetches["discrim_loss"] = model.discrim_loss 760 | fetches["gen_loss_GAN"] = model.gen_loss_GAN 761 | fetches["gen_loss_L1"] = model.gen_loss_L1 762 | 763 | if should(a.summary_freq): 764 | fetches["summary"] = sv.summary_op 765 | 766 | if should(a.display_freq): 767 | fetches["display"] = display_fetches 768 | 769 | results = sess.run(fetches, options=options, run_metadata=run_metadata) 770 | 771 | if should(a.summary_freq): 772 | print("recording summary") 773 | sv.summary_writer.add_summary(results["summary"], results["global_step"]) 774 | 775 | if should(a.display_freq): 776 | print("saving display images") 777 | filesets = save_images(results["display"], step=results["global_step"]) 778 | append_index(filesets, step=True) 779 | 780 | if should(a.trace_freq): 781 | print("recording trace") 782 | sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"]) 783 | 784 | if should(a.progress_freq): 785 | # global_step will have the correct step count if we resume from a checkpoint 786 | train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch) 787 | train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1 788 | rate = (step + 1) * a.batch_size / (time.time() - start) 789 | remaining = (max_steps - step) * a.batch_size / rate 790 | print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) 791 | print("discrim_loss", results["discrim_loss"]) 792 | print("gen_loss_GAN", results["gen_loss_GAN"]) 793 | print("gen_loss_L1", results["gen_loss_L1"]) 794 | 795 | if should(a.save_freq): 796 | print("saving model") 797 | saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) 798 | 799 | if sv.should_stop(): 800 | break 801 | 802 | 803 | main() 804 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras==2.1.3 2 | webvtt-py==0.4.1 3 | tensorflow==1.15.0 4 | tensorboard==1.15.0 5 | python-speech-features==0.6 6 | dlib==19.9.0 7 | opencv-contrib-python==3.4.2.16 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | from keras.models import Sequential 3 | from keras.layers import Dense, LSTM, Dropout, Embedding, Lambda, TimeDistributed 4 | import keras.backend as K 5 | from keras.preprocessing.sequence import pad_sequences 6 | from keras.models import load_model 7 | import keras 8 | from sklearn.preprocessing import MinMaxScaler 9 | import numpy as np 10 | from tqdm import tqdm 11 | import pickle as pkl 12 | from keras.callbacks import TensorBoard 13 | from time import time 14 | import cv2 15 | import scipy.io.wavfile as wav 16 | from python_speech_features import logfbank 17 | import subprocess 18 | import argparse 19 | 20 | ######################################################################################### 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--sf", help="path to wav file") 24 | a = parser.parse_args() 25 | 26 | key_audio = a.sf # '00003' # '00001-003' # 'karan' # '00002-002' # '00002-007' # 27 | time_delay = 20 28 | look_back = 50 29 | n_epoch = 50 30 | outputFolder = 'testing_output_images/' 31 | 32 | ######################################################################################### 33 | 34 | cmd = 'rm -rf '+outputFolder + '&& mkdir ' + outputFolder 35 | subprocess.call(cmd ,shell=True) 36 | 37 | ######################################################################################### 38 | 39 | model = load_model('checkpoints/my_model.h5') 40 | 41 | ######################################################################################### 42 | 43 | def subsample(y, fps_from = 100.0, fps_to = 29.97): 44 | factor = int(np.ceil(fps_from/fps_to)) 45 | # Subsample the points 46 | new_y = np.zeros((int(y.shape[0]/factor), 20, 2)) #(timesteps, 20) = (500, 20x2) 47 | for idx in range(new_y.shape[0]): 48 | if not (idx*factor > y.shape[0]-1): 49 | # Get into (x, y) format 50 | new_y[idx, :, 0] = y[idx*factor, 0:20] 51 | new_y[idx, :, 1] = y[idx*factor, 20:] 52 | else: 53 | break 54 | # print('Subsampled y:', new_y.shape) 55 | new_y = [np.array(each) for each in new_y.tolist()] 56 | # print(len(new_y)) 57 | return new_y 58 | 59 | def drawLips(keypoints, new_img, c = (255, 255, 255), th = 1, show = False): 60 | 61 | keypoints = np.float32(keypoints) 62 | 63 | for i in range(48, 59): 64 | cv2.line(new_img, tuple(keypoints[i]), tuple(keypoints[i+1]), color=c, thickness=th) 65 | cv2.line(new_img, tuple(keypoints[48]), tuple(keypoints[59]), color=c, thickness=th) 66 | cv2.line(new_img, tuple(keypoints[48]), tuple(keypoints[60]), color=c, thickness=th) 67 | cv2.line(new_img, tuple(keypoints[54]), tuple(keypoints[64]), color=c, thickness=th) 68 | cv2.line(new_img, tuple(keypoints[67]), tuple(keypoints[60]), color=c, thickness=th) 69 | for i in range(60, 67): 70 | cv2.line(new_img, tuple(keypoints[i]), tuple(keypoints[i+1]), color=c, thickness=th) 71 | 72 | if (show == True): 73 | cv2.imshow('lol', new_img) 74 | cv2.waitKey(10000) 75 | 76 | def getOriginalKeypoints(kp_features_mouth, N, tilt, mean): 77 | # Denormalize the points 78 | kp_dn = N * kp_features_mouth 79 | # Add the tilt 80 | x, y = kp_dn[:, 0], kp_dn[:, 1] 81 | c, s = np.cos(tilt), np.sin(tilt) 82 | x_dash, y_dash = x*c + y*s, -x*s + y*c 83 | kp_tilt = np.hstack((x_dash.reshape((-1,1)), y_dash.reshape((-1,1)))) 84 | # Shift to the mean 85 | kp = kp_tilt + mean 86 | return kp 87 | 88 | ######################################################################################### 89 | 90 | # Load the files 91 | # with open('data/audio_kp/audio_kp1467_mel.pickle', 'rb') as pkl_file: 92 | # audio_kp = pkl.load(pkl_file) 93 | with open('data/pca/pkp1467.pickle', 'rb') as pkl_file: 94 | video_kp = pkl.load(pkl_file) 95 | with open('data/pca/pca1467.pickle', 'rb') as pkl_file: 96 | pca = pkl.load(pkl_file) 97 | # Get the original keypoints file 98 | with open('data/a2key_data/kp_test.pickle', 'rb') as pkl_file: 99 | kp = pkl.load(pkl_file) 100 | 101 | # Get the data 102 | X, y = [], [] # Create the empty lists 103 | # audio = audio_kp[key_audio] 104 | video = video_kp['00001-000'] 105 | 106 | 107 | # Get audio features 108 | (rate, sig) = wav.read(key_audio) 109 | audio = logfbank(sig,rate) 110 | 111 | 112 | # if (len(audio) > len(video)): 113 | # audio = audio[0:len(video)] 114 | # else: 115 | # video = video[0:len(audio)] 116 | start = (time_delay-look_back) if (time_delay-look_back > 0) else 0 117 | for i in range(start, len(audio)-look_back): 118 | a = np.array(audio[i:i+look_back]) 119 | # v = np.array(video[i+look_back-time_delay]).reshape((1, -1)) 120 | X.append(a) 121 | # y.append(v) 122 | 123 | for i in range(start, len(video)-look_back): 124 | v = np.array(video[i+look_back-time_delay]).reshape((1, -1)) 125 | y.append(v) 126 | 127 | X = np.array(X) 128 | y = np.array(y) 129 | shapeX = X.shape 130 | shapey = y.shape 131 | print('Shapes:', X.shape) 132 | X = X.reshape(-1, X.shape[2]) 133 | y = y.reshape(-1, y.shape[2]) 134 | print('Shapes:', X.shape) 135 | 136 | scalerX = MinMaxScaler(feature_range=(0, 1)) 137 | scalery = MinMaxScaler(feature_range=(0, 1)) 138 | 139 | X = scalerX.fit_transform(X) 140 | y = scalery.fit_transform(y) 141 | 142 | 143 | X = X.reshape(shapeX) 144 | # y = y.reshape(shapey[0], shapey[2]) 145 | 146 | # print('Shapes:', X.shape, y.shape) 147 | # print('X mean:', np.mean(X), 'X var:', np.var(X)) 148 | # print('y mean:', np.mean(y), 'y var:', np.var(y)) 149 | 150 | y_pred = model.predict(X) 151 | 152 | # Scale it up 153 | y_pred = scalery.inverse_transform(y_pred) 154 | # y = scalery.inverse_transform(y) 155 | 156 | y_pred = pca.inverse_transform(y_pred) 157 | # y = pca.inverse_transform(y) 158 | 159 | print('Upsampled number:', len(y_pred)) 160 | 161 | y_pred = subsample(y_pred, 100, 34) 162 | 163 | # y = subsample(y, 100, 100) 164 | 165 | # error = np.mean(np.square(np.array(y_pred) - np.array(y))) 166 | 167 | # print('Error:', error) 168 | 169 | print('Subsampled number:', len(y_pred)) 170 | 171 | # Visualization 172 | # Cut the other stream according to whichever is smaller 173 | if (len(kp) < len(y_pred)): 174 | n = len(kp) 175 | y_pred = y_pred[:n] 176 | else: 177 | n = len(y_pred) 178 | kp = kp[:n] 179 | 180 | 181 | for idx, (x, k) in enumerate(zip(y_pred, kp)): 182 | 183 | unit_mouth_kp, N, tilt, mean, unit_kp, keypoints = k[0], k[1], k[2], k[3], k[4], k[5] 184 | kps = getOriginalKeypoints(x, N, tilt, mean) 185 | keypoints[48:68] = kps 186 | 187 | imgfile = 'data/a2key_data/images/' + str(idx+1).rjust(5, '0') + '.png' 188 | im = cv2.imread(imgfile) 189 | drawLips(keypoints, im, c = (255, 255, 255), th = 1, show = False) 190 | 191 | # make it pix2pix style 192 | im_out = np.zeros_like(im) 193 | im1 = np.hstack((im, im_out)) 194 | # print('Shape: ', im1.shape) 195 | cv2.imwrite(outputFolder + str(idx) + '.png', im1) 196 | 197 | print('Done writing', n, 'images') 198 | 199 | # cmd = 'rm -rf input0.mp4 && rm -rf output.mp4' 200 | # subprocess.call(cmd ,shell=True) 201 | 202 | # cmd = 'ffmpeg -r 30 -f image2 -s 256x256 -i output_images/%d.png -vcodec libx264 -crf 25 input0.mp4 && ffmpeg -i input0.mp4 -i data/audios/00003.wav -c:v copy -c:a aac -strict experimental output.mp4 && rm -rf output_images/*.png' 203 | # subprocess.call(cmd ,shell=True) 204 | 205 | # cmd = 'rm -rf input0.mp4' 206 | # subprocess.call(cmd ,shell=True) 207 | 208 | # cmd = 'rm -rf output_images' 209 | # subprocess.call(cmd ,shell=True) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | SF="$1" 2 | 3 | rm -rf input.mp4 4 | rm -rf output.mp4 5 | 6 | python3 run.py --sf $SF && 7 | python3 pix2pix.py --mode test --output_dir test_output/ --input_dir testing_output_images/ --checkpoint checkpoints/output/ && 8 | 9 | ffmpeg -r 30 -f image2 -s 256x256 -i test_output/images/%d-outputs.png -vcodec libx264 -crf 25 output0.mp4 && 10 | ffmpeg -r 30 -f image2 -s 256x256 -i test_output/images/%d-inputs.png -vcodec libx264 -crf 25 input0.mp4 && 11 | # ffmpeg -i output0.mp4 -i audio_testing.wav -c:v copy -c:a aac -strict experimental output.mp4 12 | # audios/00001-000.wav 13 | ffmpeg -i $SF output_audio_trim.wav && 14 | 15 | ffmpeg -i output0.mp4 -i output_audio_trim.wav -c:v copy -c:a aac output.mp4 && 16 | ffmpeg -i input0.mp4 -i output_audio_trim.wav -c:v copy -c:a aac input.mp4 && 17 | 18 | rm -rf testing_output_images 19 | rm -rf test_output 20 | rm -rf output0.mp4 21 | rm -rf input0.mp4 22 | rm -rf output_audio_trim.wav -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential 2 | from keras.layers import Dense, LSTM, Dropout, Embedding, Lambda, TimeDistributed 3 | import keras.backend as K 4 | from keras.preprocessing.sequence import pad_sequences 5 | from keras.models import load_model 6 | import keras 7 | from sklearn.preprocessing import MinMaxScaler 8 | import numpy as np 9 | from tqdm import tqdm 10 | import pickle as pkl 11 | from keras.callbacks import TensorBoard 12 | from time import time 13 | 14 | ######################################################################################### 15 | 16 | time_delay = 20 #0 17 | look_back = 50 18 | n_epoch = 50 19 | n_videos = 50 20 | tbCallback = TensorBoard(log_dir="logs/{}".format(time())) # TensorBoard(log_dir='./Graph', histogram_freq=0, batch_size=n_batch, write_graph=True, write_images=True) 21 | 22 | ######################################################################################### 23 | 24 | # Load the files 25 | with open('data/audio_kp/audio_kp1467_mel.pickle', 'rb') as pkl_file: 26 | audio_kp = pkl.load(pkl_file) 27 | with open('data/pca/pkp1467.pickle', 'rb') as pkl_file: 28 | video_kp = pkl.load(pkl_file) 29 | with open('data/pca/pca1467.pickle', 'rb') as pkl_file: 30 | pca = pkl.load(pkl_file) 31 | 32 | 33 | # Get the data 34 | 35 | X, y = [], [] # Create the empty lists 36 | # Get the common keys 37 | keys_audio = audio_kp.keys() 38 | keys_video = video_kp.keys() 39 | keys = sorted(list(set(keys_audio).intersection(set(keys_video)))) 40 | # print('Length of common keys:', len(keys), 'First common key:', keys[0]) 41 | 42 | # X = np.array(X).reshape((-1, 26)) 43 | # y = np.array(y).reshape((-1, 8)) 44 | 45 | for key in tqdm(keys[0:n_videos]): 46 | audio = audio_kp[key] 47 | video = video_kp[key] 48 | if (len(audio) > len(video)): 49 | audio = audio[0:len(video)] 50 | else: 51 | video = video[0:len(audio)] 52 | start = (time_delay-look_back) if (time_delay-look_back > 0) else 0 53 | for i in range(start, len(audio)-look_back): 54 | a = np.array(audio[i:i+look_back]) 55 | v = np.array(video[i+look_back-time_delay]).reshape((1, -1)) 56 | X.append(a) 57 | y.append(v) 58 | 59 | X = np.array(X) 60 | y = np.array(y) 61 | shapeX = X.shape 62 | shapey = y.shape 63 | print('Shapes:', X.shape, y.shape) 64 | X = X.reshape(-1, X.shape[2]) 65 | y = y.reshape(-1, y.shape[2]) 66 | print('Shapes:', X.shape, y.shape) 67 | 68 | scalerX = MinMaxScaler(feature_range=(0, 1)) 69 | scalery = MinMaxScaler(feature_range=(0, 1)) 70 | 71 | X = scalerX.fit_transform(X) 72 | y = scalery.fit_transform(y) 73 | 74 | 75 | X = X.reshape(shapeX) 76 | y = y.reshape(shapey[0], shapey[2]) 77 | 78 | print('Shapes:', X.shape, y.shape) 79 | print('X mean:', np.mean(X), 'X var:', np.var(X)) 80 | print('y mean:', np.mean(y), 'y var:', np.var(y)) 81 | 82 | split1 = int(0.8*X.shape[0]) 83 | split2 = int(0.9*X.shape[0]) 84 | 85 | train_X = X[0:split1] 86 | train_y = y[0:split1] 87 | val_X = X[split1:split2] 88 | val_y = y[split1:split2] 89 | test_X = X[split2:] 90 | test_y = y[split2:] 91 | 92 | 93 | 94 | 95 | # Initialize the model 96 | 97 | model = Sequential() 98 | model.add(LSTM(25, input_shape=(look_back, 26))) 99 | model.add(Dropout(0.25)) 100 | model.add(Dense(8)) 101 | model.compile(loss='mean_squared_error', optimizer='adam') 102 | print(model.summary()) 103 | # model = load_model('my_model.h5') 104 | 105 | # train LSTM with validation data 106 | for i in tqdm(range(n_epoch)): 107 | print('Epoch', (i+1), '/', n_epoch, ' - ', int(100*(i+1)/n_epoch)) 108 | model.fit(train_X, train_y, epochs=1, batch_size=1, 109 | verbose=1, shuffle=True, callbacks=[tbCallback], validation_data=(val_X, val_y)) 110 | # model.reset_states() 111 | test_error = np.mean(np.square(test_y - model.predict(test_X))) 112 | # model.reset_states() 113 | print('Test Error: ', test_error) 114 | 115 | # Save the model 116 | model.save('my_model.h5') 117 | model.save_weights('my_model_weights.h5') 118 | print('Saved Model.') 119 | 120 | 121 | 122 | # X, num = audioToPrediction('audios/' + key_audio + '.wav') 123 | # y = model.predict(X, batch_size=n_batch) 124 | # y = y.reshape(y.shape[0]*y.shape[1], y.shape[2]) 125 | # print('y:', y[0:num].shape) 126 | 127 | 128 | 129 | --------------------------------------------------------------------------------
stepnameinputoutputtarget
%d%s