├── .DS_Store
├── ._.DS_Store
├── .gitignore
├── README.md
├── architecture.jpeg
├── dictionary
├── id2Word.npy
├── vocab.npy
└── word2Id.npy
├── model.py
├── train.py
├── train_samples
└── train_799.png
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/.DS_Store
--------------------------------------------------------------------------------
/._.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/._.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.npy
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # text-to-image
2 | Implement for [Kaggles Contest - Reverse Image Caption](https://www.kaggle.com/c/datalabcup-reverse-image-caption-ver2/leaderboard)
3 | ## Architecture
4 |
5 |
6 | Using `GAN-CLS` algorithm from the paper [Generative Adversarial Text-to-Image Synthesis](http://arxiv.org/abs/1605.05396) and `stackGAN-stage1` from [StackGAN - Github](https://github.com/hanzhanggit/StackGAN)
7 |
8 | ## Prepare Data
9 | Download image files and captions from [Google Drive](https://drive.google.com/drive/folders/1aUJrBoIN3l9U5p5pNXT0NeNzlyBWF54u?usp=sharing), put into `./text-to-image` directory
10 |
11 | ## Result ( After 800 epoch )
12 | * the flower shown has yellow anther red pistil and bright red petals.
13 | * this flower has petals that are yellow, white and purple and has dark lines
14 | * the petals on this flower are white with a yellow center
15 | * this flower has a lot of small round pink petals.
16 | * this flower is orange in color, and has petals that are ruffled and rounded.
17 | * the flower has yellow petals and the center of it is brown
18 | * this flower has petals that are blue and white.
19 | * these white flowers have petals that start off white in color and end in a white towards the tips.
20 |
21 |
--------------------------------------------------------------------------------
/architecture.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/architecture.jpeg
--------------------------------------------------------------------------------
/dictionary/id2Word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/id2Word.npy
--------------------------------------------------------------------------------
/dictionary/vocab.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/vocab.npy
--------------------------------------------------------------------------------
/dictionary/word2Id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/dictionary/word2Id.npy
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | def fc(inputs, num_out, name, activation_fn=None, biased=True):
5 | w_init = tf.random_normal_initializer(stddev=0.02)
6 | return tf.layers.dense(inputs=inputs, units=num_out, activation=activation_fn, kernel_initializer=w_init, use_bias=biased, name=name)
7 |
8 |
9 | def concat(inputs, axis, name):
10 | return tf.concat(values=inputs, axis=axis, name=name)
11 |
12 | def batch_normalization(inputs, is_training, name, activation_fn=None):
13 | output = tf.layers.batch_normalization(
14 | inputs,
15 | momentum=0.95,
16 | epsilon=1e-5,
17 | training=is_training,
18 | name=name
19 | )
20 |
21 | if activation_fn is not None:
22 | output = activation_fn(output)
23 |
24 | return output
25 |
26 | def reshape(inputs, shape, name):
27 | return tf.reshape(inputs, shape, name)
28 |
29 | def Conv2d(input, k_h, k_w, c_o, s_h, s_w, name, activation_fn=None, padding='VALID', biased=False):
30 | c_i = input.get_shape()[-1]
31 | w_init = tf.random_normal_initializer(stddev=0.02)
32 |
33 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
34 | with tf.variable_scope(name) as scope:
35 | kernel = tf.get_variable(name='weights', shape=[k_h, k_w, c_i, c_o], initializer=w_init)
36 | output = convolve(input, kernel)
37 |
38 | if biased:
39 | biases = tf.get_variable(name='biases', shape=[c_o])
40 | output = tf.nn.bias_add(output, biases)
41 | if activation_fn is not None:
42 | output = activation_fn(output, name=scope.name)
43 |
44 | return output
45 |
46 | def add(inputs, name):
47 | return tf.add_n(inputs, name=name)
48 |
49 | def UpSample(inputs, size, method, align_corners, name):
50 | return tf.image.resize_images(inputs, size, method, align_corners)
51 |
52 | def flatten(input, name):
53 | input_shape = input.get_shape()
54 | dim = 1
55 | for d in input_shape[1:].as_list():
56 | dim *= d
57 | input = tf.reshape(input, [-1, dim])
58 |
59 | return input
60 |
61 | class Generator:
62 | def __init__(self, input_z, input_rnn, is_training, reuse):
63 | self.input_z = input_z
64 | self.input_rnn = input_rnn
65 | self.is_training = is_training
66 | self.reuse = reuse
67 | self.t_dim = 128
68 | self.gf_dim = 128
69 | self.image_size = 64
70 | self.c_dim = 3
71 | self._build_model()
72 |
73 | def _build_model(self):
74 | s = self.image_size
75 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)
76 |
77 | gf_dim = self.gf_dim
78 | t_dim = self.t_dim
79 | c_dim = self.c_dim
80 |
81 | with tf.variable_scope("generator", reuse=self.reuse):
82 | net_txt = fc(inputs=self.input_rnn, num_out=t_dim, activation_fn=tf.nn.leaky_relu, name='rnn_fc')
83 | net_in = concat([self.input_z, net_txt], axis=1, name='concat_z_txt')
84 |
85 | net_h0 = fc(inputs=net_in, num_out=gf_dim*8*s16*s16, name='g_h0/fc', biased=False)
86 | net_h0 = batch_normalization(net_h0, activation_fn=None, is_training=self.is_training, name='g_h0/batch_norm')
87 | net_h0 = reshape(net_h0, [-1, s16, s16, gf_dim*8], name='g_h0/reshape')
88 |
89 | net = Conv2d(net_h0, 1, 1, gf_dim*2, 1, 1, name='g_h1_res/conv2d')
90 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h1_res/batch_norm')
91 | net = Conv2d(net, 3, 3, gf_dim*2, 1, 1, name='g_h1_res/conv2d2', padding='SAME')
92 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h1_res/batch_norm2')
93 | net = Conv2d(net, 3, 3, gf_dim*8, 1, 1, name='g_h1_res/conv2d3', padding='SAME')
94 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='g_h1_res/batch_norm3')
95 |
96 | net_h1 = add([net_h0, net], name='g_h1_res/add')
97 | net_h1_output = tf.nn.relu(net_h1)
98 |
99 | net_h2 = UpSample(net_h1_output, size=[s8, s8], method=1, align_corners=False, name='g_h2/upsample2d')
100 | net_h2 = Conv2d(net_h2, 3, 3, gf_dim*4, 1, 1, name='g_h2/conv2d', padding='SAME')
101 | net_h2 = batch_normalization(net_h2, activation_fn=None, is_training=self.is_training, name='g_h2/batch_norm')
102 |
103 | net = Conv2d(net_h2, 1, 1, gf_dim, 1, 1, name='g_h3_res/conv2d')
104 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h3_res/batch_norm')
105 | net = Conv2d(net, 3, 3, gf_dim, 1, 1, name='g_h3_res/conv2d2', padding='SAME')
106 | net = batch_normalization(net, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h3_res/batch_norm2')
107 | net = Conv2d(net, 3, 3, gf_dim*4, 1, 1, name='g_h3_res/conv2d3', padding='SAME')
108 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='g_h3_res/batch_norm3')
109 |
110 | net_h3 = add([net_h2, net], name='g_h3/add')
111 | net_h3_outputs = tf.nn.relu(net_h3)
112 |
113 | net_h4 = UpSample(net_h3_outputs, size=[s4, s4], method=1, align_corners=False, name='g_h4/upsample2d')
114 | net_h4 = Conv2d(net_h4, 3, 3, gf_dim*2, 1, 1, name='g_h4/conv2d', padding='SAME')
115 | net_h4 = batch_normalization(net_h4, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h4/batch_norm')
116 |
117 | net_h5 = UpSample(net_h4, size=[s2, s2], method=1, align_corners=False, name='g_h5/upsample2d')
118 | net_h5 = Conv2d(net_h5, 3, 3, gf_dim, 1, 1, name='g_h5/conv2d', padding='SAME')
119 | net_h5 = batch_normalization(net_h5, activation_fn=tf.nn.relu, is_training=self.is_training, name='g_h5/batch_norm')
120 |
121 | net_ho = UpSample(net_h5, size=[s, s], method=1, align_corners=False, name='g_ho/upsample2d')
122 | net_ho = Conv2d(net_ho, 3, 3, c_dim, 1, 1, name='g_ho/conv2d', padding='SAME', biased=True) ## biased = True
123 |
124 | self.outputs = tf.nn.tanh(net_ho)
125 | self.logits = net_ho
126 |
127 | class Discriminator:
128 | def __init__(self, input_image, input_rnn, is_training, reuse):
129 | self.input_image = input_image
130 | self.input_rnn = input_rnn
131 | self.is_training = is_training
132 | self.reuse = reuse
133 | self.df_dim = 64
134 | self.t_dim = 128
135 | self.image_size = 64
136 | self._build_model()
137 |
138 | def _build_model(self):
139 | s = self.image_size
140 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)
141 |
142 | df_dim = self.df_dim
143 | t_dim = self.t_dim
144 |
145 | with tf.variable_scope("discriminator", reuse=self.reuse):
146 | net_h0 = Conv2d(self.input_image, 4, 4, df_dim, 2, 2, name='d_h0/conv2d', activation_fn=tf.nn.leaky_relu, padding='SAME', biased=True)
147 |
148 | net_h1 = Conv2d(net_h0, 4, 4, df_dim*2, 2, 2, name='d_h1/conv2d', padding='SAME')
149 | net_h1 = batch_normalization(net_h1, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h1/batchnorm')
150 |
151 | net_h2 = Conv2d(net_h1, 4, 4, df_dim*4, 2, 2, name='d_h2/conv2d', padding='SAME')
152 | net_h2 = batch_normalization(net_h2, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h2/batchnorm')
153 |
154 | net_h3 = Conv2d(net_h2, 4, 4, df_dim*8, 2, 2, name='d_h3/conv2d', padding='SAME')
155 | net_h3 = batch_normalization(net_h3, activation_fn=None, is_training=self.is_training, name='d_h3/batchnorm')
156 |
157 | net = Conv2d(net_h3, 1, 1, df_dim*2, 1, 1, name='d_h4_res/conv2d')
158 | net = batch_normalization(net, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm')
159 | net = Conv2d(net, 3, 3, df_dim*2, 1, 1, name='d_h4_res/conv2d2', padding='SAME')
160 | net = batch_normalization(net, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h4_res/batchnorm2')
161 | net = Conv2d(net, 3, 3, df_dim*8, 1, 1, name='d_h4_res/conv2d3', padding='SAME')
162 | net = batch_normalization(net, activation_fn=None, is_training=self.is_training, name='d_h4_res/batchnorm3')
163 |
164 | net_h4 = add([net_h3, net], name='d_h4/add')
165 | net_h4_outputs = tf.nn.leaky_relu(net_h4)
166 |
167 | net_txt = fc(self.input_rnn, num_out=t_dim, activation_fn=tf.nn.leaky_relu, name='d_reduce_txt/dense')
168 | net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim1')
169 | net_txt = tf.expand_dims(net_txt, axis=1, name='d_txt/expanddim2')
170 | net_txt = tf.tile(net_txt, [1, 4, 4, 1], name='d_txt/tile')
171 |
172 | net_h4_concat = concat([net_h4_outputs, net_txt], axis=3, name='d_h3_concat')
173 |
174 | net_h4 = Conv2d(net_h4_concat, 1, 1, df_dim*8, 1, 1, name='d_h3/conv2d_2')
175 | net_h4 = batch_normalization(net_h4, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='d_h3/batch_norm_2')
176 |
177 | net_ho = Conv2d(net_h4, s16, s16, 1, s16, s16, name='d_ho/conv2d', biased=True) # biased = True
178 |
179 | self.outputs = tf.nn.sigmoid(net_ho)
180 | self.logits = net_ho
181 |
182 | class rnn_encoder:
183 | def __init__(self, input_seqs, is_training, reuse):
184 | self.input_seqs = input_seqs
185 | self.is_training = is_training
186 | self.reuse = reuse
187 | self.t_dim = 128
188 | self.rnn_hidden_size = 128
189 | self.vocab_size = 8000
190 | self.word_embedding_size = 256
191 | self.keep_prob = 1.0
192 | self.batch_size = 64
193 | self._build_model()
194 |
195 | def _build_model(self):
196 | w_init = tf.random_normal_initializer(stddev=0.02)
197 | LSTMCell = tf.contrib.rnn.BasicLSTMCell
198 |
199 | with tf.variable_scope("rnnftxt", reuse=self.reuse):
200 | word_embed_matrix = tf.get_variable('rnn/wordembed',
201 | shape=(self.vocab_size, self.word_embedding_size),
202 | initializer=tf.random_normal_initializer(stddev=0.02),
203 | dtype=tf.float32)
204 | embedded_word_ids = tf.nn.embedding_lookup(word_embed_matrix, self.input_seqs)
205 |
206 | # RNN encoder
207 | LSTMCell = tf.contrib.rnn.BasicLSTMCell(self.t_dim, reuse=self.reuse)
208 | initial_state = LSTMCell.zero_state(self.batch_size, dtype=tf.float32)
209 |
210 | rnn_net = tf.nn.dynamic_rnn(cell=LSTMCell,
211 | inputs=embedded_word_ids,
212 | initial_state=initial_state,
213 | dtype=np.float32,
214 | time_major=False,
215 | scope='rnn/dynamic')
216 |
217 | self.rnn_net = rnn_net
218 | self.outputs = rnn_net[0][:, -1, :]
219 |
220 | class cnn_encoder:
221 | def __init__(self, inputs, is_training=True, reuse=False):
222 | self.inputs = inputs
223 | self.is_training = is_training
224 | self.reuse = reuse
225 | self.df_dim = 64
226 | self.t_dim = 128
227 | self._build_model()
228 |
229 | def _build_model(self):
230 | df_dim = self.df_dim
231 |
232 | with tf.variable_scope('cnnftxt', reuse=self.reuse):
233 | net_h0 = Conv2d(self.inputs, 4, 4, df_dim, 2, 2, name='cnnf/h0/conv2d', activation_fn=tf.nn.leaky_relu, padding='SAME', biased=True)
234 | net_h1 = Conv2d(net_h0, 4, 4, df_dim*2, 2, 2, name='cnnf/h1/conv2d', padding='SAME')
235 | net_h1 = batch_normalization(net_h1, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h1/batch_norm')
236 |
237 | net_h2 = Conv2d(net_h1, 4, 4, df_dim*4, 2, 2, name='cnnf/h2/conv2d', padding='SAME')
238 | net_h2 = batch_normalization(net_h2, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h2/batch_norm')
239 |
240 | net_h3 = Conv2d(net_h2, 4, 4, df_dim*8, 2, 2, name='cnnf/h3/conv2d', padding='SAME')
241 | net_h3 = batch_normalization(net_h3, activation_fn=tf.nn.leaky_relu, is_training=self.is_training, name='cnnf/h3/batch_norm')
242 |
243 | net_h4 = flatten(net_h3, name='cnnf/h4/flatten')
244 | net_h4 = fc(net_h4, num_out=self.t_dim, name='cnnf/h4/embed', biased=False)
245 |
246 | self.outputs = net_h4
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from model import *
4 |
5 | import pandas as pd
6 | import os
7 | import scipy
8 | from scipy.io import loadmat
9 | import re
10 | import string
11 | from utils import *
12 | import random
13 | import time
14 | import argparse
15 |
16 | import warnings
17 | warnings.filterwarnings('ignore')
18 |
19 | dictionary_path = './dictionary'
20 | vocab = np.load(dictionary_path + '/vocab.npy')
21 | print('there are {} vocabularies in total'.format(len(vocab)))
22 |
23 | word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
24 | id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
25 |
26 | train_images = np.load('train_images.npy', encoding='latin1')
27 | train_captions = np.load('train_captions.npy', encoding='latin1')
28 |
29 | assert len(train_images) == len(train_captions)
30 |
31 | print('----example of captions[0]--------')
32 | for caption in train_captions[0]:
33 | print(IdList2sent(caption))
34 |
35 | captions_list = []
36 | for captions in train_captions:
37 | assert len(captions) >= 5
38 | captions_list.append(captions[:5])
39 |
40 | train_captions = np.concatenate(captions_list, axis=0)
41 |
42 | n_captions_train = len(train_captions)
43 | n_captions_per_image = 5
44 | n_images_train = len(train_images)
45 |
46 | print('Total captions: ', n_captions_train)
47 | print('----example of captions[0] (modified)--------')
48 | for caption in train_captions[:5]:
49 | print(IdList2sent(caption))
50 |
51 | lr = 0.0002
52 | lr_decay = 0.5
53 | decay_every = 100
54 | beta1 = 0.5
55 | checkpoint_dir = './checkpoint'
56 |
57 | z_dim = 512 # Noise dimension
58 | image_size = 64 # 64 x 64
59 | c_dim = 3 # for rgb
60 | batch_size = 64
61 | ni = int(np.ceil(np.sqrt(batch_size)))
62 |
63 | ### Testing setting
64 | sample_size = batch_size
65 | sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
66 |
67 | sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
68 | ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
69 | ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
70 | ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
71 | ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
72 | ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
73 | ["this flower has petals that are blue and white."] * int(sample_size/ni) +\
74 | ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)
75 | for i, sent in enumerate(sample_sentence):
76 | sample_sentence[i] = sent2IdList(sent)
77 |
78 | print(sample_sentence[0])
79 | def save(saver, sess, logdir, step):
80 | model_name = 'model.ckpt'
81 | checkpoint_path = os.path.join(logdir, model_name)
82 |
83 | if not os.path.exists(logdir):
84 | os.makedirs(logdir)
85 | saver.save(sess, checkpoint_path, global_step=step)
86 | print('The checkpoint has been created.')
87 |
88 | def load(saver, sess, ckpt_path):
89 | saver.restore(sess, ckpt_path)
90 | print("Restored model parameters from {}".format(ckpt_path))
91 |
92 | def train():
93 | t_real_image = tf.placeholder('float32', [batch_size, image_size, image_size, 3], name = 'real_image')
94 | t_wrong_image = tf.placeholder('float32', [batch_size ,image_size, image_size, 3], name = 'wrong_image')
95 | t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input')
96 | t_wrong_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='wrong_caption_input')
97 | t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise')
98 |
99 | ### Training Phase - CNN - RNN mapping
100 | net_cnn = cnn_encoder(t_real_image, is_training=True, reuse=False)
101 | x = net_cnn.outputs
102 | v = rnn_encoder(t_real_caption, is_training=True, reuse=False).outputs
103 | x_w = cnn_encoder(t_wrong_image, is_training=True, reuse=True).outputs
104 | v_w = rnn_encoder(t_wrong_caption, is_training=True, reuse=True).outputs
105 |
106 | alpha = 0.2 # margin alpha
107 | rnn_loss = tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x, v_w))) + \
108 | tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x_w, v)))
109 |
110 | ### Training Phase - GAN
111 | net_rnn = rnn_encoder(t_real_caption, is_training=False, reuse=True)
112 | net_fake_image = Generator(t_z, net_rnn.outputs, is_training=True, reuse=False)
113 |
114 | net_disc_fake = Discriminator(net_fake_image.outputs, net_rnn.outputs, is_training=True, reuse=False)
115 | disc_fake_logits = net_disc_fake.logits
116 |
117 | net_disc_real = Discriminator(t_real_image, net_rnn.outputs, is_training=True, reuse=True)
118 | disc_real_logits = net_disc_real.logits
119 |
120 | net_disc_mismatch = Discriminator(t_real_image,
121 | rnn_encoder(t_wrong_caption, is_training=False, reuse=True).outputs,
122 | is_training=True, reuse=True)
123 | disc_mismatch_logits = net_disc_mismatch.logits
124 |
125 | d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_logits, labels=tf.ones_like(disc_real_logits), name='d1'))
126 | d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_mismatch_logits, labels=tf.zeros_like(disc_mismatch_logits), name='d2'))
127 | d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_logits, labels=tf.zeros_like(disc_fake_logits), name='d3'))
128 | d_loss = d_loss1 + (d_loss2 + d_loss3) * 0.5
129 |
130 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_logits, labels=tf.ones_like(disc_fake_logits), name='g'))
131 |
132 | ### Testing Phase
133 | net_g = Generator(t_z,
134 | rnn_encoder(t_real_caption, is_training=False, reuse=True).outputs,
135 | is_training=False, reuse=True)
136 |
137 | rnn_vars = [var for var in tf.trainable_variables() if 'rnn' in var.name]
138 | g_vars = [var for var in tf.trainable_variables() if 'generator' in var.name]
139 | d_vars = [var for var in tf.trainable_variables() if 'discrim' in var.name]
140 | cnn_vars = [var for var in tf.trainable_variables() if 'cnn' in var.name]
141 |
142 | update_ops_D = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'discrim' in var.name]
143 | update_ops_G = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'generator' in var.name]
144 | update_ops_CNN = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'cnn' in var.name]
145 |
146 | print('----------Update_ops_D--------')
147 | for var in update_ops_D:
148 | print(var.name)
149 | print('----------Update_ops_G--------')
150 | for var in update_ops_G:
151 | print(var.name)
152 | print('----------Update_ops_CNN--------')
153 | for var in update_ops_CNN:
154 | print(var.name)
155 |
156 | with tf.variable_scope('learning_rate'):
157 | lr_v = tf.Variable(lr, trainable=False)
158 |
159 | with tf.control_dependencies(update_ops_D):
160 | d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)
161 |
162 | with tf.control_dependencies(update_ops_G):
163 | g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
164 |
165 | with tf.control_dependencies(update_ops_CNN):
166 | grads, _ = tf.clip_by_global_norm(tf.gradients(rnn_loss, rnn_vars + cnn_vars), 10)
167 | optimizer = tf.train.AdamOptimizer(lr_v, beta1=beta1)
168 | rnn_optim = optimizer.apply_gradients(zip(grads, rnn_vars + cnn_vars))
169 |
170 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
171 | init = tf.global_variables_initializer()
172 | sess.run(init)
173 |
174 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)
175 |
176 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
177 | if ckpt and ckpt.model_checkpoint_path:
178 | loader = tf.train.Saver(var_list=tf.global_variables())
179 | load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
180 | load(loader, sess, ckpt.model_checkpoint_path)
181 | else:
182 | print('no checkpoints find.')
183 |
184 | n_epoch = 600
185 | n_batch_epoch = int(n_images_train / batch_size)
186 | for epoch in range(n_epoch):
187 | start_time = time.time()
188 |
189 | if epoch !=0 and (epoch % decay_every == 0):
190 | new_lr_decay = lr_decay ** (epoch // decay_every)
191 | sess.run(tf.assign(lr_v, lr * new_lr_decay))
192 | log = " ** new learning rate: %f" % (lr * new_lr_decay)
193 | print(log)
194 |
195 | elif epoch == 0:
196 | log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay)
197 | print(log)
198 |
199 | for step in range(n_batch_epoch):
200 | step_time = time.time()
201 |
202 | ## get matched text & image
203 | idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size)
204 | b_real_caption = train_captions[idexs]
205 | b_real_images = train_images[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')]
206 |
207 | """ check for loading right images
208 | save_images(b_real_images, [ni, ni], 'train_samples/train_00.png')
209 | for caption in b_real_caption[:8]:
210 | print(IdList2sent(caption))
211 | exit()
212 | """
213 |
214 | ## get wrong caption & wrong image
215 | idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size)
216 | b_wrong_caption = train_captions[idexs]
217 | idexs2 = get_random_int(min=0, max=n_images_train-1, number=batch_size)
218 | b_wrong_images = train_images[idexs2]
219 |
220 | ## get noise
221 | b_z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, z_dim)).astype(np.float32)
222 |
223 | b_real_images = threading_data(b_real_images, prepro_img, mode='train') # [0, 255] --> [-1, 1] + augmentation
224 | b_wrong_images = threading_data(b_wrong_images, prepro_img, mode='train')
225 |
226 | ## update RNN
227 | if epoch < 80:
228 | errRNN, _ = sess.run([rnn_loss, rnn_optim], feed_dict={
229 | t_real_image : b_real_images,
230 | t_wrong_image : b_wrong_images,
231 | t_real_caption : b_real_caption,
232 | t_wrong_caption : b_wrong_caption})
233 | else:
234 | errRNN = 0
235 |
236 | ## updates D
237 | errD, _ = sess.run([d_loss, d_optim], feed_dict={
238 | t_real_image : b_real_images,
239 | t_wrong_caption : b_wrong_caption,
240 | t_real_caption : b_real_caption,
241 | t_z : b_z})
242 | ## updates G
243 | errG, _ = sess.run([g_loss, g_optim], feed_dict={
244 | t_real_caption : b_real_caption,
245 | t_z : b_z})
246 |
247 | print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.8f, g_loss: %.8f, rnn_loss: %.8f" \
248 | % (epoch, n_epoch, step, n_batch_epoch, time.time() - step_time, errD, errG, errRNN))
249 |
250 | if (epoch + 1) % 1 == 0:
251 | print(" ** Epoch %d took %fs" % (epoch, time.time()-start_time))
252 | img_gen, rnn_out = sess.run([net_g.outputs, net_rnn.outputs], feed_dict={
253 | t_real_caption : sample_sentence,
254 | t_z : sample_seed})
255 |
256 | save_images(img_gen, [ni, ni], 'train_samples/train_{:02d}.png'.format(epoch))
257 |
258 | if (epoch != 0) and (epoch % 10) == 0:
259 | save(saver, sess, checkpoint_dir, epoch)
260 | print("[*] Save checkpoints SUCCESS!")
261 |
262 | testData = os.path.join('dataset', 'testData.pkl')
263 | def test():
264 | data = pd.read_pickle(testData)
265 | captions = data['Captions'].values
266 | caption = []
267 | for i in range(len(captions)):
268 | caption.append(captions[i])
269 | caption = np.asarray(caption)
270 | index = data['ID'].values
271 | index = np.asarray(index)
272 |
273 | t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input')
274 | t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise')
275 |
276 | net_g = Generator(t_z, rnn_encoder(t_real_caption, is_training=False, reuse=False).outputs,
277 | is_training=False, reuse=False)
278 |
279 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
280 | init = tf.global_variables_initializer()
281 | sess.run(init)
282 |
283 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
284 | if ckpt and ckpt.model_checkpoint_path:
285 | loader = tf.train.Saver(var_list=tf.global_variables())
286 | load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
287 | load(loader, sess, ckpt.model_checkpoint_path)
288 | else:
289 | print('no checkpoints find.')
290 |
291 | n_caption_test = len(caption)
292 | n_batch_epoch = int(n_caption_test / batch_size) + 1
293 |
294 | ## repeat
295 | caption = np.tile(caption, (2, 1))
296 | index = np.tile(index, 2)
297 |
298 | assert index[0] == index[n_caption_test]
299 |
300 | for i in range(n_batch_epoch):
301 | test_cap = caption[i*batch_size: (i+1)*batch_size]
302 |
303 | z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, z_dim)).astype(np.float32)
304 | gen = sess.run(net_g.outputs, feed_dict={t_real_caption: test_cap, t_z: z})
305 | for j in range(batch_size):
306 | save_images(np.expand_dims(gen[j], axis=0), [1, 1], 'inference/inference_{:04d}.png'.format(index[i*batch_size + j]))
307 |
308 | if __name__ == '__main__':
309 | parser = argparse.ArgumentParser(description="Text-to-image")
310 | parser.add_argument("--mode", type=str, default='train',
311 | help="train/test")
312 |
313 | args = parser.parse_args()
314 | if args.mode == 'train':
315 | print('In training mode.')
316 | train()
317 | elif args.mode == 'test':
318 | print('In testing mode.')
319 | test()
320 |
--------------------------------------------------------------------------------
/train_samples/train_799.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellochick/text-to-image/fe5d1385fd26ea17aa9ad41afc74075e13f8db85/train_samples/train_799.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import random
4 | import scipy
5 | import scipy.misc
6 | import numpy as np
7 | import re
8 | import string
9 | import threading
10 | import scipy.ndimage as ndi
11 | from skimage import transform
12 | from skimage import exposure
13 | import skimage
14 |
15 | dictionary_path = './dictionary'
16 | word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
17 | id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
18 |
19 | def sent2IdList(line, MAX_SEQ_LENGTH=20):
20 | MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
21 | padding = 0
22 | prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
23 | prep_line = prep_line.replace('-', ' ')
24 | prep_line = prep_line.replace('-', ' ')
25 | prep_line = prep_line.replace(' ', ' ')
26 | prep_line = prep_line.replace('.', '')
27 | tokens = prep_line.split(' ')
28 | tokens = [
29 | tokens[i] for i in range(len(tokens))
30 | if tokens[i] != ' ' and tokens[i] != ''
31 | ]
32 | l = len(tokens)
33 | padding = MAX_SEQ_LIMIT - l
34 | for i in range(padding):
35 | tokens.append('')
36 |
37 | line = [
38 | word2Id_dict[tokens[k]]
39 | if tokens[k] in word2Id_dict else word2Id_dict['']
40 | for k in range(len(tokens))
41 | ]
42 |
43 | return line
44 |
45 | def IdList2sent(caption):
46 | sentence = []
47 | for ID in caption:
48 | if ID != word2Id_dict['']:
49 | sentence.append(id2word_dict[ID])
50 |
51 | return sentence
52 |
53 | def get_random_int(min=0, max=10, number=5):
54 | """Return a list of random integer by the given range and quantity.
55 | Examples
56 | ---------
57 | >>> r = get_random_int(min=0, max=10, number=5)
58 | ... [10, 2, 3, 3, 7]
59 | """
60 | return [random.randint(min,max) for p in range(0,number)]
61 |
62 | ## Save images
63 | def merge(images, size):
64 | h, w = images.shape[1], images.shape[2]
65 | img = np.zeros((h * size[0], w * size[1], 3))
66 | for idx, image in enumerate(images):
67 | i = idx % size[1]
68 | j = idx // size[1]
69 | img[j*h:j*h+h, i*w:i*w+w, :] = image
70 | return img
71 |
72 | def imsave(images, size, path):
73 | return scipy.misc.imsave(path, merge(images, size))
74 |
75 | def save_images(images, size, image_path):
76 | return imsave(images, size, image_path)
77 |
78 | # Data Augmentation reference: https://github.com/tensorlayer/tensorlayer/tree/master/tensorlayer
79 | def threading_data(data=None, fn=None, **kwargs):
80 | def apply_fn(results, i, data, kwargs):
81 | results[i] = fn(data, **kwargs)
82 |
83 | ## start multi-threaded reading.
84 | results = [None] * len(data) ## preallocate result list
85 | threads = []
86 | for i in range(len(data)):
87 | t = threading.Thread(
88 | name='threading_and_return',
89 | target=apply_fn,
90 | args=(results, i, data[i], kwargs)
91 | )
92 | t.start()
93 | threads.append(t)
94 |
95 | for t in threads:
96 | t.join()
97 |
98 | return np.asarray(results)
99 |
100 | def apply_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', cval=0., order=1):
101 | x = np.rollaxis(x, channel_index, 0)
102 | final_affine_matrix = transform_matrix[:2, :2]
103 | final_offset = transform_matrix[:2, 2]
104 | channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix,
105 | final_offset, order=order, mode=fill_mode, cval=cval) for x_channel in x]
106 | x = np.stack(channel_images, axis=0)
107 | x = np.rollaxis(x, 0, channel_index+1)
108 | return x
109 |
110 | def transform_matrix_offset_center(matrix, x, y):
111 | o_x = float(x) / 2 + 0.5
112 | o_y = float(y) / 2 + 0.5
113 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
114 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
115 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
116 | return transform_matrix
117 |
118 | def rotation(x, rg=20, is_random=False, row_index=0, col_index=1, channel_index=2,
119 | fill_mode='nearest', cval=0.):
120 | if is_random:
121 | theta = np.pi / 180 * np.random.uniform(-rg, rg)
122 | else:
123 | theta = np.pi /180 * rg
124 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
125 | [np.sin(theta), np.cos(theta), 0],
126 | [0, 0, 1]])
127 |
128 | h, w = x.shape[row_index], x.shape[col_index]
129 | transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w)
130 | x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
131 | return x
132 |
133 | def crop(x, wrg, hrg, is_random=False, row_index=0, col_index=1, channel_index=2):
134 | h, w = x.shape[row_index], x.shape[col_index]
135 | assert (h > hrg) and (w > wrg), "The size of cropping should smaller than the original image"
136 | if is_random:
137 | h_offset = int(np.random.uniform(0, h-hrg) -1)
138 | w_offset = int(np.random.uniform(0, w-wrg) -1)
139 | return x[h_offset: hrg+h_offset ,w_offset: wrg+w_offset]
140 | else: # central crop
141 | h_offset = int(np.floor((h - hrg)/2.))
142 | w_offset = int(np.floor((w - wrg)/2.))
143 | h_end = h_offset + hrg
144 | w_end = w_offset + wrg
145 | return x[h_offset: h_end, w_offset: w_end]
146 |
147 | def flip_axis(x, axis, is_random=False):
148 | if is_random:
149 | factor = np.random.uniform(-1, 1)
150 | if factor > 0:
151 | x = np.asarray(x).swapaxes(axis, 0)
152 | x = x[::-1, ...]
153 | x = x.swapaxes(0, axis)
154 | return x
155 | else:
156 | return x
157 | else:
158 | x = np.asarray(x).swapaxes(axis, 0)
159 | x = x[::-1, ...]
160 | x = x.swapaxes(0, axis)
161 | return x
162 |
163 | def imresize(x, size=[100, 100], interp='bilinear', mode=None):
164 | if x.shape[-1] == 1:
165 | # greyscale
166 | x = scipy.misc.imresize(x[:,:,0], size, interp=interp, mode=mode)
167 | return x[:, :, np.newaxis]
168 | elif x.shape[-1] == 3:
169 | # rgb, bgr ..
170 | return scipy.misc.imresize(x, size, interp=interp, mode=mode)
171 | else:
172 | raise Exception("Unsupported channel %d" % x.shape[-1])
173 |
174 | def prepro_img(x, mode=None):
175 | # rescale [0, 255] --> (-1, 1), random flip, crop, rotate
176 |
177 | if mode=='train':
178 | x = flip_axis(x, axis=1, is_random=True)
179 | x = rotation(x, rg=16, is_random=True, fill_mode='nearest')
180 | x = imresize(x, size=[64+15, 64+15], interp='bilinear', mode=None)
181 | x = crop(x, wrg=64, hrg=64, is_random=True)
182 | x = x / (255. / 2.)
183 | x = x - 1.
184 | # x = x * 0.9999
185 |
186 | return x
187 |
188 | def cosine_similarity(v1, v2):
189 | cost = tf.reduce_sum(tf.multiply(v1, v2), 1) / (tf.sqrt(tf.reduce_sum(tf.multiply(v1, v1), 1)) * tf.sqrt(tf.reduce_sum(tf.multiply(v2, v2), 1)))
190 | return cost
191 |
192 | def combine_and_save_image_sets(image_sets, directory):
193 | for i in range(len(image_sets[0])):
194 | combined_image = []
195 | for set_no in range(len(image_sets)):
196 | combined_image.append( image_sets[set_no][i] )
197 | combined_image.append( np.zeros((image_sets[set_no][i].shape[0], 5, 3)) )
198 | combined_image = np.concatenate( combined_image, axis = 1 )
199 |
200 | scipy.misc.imsave( os.path.join( directory, 'combined_{}.jpg'.format(i) ), combined_image)
--------------------------------------------------------------------------------