├── LICENSE ├── README.md ├── create_data.py ├── dancegen.ipynb ├── demo.gif ├── demo2.gif ├── floyd.yml ├── gen_lv.py ├── mdn.py ├── model.py └── video_from_lv.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jaison Saji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DanceNet - Dance generator using Variational Autoencoder, LSTM and Mixture Density Network. (Keras) 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/jsn5/dancenet/blob/master/LICENSE) [![Run on FloydHub](https://static.floydhub.com/button/button-small.svg)](https://floydhub.com/run) 4 | [![DOI](https://zenodo.org/badge/143685321.svg)](https://zenodo.org/badge/latestdoi/143685321) 5 | 6 | ![](https://github.com/jsn5/dancenet/blob/master/demo.gif ) ![](https://github.com/jsn5/dancenet/blob/master/demo2.gif ) 7 | 8 | ## Main components: 9 | 10 | * Variational autoencoder 11 | * LSTM + Mixture Density Layer 12 | 13 | ## Requirements: 14 | 15 | * Python version = 3.5.2 16 | 17 | ### Packages 18 | * keras==2.2.0 19 | * sklearn==0.19.1 20 | * numpy==1.14.3 21 | * opencv-python==3.4.1 22 | 23 | ## Dataset 24 | 25 | https://www.youtube.com/watch?v=NdSqAAT28v0 26 | This is the video used for training. 27 | 28 | 29 | ## How to run locally 30 | 31 | * Download the trained weights from [here](https://drive.google.com/file/d/1LWtERyPAzYeZjL816gBoLyQdC2MDK961/view?usp=sharing). and extract it to the dancenet dir. 32 | * Run dancegen.ipynb 33 | 34 | ## How to run in your browser 35 | 36 | [![Run on FloydHub](https://static.floydhub.com/button/button-small.svg)](https://floydhub.com/run) 37 | 38 | * Click the button above to open this code in a FloydHub workspace (the [trained weights dataset](https://www.floydhub.com/whatrocks/datasets/dancenet-weights) will be automatically attached to the environment) 39 | * Run dancegen.ipynb 40 | 41 | ## Training from scratch 42 | 43 | * fill dance sequence images labeled as `1.jpg`, `2.jpg` ... in `imgs/` folder 44 | * run `model.py` 45 | * run `gen_lv.py` to encode images 46 | * run `video_from_lv.py` to test decoded video 47 | * run jupyter notebook `dancegen.ipynb` to train dancenet and generate new video. 48 | 49 | ## References 50 | 51 | * [Does my AI have better dance moves than me?](https://www.youtube.com/watch?v=Sc7RiNgHHaE&t=9s) by Cary Huang 52 | * [Generative Choreography using Deep Learning (Chor-RNN)](https://arxiv.org/abs/1605.06921) 53 | * [Building autoencoders in keras](https://blog.keras.io/building-autoencoders-in-keras.html) by [Francois Chollet](https://twitter.com/fchollet) 54 | * [Time Series Prediction with LSTM Recurrent Neural Networks in Python with Keras](https://machinelearningmastery.com/time-series-prediction-lstm-recurrent-neural-networks-python-keras/) 55 | * [Mixture Density Networks](http://blog.otoro.net/2015/06/14/mixture-density-networks/) by [David Ha](https://twitter.com/hardmaru) 56 | * [Mixture Density Layer for Keras](https://github.com/cpmpercussion/keras-mdn-layer) by [Charles Martin](https://github.com/cpmpercussion/) 57 | 58 | -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | 6 | VIDEO_PATH = 'data.mkv' 7 | kernel = np.ones((2,2),np.uint8) 8 | cap = cv2.VideoCapture(VIDEO_PATH) 9 | 10 | data = [] 11 | count = 1 12 | limit = 0 13 | while(cap.isOpened()): 14 | ret, image_np = cap.read() 15 | if ret == False: 16 | break 17 | if limit == 3: 18 | limit = 0 19 | #image_np = 255 - image_np 20 | image_np = cv2.resize(image_np,(208,120)) 21 | #ret,image_np = cv2.threshold(image_np,127,255,cv2.THRESH_BINARY) 22 | bg_index = np.where(np.greater(image_np,20)) 23 | image_np[bg_index] = 255 24 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) 25 | #(T, thresh) = cv2.threshold(image_np, 0, 255, cv2.THRESH_BINARY) 26 | cv2.imwrite("imgs/{}.jpg".format(count),image_np) 27 | print("{}.jpg".format(count)) 28 | count += 1 29 | limit += 1 30 | -------------------------------------------------------------------------------- /dancegen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from model import vae,decoder\n", 12 | "from keras.layers import Dropout\n", 13 | "from keras.layers import LSTM\n", 14 | "from keras.layers import Dense\n", 15 | "from keras.layers import Reshape\n", 16 | "from keras.layers import Input\n", 17 | "from keras.models import Model\n", 18 | "from keras.optimizers import adam\n", 19 | "from keras.callbacks import ModelCheckpoint\n", 20 | "import cv2\n", 21 | "import numpy as np\n", 22 | "import mdn\n", 23 | "from sklearn.preprocessing import MinMaxScaler" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Set paths" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ENCODED_DATA_PATH = '/floyd/input/dancenetweights/data/lv.npy'\n", 40 | "VAE_PATH = '/floyd/input/dancenetweights/weights/vae_cnn.h5'\n", 41 | "DANCENET_PATH = '/floyd/input/dancenetweights/weights/gendance.h5'" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "# Load encoded data" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "scrolled": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "data = np.load(ENCODED_DATA_PATH)\n", 60 | "print(data.shape)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "# Normalize data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "data = np.array(data).reshape(-1,128)\n", 77 | "scaler = MinMaxScaler(feature_range=(0, 1))\n", 78 | "scaler = scaler.fit(data)\n", 79 | "data = scaler.transform(data)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "numComponents = 24\n", 89 | "outputDim = 128" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# LSTM + MDN " 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "scrolled": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "inputs = Input(shape=(128,))\n", 108 | "x = Reshape((1,128))(inputs)\n", 109 | "x = LSTM(512, return_sequences=True,input_shape=(1,128))(x)\n", 110 | "x = Dropout(0.40)(x)\n", 111 | "x = LSTM(512, return_sequences=True)(x)\n", 112 | "x = Dropout(0.40)(x)\n", 113 | "x = LSTM(512)(x)\n", 114 | "x = Dropout(0.40)(x)\n", 115 | "x = Dense(1000,activation='relu')(x)\n", 116 | "outputs = mdn.MDN(outputDim, numComponents)(x)\n", 117 | "model = Model(inputs=inputs,outputs=outputs)\n", 118 | "print(model.summary())" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "opt = adam(lr=0.0005)\n", 128 | "model.compile(loss=mdn.get_mixture_loss_func(outputDim,numComponents),optimizer=opt)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "scrolled": true 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "train = False #change to True to train from scratch\n", 140 | "\n", 141 | "if train:\n", 142 | " X = data[0:len(data)-1]\n", 143 | " Y = data[1:len(data)]\n", 144 | " checkpoint = ModelCheckpoint(DANCENET_PATH, monitor='loss', verbose=1, save_best_only=True, mode='auto')\n", 145 | " callbacks_list = [checkpoint]\n", 146 | " model.fit(X,Y,batch_size=1024, verbose=1, shuffle=False, validation_split=0.20, epochs=10000, callbacks=callbacks_list)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "# Load weights" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "vae.load_weights(VAE_PATH)\n", 163 | "model.load_weights(DANCENET_PATH)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "# Generate Video" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "scrolled": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n", 182 | "video = cv2.VideoWriter(\"out.mp4\", fourcc, 30.0, (208, 120))\n", 183 | "lv_in = data[0]\n", 184 | "\n", 185 | "for i in range(500):\n", 186 | " input = np.array(lv_in).reshape(1,128)\n", 187 | " lv_out = model.predict(input)\n", 188 | " shape = np.array(lv_out).shape[1]\n", 189 | " lv_out = np.array(lv_out).reshape(shape)\n", 190 | " lv_out = mdn.sample_from_output(lv_out,128,numComponents,temp=0.01)\n", 191 | " lv_out = scaler.inverse_transform(lv_out)\n", 192 | " img = decoder.predict(np.array(lv_out).reshape(1,128))\n", 193 | " img = np.array(img).reshape(120,208,1)\n", 194 | " img = img * 255\n", 195 | " img = np.array(img).astype(\"uint8\")\n", 196 | " img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)\n", 197 | " lv_in = lv_out\n", 198 | " video.write(img)\n", 199 | "video.release()\n" 200 | ] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 3", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.6.5" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsn5/dancenet/fbf0b0aee93cc9c7401767eac4195714376a710f/demo.gif -------------------------------------------------------------------------------- /demo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsn5/dancenet/fbf0b0aee93cc9c7401767eac4195714376a710f/demo2.gif -------------------------------------------------------------------------------- /floyd.yml: -------------------------------------------------------------------------------- 1 | env: pytorch-0.4 2 | machine: cpu 3 | input: 4 | - source: whatrocks/datasets/dancenet-weights/1 5 | destination: dancenetweights -------------------------------------------------------------------------------- /gen_lv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.keras import backend as K 3 | import os 4 | import numpy as np 5 | import cv2 6 | from model import encoder,decoder,vae 7 | 8 | vae.load_weights('vae_cnn.h5') 9 | data =[] 10 | lv_array = [] 11 | limit = len(os.listdir('imgs')) 12 | 13 | for i in range(1,limit): 14 | img = cv2.imread('imgs/{}.jpg'.format(i),cv2.IMREAD_GRAYSCALE) 15 | img = cv2.resize(img,(208,120)) 16 | data_np = np.array(img) / 255 17 | data_np = data_np.reshape(1,120,208,1) 18 | lv = encoder.predict(data_np)[2] 19 | lv_array.append(lv) 20 | 21 | print(np.array(lv_array).shape) 22 | np.save("lv.npy",np.array(lv_array)) 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /mdn.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Mixture Density Layer for Keras 3 | cpmpercussion: Charles Martin (University of Oslo) 2018 4 | https://github.com/cpmpercussion/keras-mdn-layer 5 | 6 | Hat tip to [Omimo's Keras MDN layer](https://github.com/omimo/Keras-MDN) for a starting point for this code. 7 | """ 8 | import keras 9 | from keras import backend as K 10 | from keras.layers import Dense 11 | from keras.engine.topology import Layer 12 | import numpy as np 13 | from tensorflow.contrib.distributions import Categorical, Mixture, MultivariateNormalDiag 14 | import tensorflow as tf 15 | 16 | 17 | def elu_plus_one_plus_epsilon(x): 18 | """ELU activation with a very small addition to help prevent NaN in loss.""" 19 | return (K.elu(x) + 1 + 1e-8) 20 | 21 | 22 | class MDN(Layer): 23 | """A Mixture Density Network Layer for Keras. 24 | This layer has a few tricks to avoid NaNs in the loss function when training: 25 | - Activation for variances is ELU + 1 + 1e-8 (to avoid very small values) 26 | - Mixture weights (pi) are trained in as logits, not in the softmax space. 27 | 28 | A loss function needs to be constructed with the same output dimension and number of mixtures. 29 | A sampling function is also provided to sample from distribution parametrised by the MDN outputs. 30 | """ 31 | 32 | def __init__(self, output_dimension, num_mixtures, **kwargs): 33 | self.output_dim = output_dimension 34 | self.num_mix = num_mixtures 35 | with tf.name_scope('MDN'): 36 | self.mdn_mus = Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation 37 | self.mdn_sigmas = Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation 38 | self.mdn_pi = Dense(self.num_mix, name='mdn_pi') # mix vals, logits 39 | super(MDN, self).__init__(**kwargs) 40 | 41 | def build(self, input_shape): 42 | self.mdn_mus.build(input_shape) 43 | self.mdn_sigmas.build(input_shape) 44 | self.mdn_pi.build(input_shape) 45 | self.trainable_weights = self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights 46 | self.non_trainable_weights = self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights 47 | super(MDN, self).build(input_shape) 48 | 49 | def call(self, x, mask=None): 50 | with tf.name_scope('MDN'): 51 | mdn_out = keras.layers.concatenate([self.mdn_mus(x), 52 | self.mdn_sigmas(x), 53 | self.mdn_pi(x)], 54 | name='mdn_outputs') 55 | return mdn_out 56 | 57 | def compute_output_shape(self, input_shape): 58 | return (input_shape[0], self.output_dim) 59 | 60 | def get_config(self): 61 | config = { 62 | "output_dimension": self.output_dim, 63 | "num_mixtures": self.num_mix 64 | } 65 | base_config = super(MDN, self).get_config() 66 | return dict(list(base_config.items()) + list(config.items())) 67 | 68 | 69 | def get_mixture_loss_func(output_dim, num_mixes): 70 | """Construct a loss functions for the MDN layer parametrised by number of mixtures.""" 71 | # Construct a loss function with the right number of mixtures and outputs 72 | def loss_func(y_true, y_pred): 73 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 74 | num_mixes * output_dim, 75 | num_mixes], 76 | axis=1, name='mdn_coef_split') 77 | cat = Categorical(logits=out_pi) 78 | component_splits = [output_dim] * num_mixes 79 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 80 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 81 | coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 82 | in zip(mus, sigs)] 83 | mixture = Mixture(cat=cat, components=coll) 84 | loss = mixture.log_prob(y_true) 85 | loss = tf.negative(loss) 86 | loss = tf.reduce_mean(loss) 87 | return loss 88 | 89 | # Actually return the loss_func 90 | with tf.name_scope('MDN'): 91 | return loss_func 92 | 93 | 94 | def get_mixture_sampling_fun(output_dim, num_mixes): 95 | """Construct a sampling function for the MDN layer parametrised by mixtures and output dimension.""" 96 | # Construct a loss function with the right number of mixtures and outputs 97 | def sampling_func(y_pred): 98 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 99 | num_mixes * output_dim, 100 | num_mixes], 101 | axis=1, name='mdn_coef_split') 102 | cat = Categorical(logits=out_pi) 103 | component_splits = [output_dim] * num_mixes 104 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 105 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 106 | coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 107 | in zip(mus, sigs)] 108 | mixture = Mixture(cat=cat, components=coll) 109 | samp = mixture.sample() 110 | # Todo: temperature adjustment for sampling function. 111 | return samp 112 | 113 | # Actually return the loss_func 114 | with tf.name_scope('MDNLayer'): 115 | return sampling_func 116 | 117 | 118 | def get_mixture_mse_accuracy(output_dim, num_mixes): 119 | """Construct an MSE accuracy function for the MDN layer 120 | that takes one sample and compares to the true value.""" 121 | # Construct a loss function with the right number of mixtures and outputs 122 | def mse_func(y_true, y_pred): 123 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 124 | num_mixes * output_dim, 125 | num_mixes], 126 | axis=1, name='mdn_coef_split') 127 | cat = Categorical(logits=out_pi) 128 | component_splits = [output_dim] * num_mixes 129 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 130 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 131 | coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 132 | in zip(mus, sigs)] 133 | mixture = Mixture(cat=cat, components=coll) 134 | samp = mixture.sample() 135 | mse = tf.reduce_mean(tf.square(samp - y_true), axis=-1) 136 | # Todo: temperature adjustment for sampling functon. 137 | return mse 138 | 139 | # Actually return the loss_func 140 | with tf.name_scope('MDNLayer'): 141 | return mse_func 142 | 143 | 144 | def split_mixture_params(params, output_dim, num_mixes): 145 | """Splits up an array of mixture parameters into mus, sigmas, and pis 146 | depending on the number of mixtures and output dimension.""" 147 | mus = params[:num_mixes*output_dim] 148 | sigs = params[num_mixes*output_dim:2*num_mixes*output_dim] 149 | pi_logits = params[-num_mixes:] 150 | return mus, sigs, pi_logits 151 | 152 | 153 | def softmax(w, t=1.0): 154 | """Softmax function for a list or numpy array of logits. Also adjusts temperature.""" 155 | e = np.array(w) / t # adjust temperature 156 | e -= e.max() # subtract max to protect from exploding exp values. 157 | e = np.exp(e) 158 | dist = e / np.sum(e) 159 | return dist 160 | 161 | 162 | def sample_from_categorical(dist): 163 | """Samples from a categorical model PDF.""" 164 | r = np.random.rand(1) # uniform random number in [0,1] 165 | accumulate = 0 166 | for i in range(0, dist.size): 167 | accumulate += dist[i] 168 | if accumulate >= r: 169 | return i 170 | tf.logging.info('Error sampling mixture model.') 171 | return -1 172 | 173 | 174 | def sample_from_output(params, output_dim, num_mixes, temp=1.0): 175 | """Sample from an MDN output with temperature adjustment.""" 176 | mus = params[:num_mixes*output_dim] 177 | sigs = params[num_mixes*output_dim:2*num_mixes*output_dim] 178 | pis = softmax(params[-num_mixes:], t=temp) 179 | m = sample_from_categorical(pis) 180 | # Alternative way to sample from categorical: 181 | # m = np.random.choice(range(len(pis)), p=pis) 182 | mus_vector = mus[m*output_dim:(m+1)*output_dim] 183 | sig_vector = sigs[m*output_dim:(m+1)*output_dim] * temp # adjust for temperature 184 | cov_matrix = np.identity(output_dim) * sig_vector 185 | sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1) 186 | return sample 187 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D 2 | from keras.layers import MaxPooling2D 3 | from keras.layers import Dense 4 | from keras.layers import Flatten 5 | from keras.layers import Reshape 6 | from keras.layers import Input 7 | from keras.layers import UpSampling2D 8 | from keras.layers import Lambda 9 | from keras.models import Model 10 | from keras.losses import binary_crossentropy 11 | from keras.callbacks import ModelCheckpoint 12 | from keras import backend as K 13 | import numpy as np 14 | import cv2 15 | import os 16 | 17 | latent_dim = 128 18 | def sampling(args): 19 | z_mean, z_log_var = args 20 | batch = K.shape(z_mean)[0] 21 | dim = K.int_shape(z_mean)[1] 22 | # by default, random_normal has mean=0 and std=1.0 23 | epsilon = K.random_normal(shape=(batch, dim)) 24 | return z_mean + K.exp(0.5 * z_log_var) * epsilon 25 | 26 | 27 | input_img = Input(shape=(120,208,1)) 28 | x = Conv2D(filters=128,kernel_size=3, activation='relu', padding='same')(input_img) 29 | x = MaxPooling2D(pool_size=2)(x) 30 | x = Conv2D(filters=64,kernel_size=3, activation='relu', padding='same')(x) 31 | x = MaxPooling2D(pool_size=2)(x) 32 | x = Conv2D(filters=32,kernel_size=3, activation='relu', padding='same')(x) 33 | x = MaxPooling2D(pool_size=2)(x) 34 | shape = K.int_shape(x) 35 | x = Flatten()(x) 36 | x = Dense(128,kernel_initializer='glorot_uniform')(x) 37 | 38 | z_mean = Dense(latent_dim)(x) 39 | z_log_var = Dense(latent_dim)(x) 40 | z = Lambda(sampling, output_shape=(latent_dim,), name="z")([z_mean,z_log_var]) 41 | 42 | encoder = Model(input_img, [z_mean, z_log_var,z], name="encoder") 43 | encoder.summary() 44 | 45 | 46 | latent_inputs = Input(shape=(latent_dim,), name='z_sampling') 47 | x = Dense(shape[1] * shape[2] * shape[3], kernel_initializer='glorot_uniform',activation='relu')(latent_inputs) 48 | x = Reshape((shape[1],shape[2],shape[3]))(x) 49 | x = Dense(128,kernel_initializer='glorot_uniform')(x) 50 | x = Conv2D(filters=32, kernel_size=3, activation='relu', padding='same')(x) 51 | x = UpSampling2D(size=(2,2))(x) 52 | x = Conv2D(filters=64,kernel_size=3, activation='relu', padding='same')(x) 53 | x = UpSampling2D(size=(2,2))(x) 54 | x = Conv2D(filters=128,kernel_size=3, activation='relu', padding='same')(x) 55 | x = UpSampling2D(size=(2,2))(x) 56 | x = Conv2D(filters=1,kernel_size=3, activation='sigmoid', padding='same')(x) 57 | 58 | decoder = Model(latent_inputs,x,name='decoder') 59 | 60 | decoder.summary() 61 | 62 | 63 | outputs = decoder(encoder(input_img)[2]) 64 | print(outputs.shape) 65 | vae = Model(input_img,outputs,name="vae") 66 | 67 | def data_generator(batch_size,limit): 68 | 69 | batch = [] 70 | counter = 1 71 | while 1: 72 | for i in range(1,limit+1): 73 | if counter >= limit: 74 | counter = 1 75 | img = cv2.imread("imgs/{}.jpg".format(counter),cv2.IMREAD_GRAYSCALE) 76 | img = img.reshape(120,208,1) 77 | batch.append(img) 78 | if len(batch) == batch_size: 79 | batch_np = np.array(batch) / 255 80 | batch = [] 81 | yield (batch_np,None) 82 | counter += 1 83 | 84 | if __name__ == '__main__': 85 | 86 | reconstruction_loss = binary_crossentropy(K.flatten(input_img), K.flatten(outputs)) 87 | reconstruction_loss *= 120 * 208 88 | kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) 89 | kl_loss = K.sum(kl_loss,axis=-1) 90 | kl_loss *= -0.5 91 | vae_loss = K.mean(reconstruction_loss + kl_loss) 92 | vae.add_loss(vae_loss) 93 | vae.compile(optimizer='rmsprop', loss = None) 94 | vae.summary() 95 | checkpoint = ModelCheckpoint('./weights/vae_cnn.h5', verbose=1,monitor='loss', save_best_only=True, mode='auto',period=1) 96 | callbacks= [checkpoint] 97 | batch_size = 100 98 | limit = len(os.listdir('imgs')) 99 | spe = int(limit/batch_size) 100 | print(limit,spe) 101 | vae.fit_generator(data_generator(batch_size,limit),epochs=50, steps_per_epoch=spe,callbacks=callbacks) 102 | vae.save_weights('vae_cnn.h5') 103 | -------------------------------------------------------------------------------- /video_from_lv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from model import decoder,vae 4 | import cv2 5 | 6 | vae.load_weights("vae_cnn.h5") 7 | lv = np.load("lv.npy") 8 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 9 | video = cv2.VideoWriter("output.avi", fourcc, 30.0, (208, 120)) 10 | 11 | for i in range(1000): 12 | data = lv[i].reshape(1,128) 13 | img = decoder.predict(data) 14 | img = np.array(img).reshape(120,208,1) 15 | img = img * 255 16 | img = np.array(img).astype("uint8") 17 | img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) 18 | video.write(img) 19 | video.release() 20 | --------------------------------------------------------------------------------