├── .DS_Store ├── ._.DS_Store ├── .gitignore ├── README.md ├── architecture.jpeg ├── dictionary ├── id2Word.npy ├── vocab.npy └── word2Id.npy ├── model.py ├── train.py ├── train_samples └── train_799.png └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/.DS_Store -------------------------------------------------------------------------------- /._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/._.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # text-to-image 2 | Implement for [Kaggles Contest - Reverse Image Caption](https://www.kaggle.com/c/datalabcup-reverse-image-caption-ver2/leaderboard) 3 | ## Architecture 4 | 5 | 6 | Using `GAN-CLS` algorithm from the paper [Generative Adversarial Text-to-Image Synthesis](http://arxiv.org/abs/1605.05396) and `stackGAN-stage1` from [StackGAN - Github](https://github.com/hanzhanggit/StackGAN) 7 | 8 | ## Prepare Data 9 | Download image files and captions from [Google Drive](https://drive.google.com/drive/folders/1aUJrBoIN3l9U5p5pNXT0NeNzlyBWF54u?usp=sharing), put into `./text-to-image` directory 10 | 11 | ## Result ( After 800 epoch ) 12 | * the flower shown has yellow anther red pistil and bright red petals. 13 | * this flower has petals that are yellow, white and purple and has dark lines 14 | * the petals on this flower are white with a yellow center 15 | * this flower has a lot of small round pink petals. 16 | * this flower is orange in color, and has petals that are ruffled and rounded. 17 | * the flower has yellow petals and the center of it is brown 18 | * this flower has petals that are blue and white. 19 | * these white flowers have petals that start off white in color and end in a white towards the tips. 20 | 21 | -------------------------------------------------------------------------------- /architecture.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/architecture.jpeg -------------------------------------------------------------------------------- /dictionary/id2Word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/id2Word.npy -------------------------------------------------------------------------------- /dictionary/vocab.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/vocab.npy -------------------------------------------------------------------------------- /dictionary/word2Id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/word2Id.npy -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def fc(inputs, num_out, name, activation_fn=None, biased=True): 5 | w_init = tf.random_normal_initializer(stddev=0.02) 6 | return tf.layers.dense(inputs=inputs, units=num_out, activation=activation_fn, kernel_initializer=w_init, use_bias=biased, name=name) 7 | 8 | 9 | def concat(inputs, axis, name): 10 | return tf.concat(values=inputs, axis=axis, name=name) 11 | 12 | def batch_normalization(inputs, is_training, name, activation_fn=None): 13 | output = tf.layers.batch_normalization( 14 | inputs, 15 | momentum=0.95, 16 | epsilon=1e-5, 17 | training=is_training, 18 | name=name 19 | ) 20 | 21 | if activation_fn is not None: 22 | output = activation_fn(output) 23 | 24 | return output 25 | 26 | def reshape(inputs, shape, name): 27 | return tf.reshape(inputs, shape, name) 28 | 29 | def Conv2d(input, k_h, k_w, c_o, s_h, s_w, name, activation_fn=None, padding='VALID', biased=False): 30 | c_i = input.get_shape()[-1] 31 | w_init = tf.random_normal_initializer(stddev=0.02) 32 | 33 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 34 | with tf.variable_scope(name) as scope: 35 | kernel = tf.get_variable(name='weights', shape=[k_h, k_w, c_i, c_o], initializer=w_init) 36 | output = convolve(input, kernel) 37 | 38 | if biased: 39 | biases = tf.get_variable(name='biases', shape=[c_o]) 40 | output = tf.nn.bias_add(output, biases) 41 | if activation_fn is not None: 42 | output = activation_fn(output, name=scope.name) 43 | 44 | return output 45 | 46 | def add(inputs, name): 47 | return tf.add_n(inputs, name=name) 48 | 49 | def UpSample(inputs, size, method, align_corners, name): 50 | return tf.image.resize_images(inputs, size, method, align_corners) 51 | 52 | def flatten(input, name): 53 | input_shape = input.get_shape() 54 | dim = 1 55 | for d in input_shape[1:].as_list(): 56 | dim *= d 57 | input = tf.reshape(input, [-1, dim]) 58 | 59 | return input 60 | 61 | class Generator: 62 | def __init__(self, input_z, input_rnn, is_training, reuse): 63 | self.input_z = input_z 64 | self.input_rnn = input_rnn 65 | self.is_training = is_training 66 | self.reuse = reuse 67 | self.t_dim = 128 68 | self.gf_dim = 128 69 | self.image_size = 64 70 | self.c_dim = 3 71 | self._build_model() 72 | 73 | def _build_model(self): 74 | s = self.image_size 75 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 76 | 77 | gf_dim = self.gf_dim 78 | t_dim = self.t_dim 79 | c_dim = self.c_dim 80 | 81 | with tf.variable_scope("generator", reuse=self.reuse): 82 | net_txt = fc(inputs=self.input_rnn, num_out=t_dim, activation_fn=tf.nn.leaky_relu, name='rnn_fc') 83 | net_in = concat([self.input_z, net_txt], axis=1, name='concat_z_txt') 84 | 85 | net_h0 = fc(inputs=net_in, num_out=gf_dim*8*s16*s16, name='g_h0/fc', biased=False) 86 | net_h0 = batch_normalization(net_h0, activation_fn=None, is_training=self.is_training, name='g_h0/batch_norm') 87 | net_h0 = reshape(net_h0, [-1, s16, s16, gf_dim*8], name='g_h0/reshape') 88 | 89 | net = Conv2d(net_h0, 1, 1, gf_dim*2, 1, 1, name='g_h1_res/conv2d') 90 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h1_res/batch_norm') 91 | net = Conv2d(net, 3, 3, gf_dim*2, 1, 1, name='g_h1_res/conv2d2', padding='SAME') 92 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h1_res/batch_norm2') 93 | net = Conv2d(net, 3, 3, gf_dim*8, 1, 1, name='g_h1_res/conv2d3', padding='SAME') 94 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='g_h1_res/batch_norm3') 95 | 96 | net_h1 = add([net_h0, net], name='g_h1_res/add') 97 | net_h1_output = tf.nn.relu(net_h1) 98 | 99 | net_h2 = UpSample(net_h1_output, size=[s8, s8], method=1, align_corners=False, name='g_h2/upsample2d') 100 | net_h2 = Conv2d(net_h2, 3, 3, gf_dim*4, 1, 1, name='g_h2/conv2d', padding='SAME') 101 | net_h2 = batch_normalization(net_h2, activation_fn=None, is_training=self.is_training, name='g_h2/batch_norm') 102 | 103 | net = Conv2d(net_h2, 1, 1, gf_dim, 1, 1, name='g_h3_res/conv2d') 104 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h3_res/batch_norm') 105 | net = Conv2d(net, 3, 3, gf_dim, 1, 1, name='g_h3_res/conv2d2', padding='SAME') 106 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h3_res/batch_norm2') 107 | net = Conv2d(net, 3, 3, gf_dim*4, 1, 1, name='g_h3_res/conv2d3', padding='SAME') 108 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='g_h3_res/batch_norm3') 109 | 110 | net_h3 = add([net_h2, net], name='g_h3/add') 111 | net_h3_outputs = tf.nn.relu(net_h3) 112 | 113 | net_h4 = UpSample(net_h3_outputs, size=[s4, s4], method=1, align_corners=False, name='g_h4/upsample2d') 114 | net_h4 = Conv2d(net_h4, 3, 3, gf_dim*2, 1, 1, name='g_h4/conv2d', padding='SAME') 115 | net_h4 = batch_normalization(net_h4, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h4/batch_norm') 116 | 117 | net_h5 = UpSample(net_h4, size=[s2, s2], method=1, align_corners=False, name='g_h5/upsample2d') 118 | net_h5 = Conv2d(net_h5, 3, 3, gf_dim, 1, 1, name='g_h5/conv2d', padding='SAME') 119 | net_h5 = batch_normalization(net_h5, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h5/batch_norm') 120 | 121 | net_ho = UpSample(net_h5, size=[s, s], method=1, align_corners=False, name='g_ho/upsample2d') 122 | net_ho = Conv2d(net_ho, 3, 3, c_dim, 1, 1, name='g_ho/conv2d', padding='SAME', biased=True) ## biased = True 123 | 124 | self.outputs = tf.nn.tanh(net_ho) 125 | self.logits = net_ho 126 | 127 | class Discriminator: 128 | def __init__(self, input_image, input_rnn, is_training, reuse): 129 | self.input_image = input_image 130 | self.input_rnn = input_rnn 131 | self.is_training = is_training 132 | self.reuse = reuse 133 | self.df_dim = 64 134 | self.t_dim = 128 135 | self.image_size = 64 136 | self._build_model() 137 | 138 | def _build_model(self): 139 | s = self.image_size 140 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 141 | 142 | df_dim = self.df_dim 143 | t_dim = self.t_dim 144 | 145 | with tf.variable_scope("discriminator", reuse=self.reuse): 146 | net_h0 = Conv2d(self.input_image, 4, 4, df_dim, 2, 2, name='d_h0/conv2d', activation_fn=tf.nn.leaky_relu, padding='SAME', biased=True) 147 | 148 | net_h1 = Conv2d(net_h0, 4, 4, df_dim*2, 2, 2, name='d_h1/conv2d', padding='SAME') 149 | net_h1 = batch_normalization(net_h1, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h1/batchnorm') 150 | 151 | net_h2 = Conv2d(net_h1, 4, 4, df_dim*4, 2, 2, name='d_h2/conv2d', padding='SAME') 152 | net_h2 = batch_normalization(net_h2, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h2/batchnorm') 153 | 154 | net_h3 = Conv2d(net_h2, 4, 4, df_dim*8, 2, 2, name='d_h3/conv2d', padding='SAME') 155 | net_h3 = batch_normalization(net_h3, activation_fn=None, is_training=self.is_training, name='d_h3/batchnorm') 156 | 157 | net = Conv2d(net_h3, 1, 1, df_dim*2, 1, 1, name='d_h4_res/conv2d') 158 | net = batch_normalization(net, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm') 159 | net = Conv2d(net, 3, 3, df_dim*2, 1, 1, name='d_h4_res/conv2d2', padding='SAME') 160 | net = batch_normalization(net, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm2') 161 | net = Conv2d(net, 3, 3, df_dim*8, 1, 1, name='d_h4_res/conv2d3', padding='SAME') 162 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='d_h4_res/batchnorm3') 163 | 164 | net_h4 = add([net_h3, net], name='d_h4/add') 165 | net_h4_outputs = tf.nn.leaky_relu(net_h4) 166 | 167 | net_txt = fc(self.input_rnn, num_out=t_dim, activation_fn=tf.nn.leaky_relu, name='d_reduce_txt/dense') 168 | net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim1') 169 | net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim2') 170 | net_txt = tf.tile(net_txt, [1, 4, 4, 1], name='d_txt/tile') 171 | 172 | net_h4_concat = concat([net_h4_outputs, net_txt], axis=3, name='d_h3_concat') 173 | 174 | net_h4 = Conv2d(net_h4_concat, 1, 1, df_dim*8, 1, 1, name='d_h3/conv2d_2') 175 | net_h4 = batch_normalization(net_h4, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h3/batch_norm_2') 176 | 177 | net_ho = Conv2d(net_h4, s16, s16, 1, s16, s16, name='d_ho/conv2d', biased=True) # biased = True 178 | 179 | self.outputs = tf.nn.sigmoid(net_ho) 180 | self.logits = net_ho 181 | 182 | class rnn_encoder: 183 | def __init__(self, input_seqs, is_training, reuse): 184 | self.input_seqs = input_seqs 185 | self.is_training = is_training 186 | self.reuse = reuse 187 | self.t_dim = 128 188 | self.rnn_hidden_size = 128 189 | self.vocab_size = 8000 190 | self.word_embedding_size = 256 191 | self.keep_prob = 1.0 192 | self.batch_size = 64 193 | self._build_model() 194 | 195 | def _build_model(self): 196 | w_init = tf.random_normal_initializer(stddev=0.02) 197 | LSTMCell = tf.contrib.rnn.BasicLSTMCell 198 | 199 | with tf.variable_scope("rnnftxt", reuse=self.reuse): 200 | word_embed_matrix = tf.get_variable('rnn/wordembed', 201 | shape=(self.vocab_size, self.word_embedding_size), 202 | initializer=tf.random_normal_initializer(stddev=0.02), 203 | dtype=tf.float32) 204 | embedded_word_ids = tf.nn.embedding_lookup(word_embed_matrix, self.input_seqs) 205 | 206 | # RNN encoder 207 | LSTMCell = tf.contrib.rnn.BasicLSTMCell(self.t_dim, reuse=self.reuse) 208 | initial_state = LSTMCell.zero_state(self.batch_size, dtype=tf.float32) 209 | 210 | rnn_net = tf.nn.dynamic_rnn(cell=LSTMCell, 211 | inputs=embedded_word_ids, 212 | initial_state=initial_state, 213 | dtype=np.float32, 214 | time_major=False, 215 | scope='rnn/dynamic') 216 | 217 | self.rnn_net = rnn_net 218 | self.outputs = rnn_net[0][:, -1, :] 219 | 220 | class cnn_encoder: 221 | def __init__(self, inputs, is_training=True, reuse=False): 222 | self.inputs = inputs 223 | self.is_training = is_training 224 | self.reuse = reuse 225 | self.df_dim = 64 226 | self.t_dim = 128 227 | self._build_model() 228 | 229 | def _build_model(self): 230 | df_dim = self.df_dim 231 | 232 | with tf.variable_scope('cnnftxt', reuse=self.reuse): 233 | net_h0 = Conv2d(self.inputs, 4, 4, df_dim, 2, 2, name='cnnf/h0/conv2d', activation_fn=tf.nn.leaky_relu, padding='SAME', biased=True) 234 | net_h1 = Conv2d(net_h0, 4, 4, df_dim*2, 2, 2, name='cnnf/h1/conv2d', padding='SAME') 235 | net_h1 = batch_normalization(net_h1, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h1/batch_norm') 236 | 237 | net_h2 = Conv2d(net_h1, 4, 4, df_dim*4, 2, 2, name='cnnf/h2/conv2d', padding='SAME') 238 | net_h2 = batch_normalization(net_h2, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h2/batch_norm') 239 | 240 | net_h3 = Conv2d(net_h2, 4, 4, df_dim*8, 2, 2, name='cnnf/h3/conv2d', padding='SAME') 241 | net_h3 = batch_normalization(net_h3, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h3/batch_norm') 242 | 243 | net_h4 = flatten(net_h3, name='cnnf/h4/flatten') 244 | net_h4 = fc(net_h4, num_out=self.t_dim, name='cnnf/h4/embed', biased=False) 245 | 246 | self.outputs = net_h4 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model import * 4 | 5 | import pandas as pd 6 | import os 7 | import scipy 8 | from scipy.io import loadmat 9 | import re 10 | import string 11 | from utils import * 12 | import random 13 | import time 14 | import argparse 15 | 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | 19 | dictionary_path = './dictionary' 20 | vocab = np.load(dictionary_path + '/vocab.npy') 21 | print('there are {} vocabularies in total'.format(len(vocab))) 22 | 23 | word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy')) 24 | id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy')) 25 | 26 | train_images = np.load('train_images.npy', encoding='latin1') 27 | train_captions = np.load('train_captions.npy', encoding='latin1') 28 | 29 | assert len(train_images) == len(train_captions) 30 | 31 | print('----example of captions[0]--------') 32 | for caption in train_captions[0]: 33 | print(IdList2sent(caption)) 34 | 35 | captions_list = [] 36 | for captions in train_captions: 37 | assert len(captions) >= 5 38 | captions_list.append(captions[:5]) 39 | 40 | train_captions = np.concatenate(captions_list, axis=0) 41 | 42 | n_captions_train = len(train_captions) 43 | n_captions_per_image = 5 44 | n_images_train = len(train_images) 45 | 46 | print('Total captions: ', n_captions_train) 47 | print('----example of captions[0] (modified)--------') 48 | for caption in train_captions[:5]: 49 | print(IdList2sent(caption)) 50 | 51 | lr = 0.0002 52 | lr_decay = 0.5 53 | decay_every = 100 54 | beta1 = 0.5 55 | checkpoint_dir = './checkpoint' 56 | 57 | z_dim = 512 # Noise dimension 58 | image_size = 64 # 64 x 64 59 | c_dim = 3 # for rgb 60 | batch_size = 64 61 | ni = int(np.ceil(np.sqrt(batch_size))) 62 | 63 | ### Testing setting 64 | sample_size = batch_size 65 | sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) 66 | 67 | sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \ 68 | ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \ 69 | ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \ 70 | ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \ 71 | ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \ 72 | ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \ 73 | ["this flower has petals that are blue and white."] * int(sample_size/ni) +\ 74 | ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni) 75 | for i, sent in enumerate(sample_sentence): 76 | sample_sentence[i] = sent2IdList(sent) 77 | 78 | print(sample_sentence[0]) 79 | def save(saver, sess, logdir, step): 80 | model_name = 'model.ckpt' 81 | checkpoint_path = os.path.join(logdir, model_name) 82 | 83 | if not os.path.exists(logdir): 84 | os.makedirs(logdir) 85 | saver.save(sess, checkpoint_path, global_step=step) 86 | print('The checkpoint has been created.') 87 | 88 | def load(saver, sess, ckpt_path): 89 | saver.restore(sess, ckpt_path) 90 | print("Restored model parameters from {}".format(ckpt_path)) 91 | 92 | def train(): 93 | t_real_image = tf.placeholder('float32', [batch_size, image_size, image_size, 3], name = 'real_image') 94 | t_wrong_image = tf.placeholder('float32', [batch_size ,image_size, image_size, 3], name = 'wrong_image') 95 | t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input') 96 | t_wrong_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='wrong_caption_input') 97 | t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise') 98 | 99 | ### Training Phase - CNN - RNN mapping 100 | net_cnn = cnn_encoder(t_real_image, is_training=True, reuse=False) 101 | x = net_cnn.outputs 102 | v = rnn_encoder(t_real_caption, is_training=True, reuse=False).outputs 103 | x_w = cnn_encoder(t_wrong_image, is_training=True, reuse=True).outputs 104 | v_w = rnn_encoder(t_wrong_caption, is_training=True, reuse=True).outputs 105 | 106 | alpha = 0.2 # margin alpha 107 | rnn_loss = tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x, v_w))) + \ 108 | tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x_w, v))) 109 | 110 | ### Training Phase - GAN 111 | net_rnn = rnn_encoder(t_real_caption, is_training=False, reuse=True) 112 | net_fake_image = Generator(t_z, net_rnn.outputs, is_training=True, reuse=False) 113 | 114 | net_disc_fake = Discriminator(net_fake_image.outputs, net_rnn.outputs, is_training=True, reuse=False) 115 | disc_fake_logits = net_disc_fake.logits 116 | 117 | net_disc_real = Discriminator(t_real_image, net_rnn.outputs, is_training=True, reuse=True) 118 | disc_real_logits = net_disc_real.logits 119 | 120 | net_disc_mismatch = Discriminator(t_real_image, 121 | rnn_encoder(t_wrong_caption, is_training=False, reuse=True).outputs, 122 | is_training=True, reuse=True) 123 | disc_mismatch_logits = net_disc_mismatch.logits 124 | 125 | d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_logits, labels=tf.ones_like(disc_real_logits), name='d1')) 126 | d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_mismatch_logits, labels=tf.zeros_like(disc_mismatch_logits), name='d2')) 127 | d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_logits, labels=tf.zeros_like(disc_fake_logits), name='d3')) 128 | d_loss = d_loss1 + (d_loss2 + d_loss3) * 0.5 129 | 130 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_logits, labels=tf.ones_like(disc_fake_logits), name='g')) 131 | 132 | ### Testing Phase 133 | net_g = Generator(t_z, 134 | rnn_encoder(t_real_caption, is_training=False, reuse=True).outputs, 135 | is_training=False, reuse=True) 136 | 137 | rnn_vars = [var for var in tf.trainable_variables() if 'rnn' in var.name] 138 | g_vars = [var for var in tf.trainable_variables() if 'generator' in var.name] 139 | d_vars = [var for var in tf.trainable_variables() if 'discrim' in var.name] 140 | cnn_vars = [var for var in tf.trainable_variables() if 'cnn' in var.name] 141 | 142 | update_ops_D = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'discrim' in var.name] 143 | update_ops_G = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'generator' in var.name] 144 | update_ops_CNN = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'cnn' in var.name] 145 | 146 | print('----------Update_ops_D--------') 147 | for var in update_ops_D: 148 | print(var.name) 149 | print('----------Update_ops_G--------') 150 | for var in update_ops_G: 151 | print(var.name) 152 | print('----------Update_ops_CNN--------') 153 | for var in update_ops_CNN: 154 | print(var.name) 155 | 156 | with tf.variable_scope('learning_rate'): 157 | lr_v = tf.Variable(lr, trainable=False) 158 | 159 | with tf.control_dependencies(update_ops_D): 160 | d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) 161 | 162 | with tf.control_dependencies(update_ops_G): 163 | g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) 164 | 165 | with tf.control_dependencies(update_ops_CNN): 166 | grads, _ = tf.clip_by_global_norm(tf.gradients(rnn_loss, rnn_vars + cnn_vars), 10) 167 | optimizer = tf.train.AdamOptimizer(lr_v, beta1=beta1) 168 | rnn_optim = optimizer.apply_gradients(zip(grads, rnn_vars + cnn_vars)) 169 | 170 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 171 | init = tf.global_variables_initializer() 172 | sess.run(init) 173 | 174 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5) 175 | 176 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 177 | if ckpt and ckpt.model_checkpoint_path: 178 | loader = tf.train.Saver(var_list=tf.global_variables()) 179 | load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) 180 | load(loader, sess, ckpt.model_checkpoint_path) 181 | else: 182 | print('no checkpoints find.') 183 | 184 | n_epoch = 600 185 | n_batch_epoch = int(n_images_train / batch_size) 186 | for epoch in range(n_epoch): 187 | start_time = time.time() 188 | 189 | if epoch !=0 and (epoch % decay_every == 0): 190 | new_lr_decay = lr_decay ** (epoch // decay_every) 191 | sess.run(tf.assign(lr_v, lr * new_lr_decay)) 192 | log = " ** new learning rate: %f" % (lr * new_lr_decay) 193 | print(log) 194 | 195 | elif epoch == 0: 196 | log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay) 197 | print(log) 198 | 199 | for step in range(n_batch_epoch): 200 | step_time = time.time() 201 | 202 | ## get matched text & image 203 | idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size) 204 | b_real_caption = train_captions[idexs] 205 | b_real_images = train_images[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')] 206 | 207 | """ check for loading right images 208 | save_images(b_real_images, [ni, ni], 'train_samples/train_00.png') 209 | for caption in b_real_caption[:8]: 210 | print(IdList2sent(caption)) 211 | exit() 212 | """ 213 | 214 | ## get wrong caption & wrong image 215 | idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size) 216 | b_wrong_caption = train_captions[idexs] 217 | idexs2 = get_random_int(min=0, max=n_images_train-1, number=batch_size) 218 | b_wrong_images = train_images[idexs2] 219 | 220 | ## get noise 221 | b_z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, z_dim)).astype(np.float32) 222 | 223 | b_real_images = threading_data(b_real_images, prepro_img, mode='train') # [0, 255] --> [-1, 1] + augmentation 224 | b_wrong_images = threading_data(b_wrong_images, prepro_img, mode='train') 225 | 226 | ## update RNN 227 | if epoch < 80: 228 | errRNN, _ = sess.run([rnn_loss, rnn_optim], feed_dict={ 229 | t_real_image : b_real_images, 230 | t_wrong_image : b_wrong_images, 231 | t_real_caption : b_real_caption, 232 | t_wrong_caption : b_wrong_caption}) 233 | else: 234 | errRNN = 0 235 | 236 | ## updates D 237 | errD, _ = sess.run([d_loss, d_optim], feed_dict={ 238 | t_real_image : b_real_images, 239 | t_wrong_caption : b_wrong_caption, 240 | t_real_caption : b_real_caption, 241 | t_z : b_z}) 242 | ## updates G 243 | errG, _ = sess.run([g_loss, g_optim], feed_dict={ 244 | t_real_caption : b_real_caption, 245 | t_z : b_z}) 246 | 247 | print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.8f, g_loss: %.8f, rnn_loss: %.8f" \ 248 | % (epoch, n_epoch, step, n_batch_epoch, time.time() - step_time, errD, errG, errRNN)) 249 | 250 | if (epoch + 1) % 1 == 0: 251 | print(" ** Epoch %d took %fs" % (epoch, time.time()-start_time)) 252 | img_gen, rnn_out = sess.run([net_g.outputs, net_rnn.outputs], feed_dict={ 253 | t_real_caption : sample_sentence, 254 | t_z : sample_seed}) 255 | 256 | save_images(img_gen, [ni, ni], 'train_samples/train_{:02d}.png'.format(epoch)) 257 | 258 | if (epoch != 0) and (epoch % 10) == 0: 259 | save(saver, sess, checkpoint_dir, epoch) 260 | print("[*] Save checkpoints SUCCESS!") 261 | 262 | testData = os.path.join('dataset', 'testData.pkl') 263 | def test(): 264 | data = pd.read_pickle(testData) 265 | captions = data['Captions'].values 266 | caption = [] 267 | for i in range(len(captions)): 268 | caption.append(captions[i]) 269 | caption = np.asarray(caption) 270 | index = data['ID'].values 271 | index = np.asarray(index) 272 | 273 | t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input') 274 | t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise') 275 | 276 | net_g = Generator(t_z, rnn_encoder(t_real_caption, is_training=False, reuse=False).outputs, 277 | is_training=False, reuse=False) 278 | 279 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 280 | init = tf.global_variables_initializer() 281 | sess.run(init) 282 | 283 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 284 | if ckpt and ckpt.model_checkpoint_path: 285 | loader = tf.train.Saver(var_list=tf.global_variables()) 286 | load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) 287 | load(loader, sess, ckpt.model_checkpoint_path) 288 | else: 289 | print('no checkpoints find.') 290 | 291 | n_caption_test = len(caption) 292 | n_batch_epoch = int(n_caption_test / batch_size) + 1 293 | 294 | ## repeat 295 | caption = np.tile(caption, (2, 1)) 296 | index = np.tile(index, 2) 297 | 298 | assert index[0] == index[n_caption_test] 299 | 300 | for i in range(n_batch_epoch): 301 | test_cap = caption[i*batch_size: (i+1)*batch_size] 302 | 303 | z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, z_dim)).astype(np.float32) 304 | gen = sess.run(net_g.outputs, feed_dict={t_real_caption: test_cap, t_z: z}) 305 | for j in range(batch_size): 306 | save_images(np.expand_dims(gen[j], axis=0), [1, 1], 'inference/inference_{:04d}.png'.format(index[i*batch_size + j])) 307 | 308 | if __name__ == '__main__': 309 | parser = argparse.ArgumentParser(description="Text-to-image") 310 | parser.add_argument("--mode", type=str, default='train', 311 | help="train/test") 312 | 313 | args = parser.parse_args() 314 | if args.mode == 'train': 315 | print('In training mode.') 316 | train() 317 | elif args.mode == 'test': 318 | print('In testing mode.') 319 | test() 320 | -------------------------------------------------------------------------------- /train_samples/train_799.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/train_samples/train_799.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import random 4 | import scipy 5 | import scipy.misc 6 | import numpy as np 7 | import re 8 | import string 9 | import threading 10 | import scipy.ndimage as ndi 11 | from skimage import transform 12 | from skimage import exposure 13 | import skimage 14 | 15 | dictionary_path = './dictionary' 16 | word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy')) 17 | id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy')) 18 | 19 | def sent2IdList(line, MAX_SEQ_LENGTH=20): 20 | MAX_SEQ_LIMIT = MAX_SEQ_LENGTH 21 | padding = 0 22 | prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip()) 23 | prep_line = prep_line.replace('-', ' ') 24 | prep_line = prep_line.replace('-', ' ') 25 | prep_line = prep_line.replace(' ', ' ') 26 | prep_line = prep_line.replace('.', '') 27 | tokens = prep_line.split(' ') 28 | tokens = [ 29 | tokens[i] for i in range(len(tokens)) 30 | if tokens[i] != ' ' and tokens[i] != '' 31 | ] 32 | l = len(tokens) 33 | padding = MAX_SEQ_LIMIT - l 34 | for i in range(padding): 35 | tokens.append('') 36 | 37 | line = [ 38 | word2Id_dict[tokens[k]] 39 | if tokens[k] in word2Id_dict else word2Id_dict[''] 40 | for k in range(len(tokens)) 41 | ] 42 | 43 | return line 44 | 45 | def IdList2sent(caption): 46 | sentence = [] 47 | for ID in caption: 48 | if ID != word2Id_dict['']: 49 | sentence.append(id2word_dict[ID]) 50 | 51 | return sentence 52 | 53 | def get_random_int(min=0, max=10, number=5): 54 | """Return a list of random integer by the given range and quantity. 55 | Examples 56 | --------- 57 | >>> r = get_random_int(min=0, max=10, number=5) 58 | ... [10, 2, 3, 3, 7] 59 | """ 60 | return [random.randint(min,max) for p in range(0,number)] 61 | 62 | ## Save images 63 | def merge(images, size): 64 | h, w = images.shape[1], images.shape[2] 65 | img = np.zeros((h * size[0], w * size[1], 3)) 66 | for idx, image in enumerate(images): 67 | i = idx % size[1] 68 | j = idx // size[1] 69 | img[j*h:j*h+h, i*w:i*w+w, :] = image 70 | return img 71 | 72 | def imsave(images, size, path): 73 | return scipy.misc.imsave(path, merge(images, size)) 74 | 75 | def save_images(images, size, image_path): 76 | return imsave(images, size, image_path) 77 | 78 | # Data Augmentation reference: https://github.com/tensorlayer/tensorlayer/tree/master/tensorlayer 79 | def threading_data(data=None, fn=None, **kwargs): 80 | def apply_fn(results, i, data, kwargs): 81 | results[i] = fn(data, **kwargs) 82 | 83 | ## start multi-threaded reading. 84 | results = [None] * len(data) ## preallocate result list 85 | threads = [] 86 | for i in range(len(data)): 87 | t = threading.Thread( 88 | name='threading_and_return', 89 | target=apply_fn, 90 | args=(results, i, data[i], kwargs) 91 | ) 92 | t.start() 93 | threads.append(t) 94 | 95 | for t in threads: 96 | t.join() 97 | 98 | return np.asarray(results) 99 | 100 | def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', cval=0., order=1): 101 | x = np.rollaxis(x, channel_index, 0) 102 | final_affine_matrix = transform_matrix[:2, :2] 103 | final_offset = transform_matrix[:2, 2] 104 | channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix, 105 | final_offset, order=order, mode=fill_mode, cval=cval) for x_channel in x] 106 | x = np.stack(channel_images, axis=0) 107 | x = np.rollaxis(x, 0, channel_index+1) 108 | return x 109 | 110 | def transform_matrix_offset_center(matrix, x, y): 111 | o_x = float(x) / 2 + 0.5 112 | o_y = float(y) / 2 + 0.5 113 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 114 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 115 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 116 | return transform_matrix 117 | 118 | def rotation(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=2, 119 | fill_mode='nearest', cval=0.): 120 | if is_random: 121 | theta = np.pi / 180 * np.random.uniform(-rg, rg) 122 | else: 123 | theta = np.pi /180 * rg 124 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 125 | [np.sin(theta), np.cos(theta), 0], 126 | [0, 0, 1]]) 127 | 128 | h, w = x.shape[row_index], x.shape[col_index] 129 | transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w) 130 | x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval) 131 | return x 132 | 133 | def crop(x, wrg, hrg, is_random=False, row_index=0, col_index=1, channel_index=2): 134 | h, w = x.shape[row_index], x.shape[col_index] 135 | assert (h > hrg) and (w > wrg), "The size of cropping should smaller than the original image" 136 | if is_random: 137 | h_offset = int(np.random.uniform(0, h-hrg) -1) 138 | w_offset = int(np.random.uniform(0, w-wrg) -1) 139 | return x[h_offset: hrg+h_offset ,w_offset: wrg+w_offset] 140 | else: # central crop 141 | h_offset = int(np.floor((h - hrg)/2.)) 142 | w_offset = int(np.floor((w - wrg)/2.)) 143 | h_end = h_offset + hrg 144 | w_end = w_offset + wrg 145 | return x[h_offset: h_end, w_offset: w_end] 146 | 147 | def flip_axis(x, axis, is_random=False): 148 | if is_random: 149 | factor = np.random.uniform(-1, 1) 150 | if factor > 0: 151 | x = np.asarray(x).swapaxes(axis, 0) 152 | x = x[::-1, ...] 153 | x = x.swapaxes(0, axis) 154 | return x 155 | else: 156 | return x 157 | else: 158 | x = np.asarray(x).swapaxes(axis, 0) 159 | x = x[::-1, ...] 160 | x = x.swapaxes(0, axis) 161 | return x 162 | 163 | def imresize(x, size=[100, 100], interp='bilinear', mode=None): 164 | if x.shape[-1] == 1: 165 | # greyscale 166 | x = scipy.misc.imresize(x[:,:,0], size, interp=interp, mode=mode) 167 | return x[:, :, np.newaxis] 168 | elif x.shape[-1] == 3: 169 | # rgb, bgr .. 170 | return scipy.misc.imresize(x, size, interp=interp, mode=mode) 171 | else: 172 | raise Exception("Unsupported channel %d" % x.shape[-1]) 173 | 174 | def prepro_img(x, mode=None): 175 | # rescale [0, 255] --> (-1, 1), random flip, crop, rotate 176 | 177 | if mode=='train': 178 | x = flip_axis(x, axis=1, is_random=True) 179 | x = rotation(x, rg=16, is_random=True, fill_mode='nearest') 180 | x = imresize(x, size=[64+15, 64+15], interp='bilinear', mode=None) 181 | x = crop(x, wrg=64, hrg=64, is_random=True) 182 | x = x / (255. / 2.) 183 | x = x - 1. 184 | # x = x * 0.9999 185 | 186 | return x 187 | 188 | def cosine_similarity(v1, v2): 189 | cost = tf.reduce_sum(tf.multiply(v1, v2), 1) / (tf.sqrt(tf.reduce_sum(tf.multiply(v1, v1), 1)) * tf.sqrt(tf.reduce_sum(tf.multiply(v2, v2), 1))) 190 | return cost 191 | 192 | def combine_and_save_image_sets(image_sets, directory): 193 | for i in range(len(image_sets[0])): 194 | combined_image = [] 195 | for set_no in range(len(image_sets)): 196 | combined_image.append( image_sets[set_no][i] ) 197 | combined_image.append( np.zeros((image_sets[set_no][i].shape[0], 5, 3)) ) 198 | combined_image = np.concatenate( combined_image, axis = 1 ) 199 | 200 | scipy.misc.imsave( os.path.join( directory, 'combined_{}.jpg'.format(i) ), combined_image) --------------------------------------------------------------------------------