├── LICENSE ├── MNIST_data ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── README.md ├── example └── mnist.gif ├── figures ├── iter10.gif ├── iter30.gif ├── iter400.gif ├── iter50.gif └── model.png └── gan.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jaesik Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST_data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/MNIST_data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/MNIST_data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/MNIST_data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/MNIST_data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Sequential Data GAN 2 | ----- 3 | 4 | This module is to generate sequential data with GAN implemented by LSTM. 5 | 6 | The basic structure is followed ckmarkoh's git https://github.com/ckmarkoh/GAN-tensorflow 7 | 8 | ### Used Data 9 | For making sequential data, MNIST 0-3 data are used (0->1->2->3). 10 | 11 | ### Modeling 12 | Each modules, genrator and discreminator are designed with 2 layer LSTM and 1 layer Fully Connected Network. 13 | Generator is designed as one-to-many model, which get one random vector as input, and generates sequential images. 14 | Discriminator is designed as many-to-one model, which get sequential images, and decides that is real or fake ones. 15 | 16 | ![alt tag](https://github.com/jaesik817/SequentialData-GAN/blob/master/figures/model.png) 17 | 18 | ### Results 19 | Epoch 10 20 | ![alt tag](https://github.com/jaesik817/SequentialData-GAN/blob/master/figures/iter10.gif) 21 | 22 | Epoch 30 23 | ![alt tag](https://github.com/jaesik817/SequentialData-GAN/blob/master/figures/iter30.gif) 24 | 25 | Epoch 50 26 | ![alt tag](https://github.com/jaesik817/SequentialData-GAN/blob/master/figures/iter50.gif) 27 | 28 | Epoch 400 29 | ![alt tag](https://github.com/jaesik817/SequentialData-GAN/blob/master/figures/iter400.gif) 30 | -------------------------------------------------------------------------------- /example/mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/example/mnist.gif -------------------------------------------------------------------------------- /figures/iter10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/figures/iter10.gif -------------------------------------------------------------------------------- /figures/iter30.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/figures/iter30.gif -------------------------------------------------------------------------------- /figures/iter400.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/figures/iter400.gif -------------------------------------------------------------------------------- /figures/iter50.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/figures/iter50.gif -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/SequentialData-GAN/68f48c6b60d50cd3e50cb66b28d18ccf5e96284f/figures/model.png -------------------------------------------------------------------------------- /gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | import numpy as np 4 | from skimage.io import imsave 5 | import os 6 | import shutil 7 | import time 8 | 9 | img_height = 28 10 | img_width = 28 11 | img_size = img_height * img_width 12 | 13 | to_train = True 14 | to_restore = False 15 | output_path = "output" 16 | 17 | max_epoch = 1000 18 | 19 | h1_size = 150 20 | h2_size = 300 21 | z_size = 100 22 | batch_size = 256 23 | seq_size=4 24 | n_hidden=300 25 | tr_data_num=60000; 26 | g_num_layers=2; 27 | d_num_layers=2; 28 | 29 | log_dir="/tmp/gan_seq/"+str(int(time.time())) 30 | 31 | 32 | def variable_summaries(var): 33 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 34 | with tf.name_scope('summaries'): 35 | mean = tf.reduce_mean(var) 36 | tf.summary.scalar('mean', mean) 37 | with tf.name_scope('stddev'): 38 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 39 | tf.summary.scalar('stddev', stddev) 40 | tf.summary.scalar('max', tf.reduce_max(var)) 41 | tf.summary.scalar('min', tf.reduce_min(var)) 42 | #tf.summary.histogram('histogram', var) 43 | 44 | def build_generator(z_prior,keep_prob): 45 | z_prior=tf.unstack(z_prior,seq_size,1); 46 | lstm_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(n_hidden), output_keep_prob=keep_prob)for _ in range(g_num_layers)]); 47 | with tf.variable_scope("gen") as gen: 48 | res, states = tf.contrib.rnn.static_rnn(lstm_cell, z_prior,dtype=tf.float32); 49 | weights=tf.Variable(tf.random_normal([n_hidden, img_size])); 50 | biases=tf.Variable(tf.random_normal([img_size])); 51 | for i in range(len(res)): 52 | res[i]=tf.nn.tanh(tf.matmul(res[i], weights) + biases); 53 | g_params=[v for v in tf.global_variables() if v.name.startswith(gen.name)]; 54 | with tf.name_scope("gen_params"): 55 | for param in g_params: 56 | variable_summaries(param); 57 | return res,g_params; 58 | 59 | def build_discriminator(x_data, x_generated, keep_prob): 60 | x_data=tf.unstack(x_data,seq_size,1); 61 | x_generated=list(x_generated); 62 | x_in = tf.concat([x_data, x_generated],1); 63 | x_in=tf.unstack(x_in,seq_size,0); 64 | lstm_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(n_hidden), output_keep_prob=keep_prob) for _ in range(d_num_layers)]); 65 | with tf.variable_scope("dis") as dis: 66 | weights=tf.Variable(tf.random_normal([n_hidden, 1])); 67 | biases=tf.Variable(tf.random_normal([1])); 68 | outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x_in, dtype=tf.float32); 69 | res=tf.matmul(outputs[-1], weights) + biases; 70 | y_data = tf.nn.sigmoid(tf.slice(res, [0, 0], [batch_size, -1], name=None)); 71 | y_generated = tf.nn.sigmoid(tf.slice(res, [batch_size, 0], [-1, -1], name=None)); 72 | d_params=[v for v in tf.global_variables() if v.name.startswith(dis.name)]; 73 | with tf.name_scope("desc_params"): 74 | for param in d_params: 75 | variable_summaries(param); 76 | return y_data, y_generated, d_params; 77 | 78 | def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5): 79 | batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5 80 | img_h, img_w = batch_res.shape[1], batch_res.shape[2] 81 | grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1) 82 | grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1) 83 | img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8) 84 | for i, res in enumerate(batch_res): 85 | if i >= grid_size[0] * grid_size[1]: 86 | break 87 | img = (res) * 255 88 | img = img.astype(np.uint8) 89 | row = (i // grid_size[0]) * (img_h + grid_pad) 90 | col = (i % grid_size[1]) * (img_w + grid_pad) 91 | img_grid[row:row + img_h, col:col + img_w] = img 92 | imsave(fname, img_grid) 93 | 94 | 95 | def train(): 96 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 97 | total_tr_data, total_tr_label = mnist.train.next_batch(mnist.train._num_examples); 98 | total_tr_data=np.array(total_tr_data,dtype=float); 99 | total_tr_label=np.array(total_tr_label,dtype=float); 100 | 101 | tr_data=np.zeros((tr_data_num,seq_size,img_size),dtype=object); 102 | for i in range(seq_size): 103 | total_idx=np.where(total_tr_label[:,i]==1.0)[0]; 104 | while(len(total_idx)