├── .gitignore ├── Data └── .gitignore ├── LICENSE ├── Python 3 Codes ├── data_loader.py ├── download_datasets.py ├── generate_images.py ├── generate_thought_vectors.py ├── image_processing.py ├── model.py ├── ops.py ├── skipthoughts.py └── train.py ├── README.md ├── Utils ├── __init__.py ├── image_processing.py └── ops.py ├── data_loader.py ├── download_datasets.py ├── generate_images.py ├── generate_thought_vectors.py ├── model.py ├── requirements.txt ├── skipthoughts.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea* 3 | *.pdf 4 | *.jpg 5 | *.png 6 | *.pyc 7 | *.py.bak 8 | sample.py 9 | vggtest.py 10 | *.pem 11 | amazon_ssh.sh 12 | awstransfer.sh 13 | localtrain.py 14 | vislstm.png 15 | sample_aws.sh 16 | eval_trec.py 17 | theanotest.py 18 | data_loader_old.py 19 | Utils/word_embeddings_old.py 20 | gen_backup.py 21 | data_loader_test.py 22 | downloadModels.sh -------------------------------------------------------------------------------- /Data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2014 Paarth Neekhara 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /Python 3 Codes/data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os.path import join, isfile 4 | import re 5 | import numpy as np 6 | import pickle 7 | import argparse 8 | import skipthoughts 9 | import h5py 10 | # DID NOT TRAIN IT ON MS COCO YET 11 | def save_caption_vectors_ms_coco(data_dir, split, batch_size): 12 | meta_data = {} 13 | ic_file = join(data_dir, 'annotations/captions_{}2014.json'.format(split)) 14 | with open(ic_file) as f: 15 | ic_data = json.loads(f.read()) 16 | 17 | meta_data['data_length'] = len(ic_data['annotations']) 18 | with open(join(data_dir, 'meta_{}.pkl'.format(split)), 'wb') as f: 19 | pickle.dump(meta_data, f) 20 | 21 | model = skipthoughts.load_model() 22 | batch_no = 0 23 | print("Total Batches", len(ic_data['annotations'])/batch_size) 24 | 25 | while batch_no*batch_size < len(ic_data['annotations']): 26 | captions = [] 27 | image_ids = [] 28 | idx = batch_no 29 | for i in range(batch_no*batch_size, (batch_no+1)*batch_size): 30 | idx = i%len(ic_data['annotations']) 31 | captions.append(ic_data['annotations'][idx]['caption']) 32 | image_ids.append(ic_data['annotations'][idx]['image_id']) 33 | 34 | print(captions) 35 | print(image_ids) 36 | # Thought Vectors 37 | tv_batch = skipthoughts.encode(model, captions) 38 | h5f_tv_batch = h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no)), 'w') 39 | h5f_tv_batch.create_dataset('tv', data=tv_batch) 40 | h5f_tv_batch.close() 41 | 42 | h5f_tv_batch_image_ids = h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no)), 'w') 43 | h5f_tv_batch_image_ids.create_dataset('tv', data=image_ids) 44 | h5f_tv_batch_image_ids.close() 45 | 46 | print("Batches Done", batch_no, len(ic_data['annotations'])/batch_size) 47 | batch_no += 1 48 | 49 | 50 | def save_caption_vectors_flowers(data_dir): 51 | import time 52 | 53 | img_dir = join(data_dir, 'flowers/jpg') 54 | image_files = [f for f in os.listdir(img_dir) if 'jpg' in f] 55 | print(image_files[300:400]) 56 | print(len(image_files)) 57 | image_captions = { img_file : [] for img_file in image_files } 58 | 59 | caption_dir = join(data_dir, 'flowers/text_c10') 60 | class_dirs = [] 61 | for i in range(1, 103): 62 | class_dir_name = 'class_%.5d'%(i) 63 | class_dirs.append( join(caption_dir, class_dir_name)) 64 | 65 | for class_dir in class_dirs: 66 | caption_files = [f for f in os.listdir(class_dir) if 'txt' in f] 67 | for cap_file in caption_files: 68 | with open(join(class_dir,cap_file)) as f: 69 | captions = f.read().split('\n') 70 | img_file = cap_file[0:11] + ".jpg" 71 | # 5 captions per image 72 | image_captions[img_file] += [cap for cap in captions if len(cap) > 0][0:5] 73 | 74 | print(len(image_captions)) 75 | 76 | model = skipthoughts.load_model() 77 | encoded_captions = {} 78 | 79 | 80 | for i, img in enumerate(image_captions): 81 | st = time.time() 82 | encoded_captions[img] = skipthoughts.encode(model, image_captions[img]) 83 | print(i, len(image_captions), img) 84 | print("Seconds", time.time() - st) 85 | 86 | 87 | h = h5py.File(join(data_dir, 'flower_tv.hdf5')) 88 | for key in encoded_captions: 89 | h.create_dataset(key, data=encoded_captions[key]) 90 | h.close() 91 | 92 | def main(): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--split', type=str, default='train', 95 | help='train/val') 96 | parser.add_argument('--data_dir', type=str, default='Data', 97 | help='Data directory') 98 | parser.add_argument('--batch_size', type=int, default=64, 99 | help='Batch Size') 100 | parser.add_argument('--data_set', type=str, default='flowers', 101 | help='Data Set : Flowers, MS-COCO') 102 | args = parser.parse_args() 103 | 104 | if args.data_set == 'flowers': 105 | save_caption_vectors_flowers(args.data_dir) 106 | else: 107 | save_caption_vectors_ms_coco(args.data_dir, args.split, args.batch_size) 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /Python 3 Codes/download_datasets.py: -------------------------------------------------------------------------------- 1 | # downloads/extracts datasets described in the README.md 2 | 3 | import os 4 | import sys 5 | import errno 6 | import tarfile 7 | 8 | if sys.version_info >= (3,): 9 | from urllib.request import urlretrieve 10 | else: 11 | from urllib.request import urlretrieve 12 | 13 | DATA_DIR = 'Data' 14 | 15 | 16 | # http://stackoverflow.com/questions/273192/how-to-check-if-a-directory-exists-and-create-it-if-necessary 17 | def make_sure_path_exists(path): 18 | try: 19 | os.makedirs(path) 20 | except OSError as exception: 21 | if exception.errno != errno.EEXIST: 22 | raise 23 | 24 | 25 | def create_data_paths(): 26 | if not os.path.isdir(DATA_DIR): 27 | raise EnvironmentError('Needs to be run from project directory containing ' + DATA_DIR) 28 | needed_paths = [ 29 | os.path.join(DATA_DIR, 'samples'), 30 | os.path.join(DATA_DIR, 'val_samples'), 31 | os.path.join(DATA_DIR, 'Models'), 32 | ] 33 | for p in needed_paths: 34 | make_sure_path_exists(p) 35 | 36 | 37 | # adapted from http://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python 38 | def dl_progress_hook(count, blockSize, totalSize): 39 | percent = int(count * blockSize * 100 / totalSize) 40 | sys.stdout.write("\r" + "...%d%%" % percent) 41 | sys.stdout.flush() 42 | 43 | 44 | def download_dataset(data_name): 45 | if data_name == 'flowers': 46 | print('== Flowers dataset ==') 47 | flowers_dir = os.path.join(DATA_DIR, 'flowers') 48 | flowers_jpg_tgz = os.path.join(flowers_dir, '102flowers.tgz') 49 | make_sure_path_exists(flowers_dir) 50 | 51 | # the original google drive link at https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view 52 | # from https://github.com/reedscot/icml2016 is problematic to download automatically, so included 53 | # the text_c10 directory from that archive as a bzipped file in the repo 54 | captions_tbz = os.path.join(DATA_DIR, 'flowers_text_c10.tar.bz2') 55 | print(('Extracting ' + captions_tbz)) 56 | captions_tar = tarfile.open(captions_tbz, 'r:bz2') 57 | captions_tar.extractall(flowers_dir) 58 | 59 | flowers_url = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' 60 | print(('Downloading ' + flowers_jpg_tgz + ' from ' + flowers_url)) 61 | urlretrieve(flowers_url, flowers_jpg_tgz, 62 | reporthook=dl_progress_hook) 63 | print(('Extracting ' + flowers_jpg_tgz)) 64 | flowers_jpg_tar = tarfile.open(flowers_jpg_tgz, 'r:gz') 65 | flowers_jpg_tar.extractall(flowers_dir) # archive contains jpg/ folder 66 | 67 | elif data_name == 'skipthoughts': 68 | print('== Skipthoughts models ==') 69 | SKIPTHOUGHTS_DIR = os.path.join(DATA_DIR, 'skipthoughts') 70 | SKIPTHOUGHTS_BASE_URL = 'http://www.cs.toronto.edu/~rkiros/models/' 71 | make_sure_path_exists(SKIPTHOUGHTS_DIR) 72 | 73 | # following https://github.com/ryankiros/skip-thoughts#getting-started 74 | skipthoughts_files = [ 75 | 'dictionary.txt', 'utable.npy', 'btable.npy', 'uni_skip.npz', 'uni_skip.npz.pkl', 'bi_skip.npz', 76 | 'bi_skip.npz.pkl', 77 | ] 78 | for filename in skipthoughts_files: 79 | src_url = SKIPTHOUGHTS_BASE_URL + filename 80 | print(('Downloading ' + src_url)) 81 | urlretrieve(src_url, os.path.join(SKIPTHOUGHTS_DIR, filename), 82 | reporthook=dl_progress_hook) 83 | 84 | elif data_name == 'nltk_punkt': 85 | import nltk 86 | print('== NLTK pre-trained Punkt tokenizer for English ==') 87 | nltk.download('punkt') 88 | 89 | elif data_name == 'pretrained_model': 90 | print('== Pretrained model ==') 91 | MODEL_DIR = os.path.join(DATA_DIR, 'Models') 92 | pretrained_model_filename = 'latest_model_flowers_temp.ckpt' 93 | src_url = 'https://bitbucket.org/paarth_neekhara/texttomimagemodel/raw/74a4bbaeee26fe31e148a54c4f495694680e2c31/' + pretrained_model_filename 94 | print(('Downloading ' + src_url)) 95 | urlretrieve( 96 | src_url, 97 | os.path.join(MODEL_DIR, pretrained_model_filename), 98 | reporthook=dl_progress_hook, 99 | ) 100 | 101 | else: 102 | raise ValueError('Unknown dataset name: ' + data_name) 103 | 104 | 105 | def main(): 106 | create_data_paths() 107 | # TODO: make configurable via command-line 108 | download_dataset('flowers') 109 | download_dataset('skipthoughts') 110 | download_dataset('nltk_punkt') 111 | download_dataset('pretrained_model') 112 | print('Done') 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /Python 3 Codes/generate_images.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import model 4 | import argparse 5 | import pickle 6 | from os.path import join 7 | import h5py 8 | from Utils import image_processing 9 | import scipy.misc 10 | import random 11 | import json 12 | import os 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--z_dim', type=int, default=100, 18 | help='Noise Dimension') 19 | 20 | parser.add_argument('--t_dim', type=int, default=256, 21 | help='Text feature dimension') 22 | 23 | parser.add_argument('--image_size', type=int, default=64, 24 | help='Image Size') 25 | 26 | parser.add_argument('--gf_dim', type=int, default=64, 27 | help='Number of conv in the first layer gen.') 28 | 29 | parser.add_argument('--df_dim', type=int, default=64, 30 | help='Number of conv in the first layer discr.') 31 | 32 | parser.add_argument('--gfc_dim', type=int, default=1024, 33 | help='Dimension of gen untis for for fully connected layer 1024') 34 | 35 | parser.add_argument('--caption_vector_length', type=int, default=2400, 36 | help='Caption Vector Length') 37 | 38 | parser.add_argument('--data_dir', type=str, default="Data", 39 | help='Data Directory') 40 | 41 | parser.add_argument('--model_path', type=str, default='Data/Models/latest_model_flowers_temp.ckpt', 42 | help='Trained Model Path') 43 | 44 | parser.add_argument('--n_images', type=int, default=5, 45 | help='Number of Images per Caption') 46 | 47 | parser.add_argument('--caption_thought_vectors', type=str, default='Data/sample_caption_vectors.hdf5', 48 | help='Caption Thought Vector File') 49 | 50 | 51 | args = parser.parse_args() 52 | model_options = { 53 | 'z_dim' : args.z_dim, 54 | 't_dim' : args.t_dim, 55 | 'batch_size' : args.n_images, 56 | 'image_size' : args.image_size, 57 | 'gf_dim' : args.gf_dim, 58 | 'df_dim' : args.df_dim, 59 | 'gfc_dim' : args.gfc_dim, 60 | 'caption_vector_length' : args.caption_vector_length 61 | } 62 | 63 | gan = model.GAN(model_options) 64 | _, _, _, _, _ = gan.build_model() 65 | sess = tf.InteractiveSession() 66 | saver = tf.train.Saver() 67 | saver.restore(sess, args.model_path) 68 | 69 | input_tensors, outputs = gan.build_generator() 70 | 71 | h = h5py.File( args.caption_thought_vectors ) 72 | caption_vectors = np.array(h['vectors']) 73 | caption_image_dic = {} 74 | for cn, caption_vector in enumerate(caption_vectors): 75 | 76 | caption_images = [] 77 | z_noise = np.random.uniform(-1, 1, [args.n_images, args.z_dim]) 78 | caption = [ caption_vector[0:args.caption_vector_length] ] * args.n_images 79 | 80 | [ gen_image ] = sess.run( [ outputs['generator'] ], 81 | feed_dict = { 82 | input_tensors['t_real_caption'] : caption, 83 | input_tensors['t_z'] : z_noise, 84 | } ) 85 | 86 | caption_images = [gen_image[i,:,:,:] for i in range(0, args.n_images)] 87 | caption_image_dic[ cn ] = caption_images 88 | print("Generated", cn) 89 | 90 | for f in os.listdir( join(args.data_dir, 'val_samples')): 91 | if os.path.isfile(f): 92 | os.unlink(join(args.data_dir, 'val_samples/' + f)) 93 | 94 | for cn in range(0, len(caption_vectors)): 95 | caption_images = [] 96 | for i, im in enumerate( caption_image_dic[ cn ] ): 97 | # im_name = "caption_{}_{}.jpg".format(cn, i) 98 | # scipy.misc.imsave( join(args.data_dir, 'val_samples/{}'.format(im_name)) , im) 99 | caption_images.append( im ) 100 | caption_images.append( np.zeros((64, 5, 3)) ) 101 | combined_image = np.concatenate( caption_images[0:-1], axis = 1 ) 102 | scipy.misc.imsave( join(args.data_dir, 'val_samples/combined_image_{}.jpg'.format(cn)) , combined_image) 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /Python 3 Codes/generate_thought_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isfile 3 | import re 4 | import numpy as np 5 | import pickle 6 | import argparse 7 | import skipthoughts 8 | import h5py 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--caption_file', type=str, default='Data/sample_captions.txt', 13 | help='caption file') 14 | parser.add_argument('--data_dir', type=str, default='Data', 15 | help='Data Directory') 16 | 17 | args = parser.parse_args() 18 | with open( args.caption_file ) as f: 19 | captions = f.read().split('\n') 20 | 21 | captions = [cap for cap in captions if len(cap) > 0] 22 | print(captions) 23 | model = skipthoughts.load_model() 24 | caption_vectors = skipthoughts.encode(model, captions) 25 | 26 | if os.path.isfile(join(args.data_dir, 'sample_caption_vectors.hdf5')): 27 | os.remove(join(args.data_dir, 'sample_caption_vectors.hdf5')) 28 | h = h5py.File(join(args.data_dir, 'sample_caption_vectors.hdf5')) 29 | h.create_dataset('vectors', data=caption_vectors) 30 | h.close() 31 | 32 | if __name__ == '__main__': 33 | main() -------------------------------------------------------------------------------- /Python 3 Codes/image_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import misc 3 | import random 4 | import skimage 5 | import skimage.io 6 | import skimage.transform 7 | 8 | def load_image_array(image_file, image_size): 9 | img = skimage.io.imread(image_file) 10 | # GRAYSCALE 11 | if len(img.shape) == 2: 12 | img_new = np.ndarray( (img.shape[0], img.shape[1], 3), dtype = 'uint8') 13 | img_new[:,:,0] = img 14 | img_new[:,:,1] = img 15 | img_new[:,:,2] = img 16 | img = img_new 17 | 18 | img_resized = skimage.transform.resize(img, (image_size, image_size)) 19 | 20 | # FLIP HORIZONTAL WIRH A PROBABILITY 0.5 21 | if random.random() > 0.5: 22 | img_resized = np.fliplr(img_resized) 23 | 24 | 25 | return img_resized.astype('float32') 26 | 27 | if __name__ == '__main__': 28 | # TEST>>> 29 | arr = load_image_array('sample.jpg', 64) 30 | print(arr.mean()) 31 | # rev = np.fliplr(arr) 32 | misc.imsave( 'rev.jpg', arr) -------------------------------------------------------------------------------- /Python 3 Codes/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from Utils import ops 3 | 4 | class GAN: 5 | ''' 6 | OPTIONS 7 | z_dim : Noise dimension 100 8 | t_dim : Text feature dimension 256 9 | image_size : Image Dimension 64 10 | gf_dim : Number of conv in the first layer generator 64 11 | df_dim : Number of conv in the first layer discriminator 64 12 | gfc_dim : Dimension of gen untis for for fully connected layer 1024 13 | caption_vector_length : Caption Vector Length 2400 14 | batch_size : Batch Size 64 15 | ''' 16 | def __init__(self, options): 17 | self.options = options 18 | 19 | self.g_bn0 = ops.batch_norm(name='g_bn0') 20 | self.g_bn1 = ops.batch_norm(name='g_bn1') 21 | self.g_bn2 = ops.batch_norm(name='g_bn2') 22 | self.g_bn3 = ops.batch_norm(name='g_bn3') 23 | 24 | self.d_bn1 = ops.batch_norm(name='d_bn1') 25 | self.d_bn2 = ops.batch_norm(name='d_bn2') 26 | self.d_bn3 = ops.batch_norm(name='d_bn3') 27 | self.d_bn4 = ops.batch_norm(name='d_bn4') 28 | 29 | 30 | def build_model(self): 31 | img_size = self.options['image_size'] 32 | t_real_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'real_image') 33 | t_wrong_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'wrong_image') 34 | t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input') 35 | t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']]) 36 | 37 | fake_image = self.generator(t_z, t_real_caption) 38 | 39 | disc_real_image, disc_real_image_logits = self.discriminator(t_real_image, t_real_caption) 40 | disc_wrong_image, disc_wrong_image_logits = self.discriminator(t_wrong_image, t_real_caption, reuse = True) 41 | disc_fake_image, disc_fake_image_logits = self.discriminator(fake_image, t_real_caption, reuse = True) 42 | 43 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_image_logits, tf.ones_like(disc_fake_image))) 44 | 45 | d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_real_image_logits, tf.ones_like(disc_real_image))) 46 | d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_wrong_image_logits, tf.zeros_like(disc_wrong_image))) 47 | d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_image_logits, tf.zeros_like(disc_fake_image))) 48 | 49 | d_loss = d_loss1 + d_loss2 + d_loss3 50 | 51 | t_vars = tf.trainable_variables() 52 | d_vars = [var for var in t_vars if 'd_' in var.name] 53 | g_vars = [var for var in t_vars if 'g_' in var.name] 54 | 55 | input_tensors = { 56 | 't_real_image' : t_real_image, 57 | 't_wrong_image' : t_wrong_image, 58 | 't_real_caption' : t_real_caption, 59 | 't_z' : t_z 60 | } 61 | 62 | variables = { 63 | 'd_vars' : d_vars, 64 | 'g_vars' : g_vars 65 | } 66 | 67 | loss = { 68 | 'g_loss' : g_loss, 69 | 'd_loss' : d_loss 70 | } 71 | 72 | outputs = { 73 | 'generator' : fake_image 74 | } 75 | 76 | checks = { 77 | 'd_loss1': d_loss1, 78 | 'd_loss2': d_loss2, 79 | 'd_loss3' : d_loss3, 80 | 'disc_real_image_logits' : disc_real_image_logits, 81 | 'disc_wrong_image_logits' : disc_wrong_image, 82 | 'disc_fake_image_logits' : disc_fake_image_logits 83 | } 84 | 85 | return input_tensors, variables, loss, outputs, checks 86 | 87 | def build_generator(self): 88 | img_size = self.options['image_size'] 89 | t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input') 90 | t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']]) 91 | fake_image = self.sampler(t_z, t_real_caption) 92 | 93 | input_tensors = { 94 | 't_real_caption' : t_real_caption, 95 | 't_z' : t_z 96 | } 97 | 98 | outputs = { 99 | 'generator' : fake_image 100 | } 101 | 102 | return input_tensors, outputs 103 | 104 | # Sample Images for a text embedding 105 | def sampler(self, t_z, t_text_embedding): 106 | tf.get_variable_scope().reuse_variables() 107 | 108 | s = self.options['image_size'] 109 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 110 | 111 | reduced_text_embedding = ops.lrelu( ops.linear(t_text_embedding, self.options['t_dim'], 'g_embedding') ) 112 | z_concat = tf.concat(1, [t_z, reduced_text_embedding]) 113 | z_ = ops.linear(z_concat, self.options['gf_dim']*8*s16*s16, 'g_h0_lin') 114 | h0 = tf.reshape(z_, [-1, s16, s16, self.options['gf_dim'] * 8]) 115 | h0 = tf.nn.relu(self.g_bn0(h0, train = False)) 116 | 117 | h1 = ops.deconv2d(h0, [self.options['batch_size'], s8, s8, self.options['gf_dim']*4], name='g_h1') 118 | h1 = tf.nn.relu(self.g_bn1(h1, train = False)) 119 | 120 | h2 = ops.deconv2d(h1, [self.options['batch_size'], s4, s4, self.options['gf_dim']*2], name='g_h2') 121 | h2 = tf.nn.relu(self.g_bn2(h2, train = False)) 122 | 123 | h3 = ops.deconv2d(h2, [self.options['batch_size'], s2, s2, self.options['gf_dim']*1], name='g_h3') 124 | h3 = tf.nn.relu(self.g_bn3(h3, train = False)) 125 | 126 | h4 = ops.deconv2d(h3, [self.options['batch_size'], s, s, 3], name='g_h4') 127 | 128 | return (tf.tanh(h4)/2. + 0.5) 129 | 130 | # GENERATOR IMPLEMENTATION based on : https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py 131 | def generator(self, t_z, t_text_embedding): 132 | 133 | s = self.options['image_size'] 134 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 135 | 136 | reduced_text_embedding = ops.lrelu( ops.linear(t_text_embedding, self.options['t_dim'], 'g_embedding') ) 137 | z_concat = tf.concat(1, [t_z, reduced_text_embedding]) 138 | z_ = ops.linear(z_concat, self.options['gf_dim']*8*s16*s16, 'g_h0_lin') 139 | h0 = tf.reshape(z_, [-1, s16, s16, self.options['gf_dim'] * 8]) 140 | h0 = tf.nn.relu(self.g_bn0(h0)) 141 | 142 | h1 = ops.deconv2d(h0, [self.options['batch_size'], s8, s8, self.options['gf_dim']*4], name='g_h1') 143 | h1 = tf.nn.relu(self.g_bn1(h1)) 144 | 145 | h2 = ops.deconv2d(h1, [self.options['batch_size'], s4, s4, self.options['gf_dim']*2], name='g_h2') 146 | h2 = tf.nn.relu(self.g_bn2(h2)) 147 | 148 | h3 = ops.deconv2d(h2, [self.options['batch_size'], s2, s2, self.options['gf_dim']*1], name='g_h3') 149 | h3 = tf.nn.relu(self.g_bn3(h3)) 150 | 151 | h4 = ops.deconv2d(h3, [self.options['batch_size'], s, s, 3], name='g_h4') 152 | 153 | return (tf.tanh(h4)/2. + 0.5) 154 | 155 | # DISCRIMINATOR IMPLEMENTATION based on : https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py 156 | def discriminator(self, image, t_text_embedding, reuse=False): 157 | if reuse: 158 | tf.get_variable_scope().reuse_variables() 159 | 160 | h0 = ops.lrelu(ops.conv2d(image, self.options['df_dim'], name = 'd_h0_conv')) #32 161 | h1 = ops.lrelu( self.d_bn1(ops.conv2d(h0, self.options['df_dim']*2, name = 'd_h1_conv'))) #16 162 | h2 = ops.lrelu( self.d_bn2(ops.conv2d(h1, self.options['df_dim']*4, name = 'd_h2_conv'))) #8 163 | h3 = ops.lrelu( self.d_bn3(ops.conv2d(h2, self.options['df_dim']*8, name = 'd_h3_conv'))) #4 164 | 165 | # ADD TEXT EMBEDDING TO THE NETWORK 166 | reduced_text_embeddings = ops.lrelu(ops.linear(t_text_embedding, self.options['t_dim'], 'd_embedding')) 167 | reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,1) 168 | reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,2) 169 | tiled_embeddings = tf.tile(reduced_text_embeddings, [1,4,4,1], name='tiled_embeddings') 170 | 171 | h3_concat = tf.concat( 3, [h3, tiled_embeddings], name='h3_concat') 172 | h3_new = ops.lrelu( self.d_bn4(ops.conv2d(h3_concat, self.options['df_dim']*8, 1,1,1,1, name = 'd_h3_conv_new'))) #4 173 | 174 | h4 = ops.linear(tf.reshape(h3_new, [self.options['batch_size'], -1]), 1, 'd_h3_lin') 175 | 176 | return tf.nn.sigmoid(h4), h4 177 | -------------------------------------------------------------------------------- /Python 3 Codes/ops.py: -------------------------------------------------------------------------------- 1 | # RESUED CODE FROM https://github.com/carpedm20/DCGAN-tensorflow/blob/master/ops.py 2 | import math 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.framework import ops 7 | 8 | 9 | class batch_norm(object): 10 | """Code modification of http://stackoverflow.com/a/33950177""" 11 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 12 | with tf.variable_scope(name): 13 | self.epsilon = epsilon 14 | self.momentum = momentum 15 | 16 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 17 | self.name = name 18 | 19 | def __call__(self, x, train=True): 20 | shape = x.get_shape().as_list() 21 | 22 | if train: 23 | with tf.variable_scope(self.name) as scope: 24 | self.beta = tf.get_variable("beta", [shape[-1]], 25 | initializer=tf.constant_initializer(0.)) 26 | self.gamma = tf.get_variable("gamma", [shape[-1]], 27 | initializer=tf.random_normal_initializer(1., 0.02)) 28 | 29 | try: 30 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 31 | except: 32 | batch_mean, batch_var = tf.nn.moments(x, [0, 1], name='moments') 33 | 34 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 35 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 36 | 37 | with tf.control_dependencies([ema_apply_op]): 38 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 39 | else: 40 | mean, var = self.ema_mean, self.ema_var 41 | 42 | normed = tf.nn.batch_norm_with_global_normalization( 43 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 44 | 45 | return normed 46 | 47 | def binary_cross_entropy(preds, targets, name=None): 48 | """Computes binary cross entropy given `preds`. 49 | For brevity, let `x = `, `z = targets`. The logistic loss is 50 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 51 | Args: 52 | preds: A `Tensor` of type `float32` or `float64`. 53 | targets: A `Tensor` of the same type and shape as `preds`. 54 | """ 55 | eps = 1e-12 56 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 57 | preds = ops.convert_to_tensor(preds, name="preds") 58 | targets = ops.convert_to_tensor(targets, name="targets") 59 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 60 | (1. - targets) * tf.log(1. - preds + eps))) 61 | 62 | def conv_cond_concat(x, y): 63 | """Concatenate conditioning vector on feature map axis.""" 64 | x_shapes = x.get_shape() 65 | y_shapes = y.get_shape() 66 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 67 | 68 | def conv2d(input_, output_dim, 69 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 70 | name="conv2d"): 71 | with tf.variable_scope(name): 72 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 73 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 74 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 75 | 76 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 77 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 78 | 79 | return conv 80 | 81 | def deconv2d(input_, output_shape, 82 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 83 | name="deconv2d", with_w=False): 84 | with tf.variable_scope(name): 85 | # filter : [height, width, output_channels, in_channels] 86 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 87 | initializer=tf.random_normal_initializer(stddev=stddev)) 88 | 89 | try: 90 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 91 | strides=[1, d_h, d_w, 1]) 92 | 93 | # Support for verisons of TensorFlow before 0.7.0 94 | except AttributeError: 95 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 96 | strides=[1, d_h, d_w, 1]) 97 | 98 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 99 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 100 | 101 | if with_w: 102 | return deconv, w, biases 103 | else: 104 | return deconv 105 | 106 | def lrelu(x, leak=0.2, name="lrelu"): 107 | return tf.maximum(x, leak*x) 108 | 109 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 110 | shape = input_.get_shape().as_list() 111 | 112 | with tf.variable_scope(scope or "Linear"): 113 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 114 | tf.random_normal_initializer(stddev=stddev)) 115 | bias = tf.get_variable("bias", [output_size], 116 | initializer=tf.constant_initializer(bias_start)) 117 | if with_w: 118 | return tf.matmul(input_, matrix) + bias, matrix, bias 119 | else: 120 | return tf.matmul(input_, matrix) + bias 121 | -------------------------------------------------------------------------------- /Python 3 Codes/skipthoughts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Skip-thought vectors 3 | https://github.com/ryankiros/skip-thoughts 4 | ''' 5 | import os 6 | 7 | import theano 8 | import theano.tensor as tensor 9 | 10 | import pickle as pkl 11 | import numpy 12 | import copy 13 | import nltk 14 | 15 | from collections import OrderedDict, defaultdict 16 | from scipy.linalg import norm 17 | from nltk.tokenize import word_tokenize 18 | 19 | profile = False 20 | 21 | #-----------------------------------------------------------------------------# 22 | # Specify model and table locations here 23 | #-----------------------------------------------------------------------------# 24 | path_to_models = 'Data/skipthoughts/' 25 | path_to_tables = 'Data/skipthoughts/' 26 | #-----------------------------------------------------------------------------# 27 | 28 | path_to_umodel = path_to_models + 'uni_skip.npz' 29 | path_to_bmodel = path_to_models + 'bi_skip.npz' 30 | 31 | 32 | def load_model(): 33 | """ 34 | Load the model with saved tables 35 | """ 36 | # Load model options 37 | print('Loading model parameters...') 38 | with open('%s.pkl'%path_to_umodel, 'rb') as f: 39 | uoptions = pkl.load(f) 40 | with open('%s.pkl'%path_to_bmodel, 'rb') as f: 41 | boptions = pkl.load(f) 42 | 43 | # Load parameters 44 | uparams = init_params(uoptions) 45 | uparams = load_params(path_to_umodel, uparams) 46 | utparams = init_tparams(uparams) 47 | bparams = init_params_bi(boptions) 48 | bparams = load_params(path_to_bmodel, bparams) 49 | btparams = init_tparams(bparams) 50 | 51 | # Extractor functions 52 | print('Compiling encoders...') 53 | embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions) 54 | f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v') 55 | embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions) 56 | f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2') 57 | 58 | # Tables 59 | print('Loading tables...') 60 | utable, btable = load_tables() 61 | 62 | # Store everything we need in a dictionary 63 | print('Packing up...') 64 | model = {} 65 | model['uoptions'] = uoptions 66 | model['boptions'] = boptions 67 | model['utable'] = utable 68 | model['btable'] = btable 69 | model['f_w2v'] = f_w2v 70 | model['f_w2v2'] = f_w2v2 71 | 72 | return model 73 | 74 | 75 | def load_tables(): 76 | """ 77 | Load the tables 78 | """ 79 | words = [] 80 | utable = numpy.load(path_to_tables + 'utable.npy',encoding='latin1') 81 | btable = numpy.load(path_to_tables + 'btable.npy',encoding='latin1') 82 | f = open(path_to_tables + 'dictionary.txt', 'rb') 83 | for line in f: 84 | words.append(line.decode('utf-8').strip()) 85 | f.close() 86 | utable = OrderedDict(list(zip(words, utable))) 87 | btable = OrderedDict(list(zip(words, btable))) 88 | return utable, btable 89 | 90 | 91 | def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False): 92 | """ 93 | Encode sentences in the list X. Each entry will return a vector 94 | """ 95 | # first, do preprocessing 96 | X = preprocess(X) 97 | 98 | # word dictionary and init 99 | d = defaultdict(lambda : 0) 100 | for w in list(model['utable'].keys()): 101 | d[w] = 1 102 | ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32') 103 | bfeatures = numpy.zeros((len(X), 2 * model['boptions']['dim']), dtype='float32') 104 | 105 | # length dictionary 106 | ds = defaultdict(list) 107 | captions = [s.split() for s in X] 108 | for i,s in enumerate(captions): 109 | ds[len(s)].append(i) 110 | 111 | # Get features. This encodes by length, in order to avoid wasting computation 112 | for k in list(ds.keys()): 113 | if verbose: 114 | print(k) 115 | numbatches = len(ds[k]) // batch_size + 1 116 | for minibatch in range(numbatches): 117 | caps = ds[k][minibatch::numbatches] 118 | 119 | if use_eos: 120 | uembedding = numpy.zeros((k+1, len(caps), model['uoptions']['dim_word']), dtype='float32') 121 | bembedding = numpy.zeros((k+1, len(caps), model['boptions']['dim_word']), dtype='float32') 122 | else: 123 | uembedding = numpy.zeros((k, len(caps), model['uoptions']['dim_word']), dtype='float32') 124 | bembedding = numpy.zeros((k, len(caps), model['boptions']['dim_word']), dtype='float32') 125 | for ind, c in enumerate(caps): 126 | caption = captions[c] 127 | for j in range(len(caption)): 128 | if d[caption[j]] > 0: 129 | uembedding[j,ind] = model['utable'][caption[j]] 130 | bembedding[j,ind] = model['btable'][caption[j]] 131 | else: 132 | uembedding[j,ind] = model['utable']['UNK'] 133 | bembedding[j,ind] = model['btable']['UNK'] 134 | if use_eos: 135 | uembedding[-1,ind] = model['utable'][''] 136 | bembedding[-1,ind] = model['btable'][''] 137 | if use_eos: 138 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32')) 139 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32')) 140 | else: 141 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption),len(caps)), dtype='float32')) 142 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption),len(caps)), dtype='float32')) 143 | if use_norm: 144 | for j in range(len(uff)): 145 | uff[j] /= norm(uff[j]) 146 | bff[j] /= norm(bff[j]) 147 | for ind, c in enumerate(caps): 148 | ufeatures[c] = uff[ind] 149 | bfeatures[c] = bff[ind] 150 | 151 | features = numpy.c_[ufeatures, bfeatures] 152 | return features 153 | 154 | 155 | def preprocess(text): 156 | """ 157 | Preprocess text for encoder 158 | """ 159 | X = [] 160 | sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') 161 | for t in text: 162 | sents = sent_detector.tokenize(t) 163 | result = '' 164 | for s in sents: 165 | tokens = word_tokenize(s) 166 | result += ' ' + ' '.join(tokens) 167 | X.append(result) 168 | return X 169 | 170 | 171 | def nn(model, text, vectors, query, k=5): 172 | """ 173 | Return the nearest neighbour sentences to query 174 | text: list of sentences 175 | vectors: the corresponding representations for text 176 | query: a string to search 177 | """ 178 | qf = encode(model, [query]) 179 | qf /= norm(qf) 180 | scores = numpy.dot(qf, vectors.T).flatten() 181 | sorted_args = numpy.argsort(scores)[::-1] 182 | sentences = [text[a] for a in sorted_args[:k]] 183 | print(('QUERY: ' + query)) 184 | print('NEAREST: ') 185 | for i, s in enumerate(sentences): 186 | print((s, sorted_args[i])) 187 | 188 | 189 | def word_features(table): 190 | """ 191 | Extract word features into a normalized matrix 192 | """ 193 | features = numpy.zeros((len(table), 620), dtype='float32') 194 | keys = list(table.keys()) 195 | for i in range(len(table)): 196 | f = table[keys[i]] 197 | features[i] = f / norm(f) 198 | return features 199 | 200 | 201 | def nn_words(table, wordvecs, query, k=10): 202 | """ 203 | Get the nearest neighbour words 204 | """ 205 | keys = list(table.keys()) 206 | qf = table[query] 207 | scores = numpy.dot(qf, wordvecs.T).flatten() 208 | sorted_args = numpy.argsort(scores)[::-1] 209 | words = [keys[a] for a in sorted_args[:k]] 210 | print(('QUERY: ' + query)) 211 | print('NEAREST: ') 212 | for i, w in enumerate(words): 213 | print(w) 214 | 215 | 216 | def _p(pp, name): 217 | """ 218 | make prefix-appended name 219 | """ 220 | return '%s_%s'%(pp, name) 221 | 222 | 223 | def init_tparams(params): 224 | """ 225 | initialize Theano shared variables according to the initial parameters 226 | """ 227 | tparams = OrderedDict() 228 | for kk, pp in list(params.items()): 229 | tparams[kk] = theano.shared(params[kk], name=kk) 230 | return tparams 231 | 232 | 233 | def load_params(path, params): 234 | """ 235 | load parameters 236 | """ 237 | pp = numpy.load(path) 238 | for kk, vv in list(params.items()): 239 | if kk not in pp: 240 | warnings.warn('%s is not in the archive'%kk) 241 | continue 242 | params[kk] = pp[kk] 243 | return params 244 | 245 | 246 | # layers: 'name': ('parameter initializer', 'feedforward') 247 | layers = {'gru': ('param_init_gru', 'gru_layer')} 248 | 249 | def get_layer(name): 250 | fns = layers[name] 251 | return (eval(fns[0]), eval(fns[1])) 252 | 253 | 254 | def init_params(options): 255 | """ 256 | initialize all parameters needed for the encoder 257 | """ 258 | params = OrderedDict() 259 | 260 | # embedding 261 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 262 | 263 | # encoder: GRU 264 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 265 | nin=options['dim_word'], dim=options['dim']) 266 | return params 267 | 268 | 269 | def init_params_bi(options): 270 | """ 271 | initialize all paramters needed for bidirectional encoder 272 | """ 273 | params = OrderedDict() 274 | 275 | # embedding 276 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 277 | 278 | # encoder: GRU 279 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 280 | nin=options['dim_word'], dim=options['dim']) 281 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r', 282 | nin=options['dim_word'], dim=options['dim']) 283 | return params 284 | 285 | 286 | def build_encoder(tparams, options): 287 | """ 288 | build an encoder, given pre-computed word embeddings 289 | """ 290 | # word embedding (source) 291 | embedding = tensor.tensor3('embedding', dtype='float32') 292 | x_mask = tensor.matrix('x_mask', dtype='float32') 293 | 294 | # encoder 295 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 296 | prefix='encoder', 297 | mask=x_mask) 298 | ctx = proj[0][-1] 299 | 300 | return embedding, x_mask, ctx 301 | 302 | 303 | def build_encoder_bi(tparams, options): 304 | """ 305 | build bidirectional encoder, given pre-computed word embeddings 306 | """ 307 | # word embedding (source) 308 | embedding = tensor.tensor3('embedding', dtype='float32') 309 | embeddingr = embedding[::-1] 310 | x_mask = tensor.matrix('x_mask', dtype='float32') 311 | xr_mask = x_mask[::-1] 312 | 313 | # encoder 314 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 315 | prefix='encoder', 316 | mask=x_mask) 317 | projr = get_layer(options['encoder'])[1](tparams, embeddingr, options, 318 | prefix='encoder_r', 319 | mask=xr_mask) 320 | 321 | ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1) 322 | 323 | return embedding, x_mask, ctx 324 | 325 | 326 | # some utilities 327 | def ortho_weight(ndim): 328 | W = numpy.random.randn(ndim, ndim) 329 | u, s, v = numpy.linalg.svd(W) 330 | return u.astype('float32') 331 | 332 | 333 | def norm_weight(nin,nout=None, scale=0.1, ortho=True): 334 | if nout == None: 335 | nout = nin 336 | if nout == nin and ortho: 337 | W = ortho_weight(nin) 338 | else: 339 | W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout)) 340 | return W.astype('float32') 341 | 342 | 343 | def param_init_gru(options, params, prefix='gru', nin=None, dim=None): 344 | """ 345 | parameter init for GRU 346 | """ 347 | if nin == None: 348 | nin = options['dim_proj'] 349 | if dim == None: 350 | dim = options['dim_proj'] 351 | W = numpy.concatenate([norm_weight(nin,dim), 352 | norm_weight(nin,dim)], axis=1) 353 | params[_p(prefix,'W')] = W 354 | params[_p(prefix,'b')] = numpy.zeros((2 * dim,)).astype('float32') 355 | U = numpy.concatenate([ortho_weight(dim), 356 | ortho_weight(dim)], axis=1) 357 | params[_p(prefix,'U')] = U 358 | 359 | Wx = norm_weight(nin, dim) 360 | params[_p(prefix,'Wx')] = Wx 361 | Ux = ortho_weight(dim) 362 | params[_p(prefix,'Ux')] = Ux 363 | params[_p(prefix,'bx')] = numpy.zeros((dim,)).astype('float32') 364 | 365 | return params 366 | 367 | 368 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs): 369 | """ 370 | Forward pass through GRU layer 371 | """ 372 | nsteps = state_below.shape[0] 373 | if state_below.ndim == 3: 374 | n_samples = state_below.shape[1] 375 | else: 376 | n_samples = 1 377 | 378 | dim = tparams[_p(prefix,'Ux')].shape[1] 379 | 380 | if mask == None: 381 | mask = tensor.alloc(1., state_below.shape[0], 1) 382 | 383 | def _slice(_x, n, dim): 384 | if _x.ndim == 3: 385 | return _x[:, :, n*dim:(n+1)*dim] 386 | return _x[:, n*dim:(n+1)*dim] 387 | 388 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')] 389 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')] 390 | U = tparams[_p(prefix, 'U')] 391 | Ux = tparams[_p(prefix, 'Ux')] 392 | 393 | def _step_slice(m_, x_, xx_, h_, U, Ux): 394 | preact = tensor.dot(h_, U) 395 | preact += x_ 396 | 397 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) 398 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) 399 | 400 | preactx = tensor.dot(h_, Ux) 401 | preactx = preactx * r 402 | preactx = preactx + xx_ 403 | 404 | h = tensor.tanh(preactx) 405 | 406 | h = u * h_ + (1. - u) * h 407 | h = m_[:,None] * h + (1. - m_)[:,None] * h_ 408 | 409 | return h 410 | 411 | seqs = [mask, state_below_, state_belowx] 412 | _step = _step_slice 413 | 414 | rval, updates = theano.scan(_step, 415 | sequences=seqs, 416 | outputs_info = [tensor.alloc(0., n_samples, dim)], 417 | non_sequences = [tparams[_p(prefix, 'U')], 418 | tparams[_p(prefix, 'Ux')]], 419 | name=_p(prefix, '_layers'), 420 | n_steps=nsteps, 421 | profile=profile, 422 | strict=True) 423 | rval = [rval] 424 | return rval 425 | 426 | -------------------------------------------------------------------------------- /Python 3 Codes/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import model 4 | import argparse 5 | import pickle 6 | from os.path import join 7 | import h5py 8 | from Utils import image_processing 9 | import scipy.misc 10 | import random 11 | import json 12 | import os 13 | import shutil 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--z_dim', type=int, default=100, 18 | help='Noise dimension') 19 | 20 | parser.add_argument('--t_dim', type=int, default=256, 21 | help='Text feature dimension') 22 | 23 | parser.add_argument('--batch_size', type=int, default=64, 24 | help='Batch Size') 25 | 26 | parser.add_argument('--image_size', type=int, default=64, 27 | help='Image Size a, a x a') 28 | 29 | parser.add_argument('--gf_dim', type=int, default=64, 30 | help='Number of conv in the first layer gen.') 31 | 32 | parser.add_argument('--df_dim', type=int, default=64, 33 | help='Number of conv in the first layer discr.') 34 | 35 | parser.add_argument('--gfc_dim', type=int, default=1024, 36 | help='Dimension of gen untis for for fully connected layer 1024') 37 | 38 | parser.add_argument('--caption_vector_length', type=int, default=2400, 39 | help='Caption Vector Length') 40 | 41 | parser.add_argument('--data_dir', type=str, default="Data", 42 | help='Data Directory') 43 | 44 | parser.add_argument('--learning_rate', type=float, default=0.0002, 45 | help='Learning Rate') 46 | 47 | parser.add_argument('--beta1', type=float, default=0.5, 48 | help='Momentum for Adam Update') 49 | 50 | parser.add_argument('--epochs', type=int, default=600, 51 | help='Max number of epochs') 52 | 53 | parser.add_argument('--save_every', type=int, default=30, 54 | help='Save Model/Samples every x iterations over batches') 55 | 56 | parser.add_argument('--resume_model', type=str, default=None, 57 | help='Pre-Trained Model Path, to resume from') 58 | 59 | parser.add_argument('--data_set', type=str, default="flowers", 60 | help='Dat set: MS-COCO, flowers') 61 | 62 | args = parser.parse_args() 63 | model_options = { 64 | 'z_dim' : args.z_dim, 65 | 't_dim' : args.t_dim, 66 | 'batch_size' : args.batch_size, 67 | 'image_size' : args.image_size, 68 | 'gf_dim' : args.gf_dim, 69 | 'df_dim' : args.df_dim, 70 | 'gfc_dim' : args.gfc_dim, 71 | 'caption_vector_length' : args.caption_vector_length 72 | } 73 | 74 | 75 | gan = model.GAN(model_options) 76 | input_tensors, variables, loss, outputs, checks = gan.build_model() 77 | 78 | d_optim = tf.train.AdamOptimizer(args.learning_rate, beta1 = args.beta1).minimize(loss['d_loss'], var_list=variables['d_vars']) 79 | g_optim = tf.train.AdamOptimizer(args.learning_rate, beta1 = args.beta1).minimize(loss['g_loss'], var_list=variables['g_vars']) 80 | 81 | sess = tf.InteractiveSession() 82 | tf.initialize_all_variables().run() 83 | 84 | saver = tf.train.Saver() 85 | if args.resume_model: 86 | saver.restore(sess, args.resume_model) 87 | 88 | loaded_data = load_training_data(args.data_dir, args.data_set) 89 | 90 | for i in range(args.epochs): 91 | batch_no = 0 92 | while batch_no*args.batch_size < loaded_data['data_length']: 93 | real_images, wrong_images, caption_vectors, z_noise, image_files = get_training_batch(batch_no, args.batch_size, 94 | args.image_size, args.z_dim, args.caption_vector_length, 'train', args.data_dir, args.data_set, loaded_data) 95 | 96 | # DISCR UPDATE 97 | check_ts = [ checks['d_loss1'] , checks['d_loss2'], checks['d_loss3']] 98 | _, d_loss, gen, d1, d2, d3 = sess.run([d_optim, loss['d_loss'], outputs['generator']] + check_ts, 99 | feed_dict = { 100 | input_tensors['t_real_image'] : real_images, 101 | input_tensors['t_wrong_image'] : wrong_images, 102 | input_tensors['t_real_caption'] : caption_vectors, 103 | input_tensors['t_z'] : z_noise, 104 | }) 105 | 106 | print("d1", d1) 107 | print("d2", d2) 108 | print("d3", d3) 109 | print("D", d_loss) 110 | 111 | # GEN UPDATE 112 | _, g_loss, gen = sess.run([g_optim, loss['g_loss'], outputs['generator']], 113 | feed_dict = { 114 | input_tensors['t_real_image'] : real_images, 115 | input_tensors['t_wrong_image'] : wrong_images, 116 | input_tensors['t_real_caption'] : caption_vectors, 117 | input_tensors['t_z'] : z_noise, 118 | }) 119 | 120 | # GEN UPDATE TWICE, to make sure d_loss does not go to 0 121 | _, g_loss, gen = sess.run([g_optim, loss['g_loss'], outputs['generator']], 122 | feed_dict = { 123 | input_tensors['t_real_image'] : real_images, 124 | input_tensors['t_wrong_image'] : wrong_images, 125 | input_tensors['t_real_caption'] : caption_vectors, 126 | input_tensors['t_z'] : z_noise, 127 | }) 128 | 129 | print("LOSSES", d_loss, g_loss, batch_no, i, len(loaded_data['image_list'])/ args.batch_size) 130 | batch_no += 1 131 | if (batch_no % args.save_every) == 0: 132 | print("Saving Images, Model") 133 | save_for_vis(args.data_dir, real_images, gen, image_files) 134 | save_path = saver.save(sess, "Data/Models/latest_model_{}_temp.ckpt".format(args.data_set)) 135 | if i%5 == 0: 136 | save_path = saver.save(sess, "Data/Models/model_after_{}_epoch_{}.ckpt".format(args.data_set, i)) 137 | 138 | def load_training_data(data_dir, data_set): 139 | if data_set == 'flowers': 140 | h = h5py.File(join(data_dir, 'flower_tv.hdf5')) 141 | flower_captions = {} 142 | for ds in h.items(): 143 | flower_captions[ds[0]] = np.array(ds[1]) 144 | image_list = [key for key in flower_captions] 145 | image_list.sort() 146 | 147 | img_75 = int(len(image_list)*0.75) 148 | training_image_list = image_list[0:img_75] 149 | random.shuffle(training_image_list) 150 | 151 | return { 152 | 'image_list' : training_image_list, 153 | 'captions' : flower_captions, 154 | 'data_length' : len(training_image_list) 155 | } 156 | 157 | else: 158 | with open(join(data_dir, 'meta_train.pkl')) as f: 159 | meta_data = pickle.load(f) 160 | # No preloading for MS-COCO 161 | return meta_data 162 | 163 | def save_for_vis(data_dir, real_images, generated_images, image_files): 164 | 165 | shutil.rmtree( join(data_dir, 'samples') ) 166 | os.makedirs( join(data_dir, 'samples') ) 167 | 168 | for i in range(0, real_images.shape[0]): 169 | real_image_255 = np.zeros( (64,64,3), dtype=np.uint8) 170 | real_images_255 = (real_images[i,:,:,:]) 171 | scipy.misc.imsave( join(data_dir, 'samples/{}_{}.jpg'.format(i, image_files[i].split('/')[-1] )) , real_images_255) 172 | 173 | fake_image_255 = np.zeros( (64,64,3), dtype=np.uint8) 174 | fake_images_255 = (generated_images[i,:,:,:]) 175 | scipy.misc.imsave(join(data_dir, 'samples/fake_image_{}.jpg'.format(i)), fake_images_255) 176 | 177 | 178 | def get_training_batch(batch_no, batch_size, image_size, z_dim, 179 | caption_vector_length, split, data_dir, data_set, loaded_data = None): 180 | if data_set == 'mscoco': 181 | with h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no))) as hf: 182 | caption_vectors = np.array(hf.get('tv')) 183 | caption_vectors = caption_vectors[:,0:caption_vector_length] 184 | with h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no))) as hf: 185 | image_ids = np.array(hf.get('tv')) 186 | 187 | real_images = np.zeros((batch_size, 64, 64, 3)) 188 | wrong_images = np.zeros((batch_size, 64, 64, 3)) 189 | 190 | image_files = [] 191 | for idx, image_id in enumerate(image_ids): 192 | image_file = join(data_dir, '%s2014/COCO_%s2014_%.12d.jpg'%(split, split, image_id) ) 193 | image_array = image_processing.load_image_array(image_file, image_size) 194 | real_images[idx,:,:,:] = image_array 195 | image_files.append(image_file) 196 | 197 | # TODO>> As of Now, wrong images are just shuffled real images. 198 | first_image = real_images[0,:,:,:] 199 | for i in range(0, batch_size): 200 | if i < batch_size - 1: 201 | wrong_images[i,:,:,:] = real_images[i+1,:,:,:] 202 | else: 203 | wrong_images[i,:,:,:] = first_image 204 | 205 | z_noise = np.random.uniform(-1, 1, [batch_size, z_dim]) 206 | 207 | 208 | return real_images, wrong_images, caption_vectors, z_noise, image_files 209 | 210 | if data_set == 'flowers': 211 | real_images = np.zeros((batch_size, 64, 64, 3)) 212 | wrong_images = np.zeros((batch_size, 64, 64, 3)) 213 | captions = np.zeros((batch_size, caption_vector_length)) 214 | 215 | cnt = 0 216 | image_files = [] 217 | for i in range(batch_no * batch_size, batch_no * batch_size + batch_size): 218 | idx = i % len(loaded_data['image_list']) 219 | image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][idx]) 220 | image_array = image_processing.load_image_array(image_file, image_size) 221 | real_images[cnt,:,:,:] = image_array 222 | 223 | # Improve this selection of wrong image 224 | wrong_image_id = random.randint(0,len(loaded_data['image_list'])-1) 225 | wrong_image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][wrong_image_id]) 226 | wrong_image_array = image_processing.load_image_array(wrong_image_file, image_size) 227 | wrong_images[cnt, :,:,:] = wrong_image_array 228 | 229 | random_caption = random.randint(0,4) 230 | captions[cnt,:] = loaded_data['captions'][ loaded_data['image_list'][idx] ][ random_caption ][0:caption_vector_length] 231 | image_files.append( image_file ) 232 | cnt += 1 233 | 234 | z_noise = np.random.uniform(-1, 1, [batch_size, z_dim]) 235 | return real_images, wrong_images, captions, z_noise, image_files 236 | 237 | if __name__ == '__main__': 238 | main() 239 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text To Image Synthesis Using Thought Vectors 2 | 3 | [![Join the chat at https://gitter.im/text-to-image/Lobby](https://badges.gitter.im/text-to-image/Lobby.svg)](https://gitter.im/text-to-image/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 4 | 5 | This is an experimental tensorflow implementation of synthesizing images from captions using [Skip Thought Vectors][1]. The images are synthesized using the GAN-CLS Algorithm from the paper [Generative Adversarial Text-to-Image Synthesis][2]. This implementation is built on top of the excellent [DCGAN in Tensorflow][3]. The following is the model architecture. The blue bars represent the Skip Thought Vectors for the captions. 6 | 7 | ![Model architecture](http://i.imgur.com/dNl2HkZ.jpg) 8 | 9 | Image Source : [Generative Adversarial Text-to-Image Synthesis][2] Paper 10 | 11 | ## Requirements 12 | - Python 2.7.6 13 | - [Tensorflow][4] 14 | - [h5py][5] 15 | - [Theano][6] : for skip thought vectors 16 | - [scikit-learn][7] : for skip thought vectors 17 | - [NLTK][8] : for skip thought vectors 18 | 19 | ## Datasets 20 | - All the steps below for downloading the datasets and models can be performed automatically by running `python download_datasets.py`. Several gigabytes of files will be downloaded and extracted. 21 | - The model is currently trained on the [flowers dataset][9]. Download the images from [this link][9] and save them in ```Data/flowers/jpg```. Also download the captions from [this link][10]. Extract the archive, copy the ```text_c10``` folder and paste it in ```Data/flowers```. 22 | - Download the pretrained models and vocabulary for skip thought vectors as per the instructions given [here][13]. Save the downloaded files in ```Data/skipthoughts```. 23 | - Make empty directories in Data, ```Data/samples```, ```Data/val_samples``` and ```Data/Models```. They will be used for sampling the generated images and saving the trained models. 24 | 25 | ## Usage 26 | - Data Processing : Extract the skip thought vectors for the flowers data set using : 27 | ``` 28 | python data_loader.py --data_set="flowers" 29 | ``` 30 | - Training 31 | * Basic usage `python train.py --data_set="flowers"` 32 | * Options 33 | - `z_dim`: Noise Dimension. Default is 100. 34 | - `t_dim`: Text feature dimension. Default is 256. 35 | - `batch_size`: Batch Size. Default is 64. 36 | - `image_size`: Image dimension. Default is 64. 37 | - `gf_dim`: Number of conv in the first layer generator. Default is 64. 38 | - `df_dim`: Number of conv in the first layer discriminator. Default is 64. 39 | - `gfc_dim`: Dimension of gen untis for for fully connected layer. Default is 1024. 40 | - `caption_vector_length`: Length of the caption vector. Default is 1024. 41 | - `data_dir`: Data Directory. Default is `Data/`. 42 | - `learning_rate`: Learning Rate. Default is 0.0002. 43 | - `beta1`: Momentum for adam update. Default is 0.5. 44 | - `epochs`: Max number of epochs. Default is 600. 45 | - `resume_model`: Resume training from a pretrained model path. 46 | - `data_set`: Data Set to train on. Default is flowers. 47 | 48 | - Generating Images from Captions 49 | * Write the captions in text file, and save it as ```Data/sample_captions.txt```. Generate the skip thought vectors for these captions using: 50 | ``` 51 | python generate_thought_vectors.py --caption_file="Data/sample_captions.txt" 52 | ``` 53 | * Generate the Images for the thought vectors using: 54 | ``` 55 | python generate_images.py --model_path= --n_images=8 56 | ``` 57 | ```n_images``` specifies the number of images to be generated per caption. The generated images will be saved in ```Data/val_samples/```. ```python generate_images.py --help``` for more options. 58 | 59 | ## Sample Images Generated 60 | Following are the images generated by the generative model from the captions. 61 | 62 | | Caption | Generated Images | 63 | | ------------- | -----:| 64 | | the flower shown has yellow anther red pistil and bright red petals | ![](http://i.imgur.com/SknZ3Sg.jpg) | 65 | | this flower has petals that are yellow, white and purple and has dark lines | ![](http://i.imgur.com/8zsv9Nc.jpg) | 66 | | the petals on this flower are white with a yellow center | ![](http://i.imgur.com/vvzv1cE.jpg) | 67 | | this flower has a lot of small round pink petals. | ![](http://i.imgur.com/w0zK1DC.jpg) | 68 | | this flower is orange in color, and has petals that are ruffled and rounded. | ![](http://i.imgur.com/VfBbRP1.jpg) | 69 | | the flower has yellow petals and the center of it is brown | ![](http://i.imgur.com/IAuOGZY.jpg) | 70 | 71 | 72 | ## Implementation Details 73 | - Only the uni-skip vectors from the skip thought vectors are used. I have not tried training the model with combine-skip vectors. 74 | - The model was trained for around 200 epochs on a GPU. This took roughly 2-3 days. 75 | - The images generated are 64 x 64 in dimension. 76 | - While processing the batches before training, the images are flipped horizontally with a probability of 0.5. 77 | - The train-val split is 0.75. 78 | 79 | ## Pre-trained Models 80 | - Download the pretrained model from [here][14] and save it in ```Data/Models```. Use this path for generating the images. 81 | 82 | ## TODO 83 | - Train the model on the MS-COCO data set, and generate more generic images. 84 | - Try different embedding options for captions(other than skip thought vectors). Also try to train the caption embedding RNN along with the GAN-CLS model. 85 | 86 | ## References 87 | - [Generative Adversarial Text-to-Image Synthesis][2] Paper 88 | - [Generative Adversarial Text-to-Image Synthesis][11] Code 89 | - [Skip Thought Vectors][1] Paper 90 | - [Skip Thought Vectors][12] Code 91 | - [DCGAN in Tensorflow][3] 92 | - [DCGAN in Tensorlayer][15] 93 | 94 | ## Alternate Implementations 95 | - [Text to Image in Torch by Scot Reed][11] 96 | - [Text to Image in Tensorlayer by Dong Hao][16] 97 | 98 | ## License 99 | MIT 100 | 101 | 102 | [1]:http://arxiv.org/abs/1506.06726 103 | [2]:http://arxiv.org/abs/1605.05396 104 | [3]:https://github.com/carpedm20/DCGAN-tensorflow 105 | [4]:https://github.com/tensorflow/tensorflow 106 | [5]:http://www.h5py.org/ 107 | [6]:https://github.com/Theano/Theano 108 | [7]:http://scikit-learn.org/stable/index.html 109 | [8]:http://www.nltk.org/ 110 | [9]:http://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 111 | [10]:https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view 112 | [11]:https://github.com/reedscot/icml2016 113 | [12]:https://github.com/ryankiros/skip-thoughts 114 | [13]:https://github.com/ryankiros/skip-thoughts#getting-started 115 | [14]:https://bitbucket.org/paarth_neekhara/texttomimagemodel/raw/74a4bbaeee26fe31e148a54c4f495694680e2c31/latest_model_flowers_temp.ckpt 116 | [15]:https://github.com/zsdonghao/dcgan 117 | [16]:https://github.com/zsdonghao/text-to-image 118 | -------------------------------------------------------------------------------- /Utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paarthneekhara/text-to-image/b5475c140ecb254096c493850ae09c1bc9da8f7a/Utils/__init__.py -------------------------------------------------------------------------------- /Utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import misc 3 | import random 4 | import skimage 5 | import skimage.io 6 | import skimage.transform 7 | 8 | def load_image_array(image_file, image_size): 9 | img = skimage.io.imread(image_file) 10 | # GRAYSCALE 11 | if len(img.shape) == 2: 12 | img_new = np.ndarray( (img.shape[0], img.shape[1], 3), dtype = 'uint8') 13 | img_new[:,:,0] = img 14 | img_new[:,:,1] = img 15 | img_new[:,:,2] = img 16 | img = img_new 17 | 18 | img_resized = skimage.transform.resize(img, (image_size, image_size)) 19 | 20 | # FLIP HORIZONTAL WIRH A PROBABILITY 0.5 21 | if random.random() > 0.5: 22 | img_resized = np.fliplr(img_resized) 23 | 24 | 25 | return img_resized.astype('float32') 26 | 27 | if __name__ == '__main__': 28 | # TEST>>> 29 | arr = load_image_array('sample.jpg', 64) 30 | print arr.mean() 31 | # rev = np.fliplr(arr) 32 | misc.imsave( 'rev.jpg', arr) -------------------------------------------------------------------------------- /Utils/ops.py: -------------------------------------------------------------------------------- 1 | # RESUED CODE FROM https://github.com/carpedm20/DCGAN-tensorflow/blob/master/ops.py 2 | import math 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.framework import ops 7 | 8 | 9 | class batch_norm(object): 10 | """Code modification of http://stackoverflow.com/a/33950177""" 11 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 12 | with tf.variable_scope(name): 13 | self.epsilon = epsilon 14 | self.momentum = momentum 15 | 16 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 17 | self.name = name 18 | 19 | def __call__(self, x, train=True): 20 | shape = x.get_shape().as_list() 21 | 22 | if train: 23 | with tf.variable_scope(self.name) as scope: 24 | self.beta = tf.get_variable("beta", [shape[-1]], 25 | initializer=tf.constant_initializer(0.)) 26 | self.gamma = tf.get_variable("gamma", [shape[-1]], 27 | initializer=tf.random_normal_initializer(1., 0.02)) 28 | 29 | try: 30 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 31 | except: 32 | batch_mean, batch_var = tf.nn.moments(x, [0, 1], name='moments') 33 | 34 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 35 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 36 | 37 | with tf.control_dependencies([ema_apply_op]): 38 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 39 | else: 40 | mean, var = self.ema_mean, self.ema_var 41 | 42 | normed = tf.nn.batch_norm_with_global_normalization( 43 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 44 | 45 | return normed 46 | 47 | def binary_cross_entropy(preds, targets, name=None): 48 | """Computes binary cross entropy given `preds`. 49 | 50 | For brevity, let `x = `, `z = targets`. The logistic loss is 51 | 52 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 53 | 54 | Args: 55 | preds: A `Tensor` of type `float32` or `float64`. 56 | targets: A `Tensor` of the same type and shape as `preds`. 57 | """ 58 | eps = 1e-12 59 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 60 | preds = ops.convert_to_tensor(preds, name="preds") 61 | targets = ops.convert_to_tensor(targets, name="targets") 62 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 63 | (1. - targets) * tf.log(1. - preds + eps))) 64 | 65 | def conv_cond_concat(x, y): 66 | """Concatenate conditioning vector on feature map axis.""" 67 | x_shapes = x.get_shape() 68 | y_shapes = y.get_shape() 69 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 70 | 71 | def conv2d(input_, output_dim, 72 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 73 | name="conv2d"): 74 | with tf.variable_scope(name): 75 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 76 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 77 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 78 | 79 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 80 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 81 | 82 | return conv 83 | 84 | def deconv2d(input_, output_shape, 85 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 86 | name="deconv2d", with_w=False): 87 | with tf.variable_scope(name): 88 | # filter : [height, width, output_channels, in_channels] 89 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 90 | initializer=tf.random_normal_initializer(stddev=stddev)) 91 | 92 | try: 93 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 94 | strides=[1, d_h, d_w, 1]) 95 | 96 | # Support for verisons of TensorFlow before 0.7.0 97 | except AttributeError: 98 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 99 | strides=[1, d_h, d_w, 1]) 100 | 101 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 102 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 103 | 104 | if with_w: 105 | return deconv, w, biases 106 | else: 107 | return deconv 108 | 109 | def lrelu(x, leak=0.2, name="lrelu"): 110 | return tf.maximum(x, leak*x) 111 | 112 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 113 | shape = input_.get_shape().as_list() 114 | 115 | with tf.variable_scope(scope or "Linear"): 116 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 117 | tf.random_normal_initializer(stddev=stddev)) 118 | bias = tf.get_variable("bias", [output_size], 119 | initializer=tf.constant_initializer(bias_start)) 120 | if with_w: 121 | return tf.matmul(input_, matrix) + bias, matrix, bias 122 | else: 123 | return tf.matmul(input_, matrix) + bias -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os.path import join, isfile 4 | import re 5 | import numpy as np 6 | import pickle 7 | import argparse 8 | import skipthoughts 9 | import h5py 10 | 11 | # DID NOT TRAIN IT ON MS COCO YET 12 | def save_caption_vectors_ms_coco(data_dir, split, batch_size): 13 | meta_data = {} 14 | ic_file = join(data_dir, 'annotations/captions_{}2014.json'.format(split)) 15 | with open(ic_file) as f: 16 | ic_data = json.loads(f.read()) 17 | 18 | meta_data['data_length'] = len(ic_data['annotations']) 19 | with open(join(data_dir, 'meta_{}.pkl'.format(split)), 'wb') as f: 20 | pickle.dump(meta_data, f) 21 | 22 | model = skipthoughts.load_model() 23 | batch_no = 0 24 | print "Total Batches", len(ic_data['annotations'])/batch_size 25 | 26 | while batch_no*batch_size < len(ic_data['annotations']): 27 | captions = [] 28 | image_ids = [] 29 | idx = batch_no 30 | for i in range(batch_no*batch_size, (batch_no+1)*batch_size): 31 | idx = i%len(ic_data['annotations']) 32 | captions.append(ic_data['annotations'][idx]['caption']) 33 | image_ids.append(ic_data['annotations'][idx]['image_id']) 34 | 35 | print captions 36 | print image_ids 37 | # Thought Vectors 38 | tv_batch = skipthoughts.encode(model, captions) 39 | h5f_tv_batch = h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no)), 'w') 40 | h5f_tv_batch.create_dataset('tv', data=tv_batch) 41 | h5f_tv_batch.close() 42 | 43 | h5f_tv_batch_image_ids = h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no)), 'w') 44 | h5f_tv_batch_image_ids.create_dataset('tv', data=image_ids) 45 | h5f_tv_batch_image_ids.close() 46 | 47 | print "Batches Done", batch_no, len(ic_data['annotations'])/batch_size 48 | batch_no += 1 49 | 50 | 51 | def save_caption_vectors_flowers(data_dir): 52 | import time 53 | 54 | img_dir = join(data_dir, 'flowers/jpg') 55 | image_files = [f for f in os.listdir(img_dir) if 'jpg' in f] 56 | print image_files[300:400] 57 | print len(image_files) 58 | image_captions = { img_file : [] for img_file in image_files } 59 | 60 | caption_dir = join(data_dir, 'flowers/text_c10') 61 | class_dirs = [] 62 | for i in range(1, 103): 63 | class_dir_name = 'class_%.5d'%(i) 64 | class_dirs.append( join(caption_dir, class_dir_name)) 65 | 66 | for class_dir in class_dirs: 67 | caption_files = [f for f in os.listdir(class_dir) if 'txt' in f] 68 | for cap_file in caption_files: 69 | with open(join(class_dir,cap_file)) as f: 70 | captions = f.read().split('\n') 71 | img_file = cap_file[0:11] + ".jpg" 72 | # 5 captions per image 73 | image_captions[img_file] += [cap for cap in captions if len(cap) > 0][0:5] 74 | 75 | print len(image_captions) 76 | 77 | model = skipthoughts.load_model() 78 | encoded_captions = {} 79 | 80 | 81 | for i, img in enumerate(image_captions): 82 | st = time.time() 83 | encoded_captions[img] = skipthoughts.encode(model, image_captions[img]) 84 | print i, len(image_captions), img 85 | print "Seconds", time.time() - st 86 | 87 | 88 | h = h5py.File(join(data_dir, 'flower_tv.hdf5')) 89 | for key in encoded_captions: 90 | h.create_dataset(key, data=encoded_captions[key]) 91 | h.close() 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--split', type=str, default='train', 96 | help='train/val') 97 | parser.add_argument('--data_dir', type=str, default='Data', 98 | help='Data directory') 99 | parser.add_argument('--batch_size', type=int, default=64, 100 | help='Batch Size') 101 | parser.add_argument('--data_set', type=str, default='flowers', 102 | help='Data Set : Flowers, MS-COCO') 103 | args = parser.parse_args() 104 | 105 | if args.data_set == 'flowers': 106 | save_caption_vectors_flowers(args.data_dir) 107 | else: 108 | save_caption_vectors_ms_coco(args.data_dir, args.split, args.batch_size) 109 | 110 | if __name__ == '__main__': 111 | main() -------------------------------------------------------------------------------- /download_datasets.py: -------------------------------------------------------------------------------- 1 | # downloads/extracts datasets described in the README.md 2 | 3 | import os 4 | import sys 5 | import errno 6 | import tarfile 7 | 8 | if sys.version_info >= (3,): 9 | from urllib.request import urlretrieve 10 | else: 11 | from urllib import urlretrieve 12 | 13 | DATA_DIR = 'Data' 14 | 15 | 16 | # http://stackoverflow.com/questions/273192/how-to-check-if-a-directory-exists-and-create-it-if-necessary 17 | def make_sure_path_exists(path): 18 | try: 19 | os.makedirs(path) 20 | except OSError as exception: 21 | if exception.errno != errno.EEXIST: 22 | raise 23 | 24 | 25 | def create_data_paths(): 26 | if not os.path.isdir(DATA_DIR): 27 | raise EnvironmentError('Needs to be run from project directory containing ' + DATA_DIR) 28 | needed_paths = [ 29 | os.path.join(DATA_DIR, 'samples'), 30 | os.path.join(DATA_DIR, 'val_samples'), 31 | os.path.join(DATA_DIR, 'Models'), 32 | ] 33 | for p in needed_paths: 34 | make_sure_path_exists(p) 35 | 36 | 37 | # adapted from http://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python 38 | def dl_progress_hook(count, blockSize, totalSize): 39 | percent = int(count * blockSize * 100 / totalSize) 40 | sys.stdout.write("\r" + "...%d%%" % percent) 41 | sys.stdout.flush() 42 | 43 | 44 | def download_dataset(data_name): 45 | if data_name == 'flowers': 46 | print('== Flowers dataset ==') 47 | flowers_dir = os.path.join(DATA_DIR, 'flowers') 48 | flowers_jpg_tgz = os.path.join(flowers_dir, '102flowers.tgz') 49 | make_sure_path_exists(flowers_dir) 50 | 51 | # the original google drive link at https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view 52 | # from https://github.com/reedscot/icml2016 is problematic to download automatically, so included 53 | # the text_c10 directory from that archive as a bzipped file in the repo 54 | captions_tbz = os.path.join(DATA_DIR, 'flowers_text_c10.tar.bz2') 55 | print('Extracting ' + captions_tbz) 56 | captions_tar = tarfile.open(captions_tbz, 'r:bz2') 57 | captions_tar.extractall(flowers_dir) 58 | 59 | flowers_url = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' 60 | print('Downloading ' + flowers_jpg_tgz + ' from ' + flowers_url) 61 | urlretrieve(flowers_url, flowers_jpg_tgz, 62 | reporthook=dl_progress_hook) 63 | print('Extracting ' + flowers_jpg_tgz) 64 | flowers_jpg_tar = tarfile.open(flowers_jpg_tgz, 'r:gz') 65 | flowers_jpg_tar.extractall(flowers_dir) # archive contains jpg/ folder 66 | 67 | elif data_name == 'skipthoughts': 68 | print('== Skipthoughts models ==') 69 | SKIPTHOUGHTS_DIR = os.path.join(DATA_DIR, 'skipthoughts') 70 | SKIPTHOUGHTS_BASE_URL = 'http://www.cs.toronto.edu/~rkiros/models/' 71 | make_sure_path_exists(SKIPTHOUGHTS_DIR) 72 | 73 | # following https://github.com/ryankiros/skip-thoughts#getting-started 74 | skipthoughts_files = [ 75 | 'dictionary.txt', 'utable.npy', 'btable.npy', 'uni_skip.npz', 'uni_skip.npz.pkl', 'bi_skip.npz', 76 | 'bi_skip.npz.pkl', 77 | ] 78 | for filename in skipthoughts_files: 79 | src_url = SKIPTHOUGHTS_BASE_URL + filename 80 | print('Downloading ' + src_url) 81 | urlretrieve(src_url, os.path.join(SKIPTHOUGHTS_DIR, filename), 82 | reporthook=dl_progress_hook) 83 | 84 | elif data_name == 'nltk_punkt': 85 | import nltk 86 | print('== NLTK pre-trained Punkt tokenizer for English ==') 87 | nltk.download('punkt') 88 | 89 | elif data_name == 'pretrained_model': 90 | print('== Pretrained model ==') 91 | MODEL_DIR = os.path.join(DATA_DIR, 'Models') 92 | pretrained_model_filename = 'latest_model_flowers_temp.ckpt' 93 | src_url = 'https://bitbucket.org/paarth_neekhara/texttomimagemodel/raw/74a4bbaeee26fe31e148a54c4f495694680e2c31/' + pretrained_model_filename 94 | print('Downloading ' + src_url) 95 | urlretrieve( 96 | src_url, 97 | os.path.join(MODEL_DIR, pretrained_model_filename), 98 | reporthook=dl_progress_hook, 99 | ) 100 | 101 | else: 102 | raise ValueError('Unknown dataset name: ' + data_name) 103 | 104 | 105 | def main(): 106 | create_data_paths() 107 | # TODO: make configurable via command-line 108 | download_dataset('flowers') 109 | download_dataset('skipthoughts') 110 | download_dataset('nltk_punkt') 111 | download_dataset('pretrained_model') 112 | print('Done') 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /generate_images.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import model 4 | import argparse 5 | import pickle 6 | from os.path import join 7 | import h5py 8 | from Utils import image_processing 9 | import scipy.misc 10 | import random 11 | import json 12 | import os 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--z_dim', type=int, default=100, 18 | help='Noise Dimension') 19 | 20 | parser.add_argument('--t_dim', type=int, default=256, 21 | help='Text feature dimension') 22 | 23 | parser.add_argument('--image_size', type=int, default=64, 24 | help='Image Size') 25 | 26 | parser.add_argument('--gf_dim', type=int, default=64, 27 | help='Number of conv in the first layer gen.') 28 | 29 | parser.add_argument('--df_dim', type=int, default=64, 30 | help='Number of conv in the first layer discr.') 31 | 32 | parser.add_argument('--gfc_dim', type=int, default=1024, 33 | help='Dimension of gen untis for for fully connected layer 1024') 34 | 35 | parser.add_argument('--caption_vector_length', type=int, default=2400, 36 | help='Caption Vector Length') 37 | 38 | parser.add_argument('--data_dir', type=str, default="Data", 39 | help='Data Directory') 40 | 41 | parser.add_argument('--model_path', type=str, default='Data/Models/latest_model_flowers_temp.ckpt', 42 | help='Trained Model Path') 43 | 44 | parser.add_argument('--n_images', type=int, default=5, 45 | help='Number of Images per Caption') 46 | 47 | parser.add_argument('--caption_thought_vectors', type=str, default='Data/sample_caption_vectors.hdf5', 48 | help='Caption Thought Vector File') 49 | 50 | 51 | args = parser.parse_args() 52 | model_options = { 53 | 'z_dim' : args.z_dim, 54 | 't_dim' : args.t_dim, 55 | 'batch_size' : args.n_images, 56 | 'image_size' : args.image_size, 57 | 'gf_dim' : args.gf_dim, 58 | 'df_dim' : args.df_dim, 59 | 'gfc_dim' : args.gfc_dim, 60 | 'caption_vector_length' : args.caption_vector_length 61 | } 62 | 63 | gan = model.GAN(model_options) 64 | _, _, _, _, _ = gan.build_model() 65 | sess = tf.InteractiveSession() 66 | saver = tf.train.Saver() 67 | saver.restore(sess, args.model_path) 68 | 69 | input_tensors, outputs = gan.build_generator() 70 | 71 | h = h5py.File( args.caption_thought_vectors ) 72 | caption_vectors = np.array(h['vectors']) 73 | caption_image_dic = {} 74 | for cn, caption_vector in enumerate(caption_vectors): 75 | 76 | caption_images = [] 77 | z_noise = np.random.uniform(-1, 1, [args.n_images, args.z_dim]) 78 | caption = [ caption_vector[0:args.caption_vector_length] ] * args.n_images 79 | 80 | [ gen_image ] = sess.run( [ outputs['generator'] ], 81 | feed_dict = { 82 | input_tensors['t_real_caption'] : caption, 83 | input_tensors['t_z'] : z_noise, 84 | } ) 85 | 86 | caption_images = [gen_image[i,:,:,:] for i in range(0, args.n_images)] 87 | caption_image_dic[ cn ] = caption_images 88 | print "Generated", cn 89 | 90 | for f in os.listdir( join(args.data_dir, 'val_samples')): 91 | if os.path.isfile(f): 92 | os.unlink(join(args.data_dir, 'val_samples/' + f)) 93 | 94 | for cn in range(0, len(caption_vectors)): 95 | caption_images = [] 96 | for i, im in enumerate( caption_image_dic[ cn ] ): 97 | # im_name = "caption_{}_{}.jpg".format(cn, i) 98 | # scipy.misc.imsave( join(args.data_dir, 'val_samples/{}'.format(im_name)) , im) 99 | caption_images.append( im ) 100 | caption_images.append( np.zeros((64, 5, 3)) ) 101 | combined_image = np.concatenate( caption_images[0:-1], axis = 1 ) 102 | scipy.misc.imsave( join(args.data_dir, 'val_samples/combined_image_{}.jpg'.format(cn)) , combined_image) 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /generate_thought_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isfile 3 | import re 4 | import numpy as np 5 | import pickle 6 | import argparse 7 | import skipthoughts 8 | import h5py 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--caption_file', type=str, default='Data/sample_captions.txt', 13 | help='caption file') 14 | parser.add_argument('--data_dir', type=str, default='Data', 15 | help='Data Directory') 16 | 17 | args = parser.parse_args() 18 | with open( args.caption_file ) as f: 19 | captions = f.read().split('\n') 20 | 21 | captions = [cap for cap in captions if len(cap) > 0] 22 | print captions 23 | model = skipthoughts.load_model() 24 | caption_vectors = skipthoughts.encode(model, captions) 25 | 26 | if os.path.isfile(join(args.data_dir, 'sample_caption_vectors.hdf5')): 27 | os.remove(join(args.data_dir, 'sample_caption_vectors.hdf5')) 28 | h = h5py.File(join(args.data_dir, 'sample_caption_vectors.hdf5')) 29 | h.create_dataset('vectors', data=caption_vectors) 30 | h.close() 31 | 32 | if __name__ == '__main__': 33 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from Utils import ops 3 | 4 | class GAN: 5 | ''' 6 | OPTIONS 7 | z_dim : Noise dimension 100 8 | t_dim : Text feature dimension 256 9 | image_size : Image Dimension 64 10 | gf_dim : Number of conv in the first layer generator 64 11 | df_dim : Number of conv in the first layer discriminator 64 12 | gfc_dim : Dimension of gen untis for for fully connected layer 1024 13 | caption_vector_length : Caption Vector Length 2400 14 | batch_size : Batch Size 64 15 | ''' 16 | def __init__(self, options): 17 | self.options = options 18 | 19 | self.g_bn0 = ops.batch_norm(name='g_bn0') 20 | self.g_bn1 = ops.batch_norm(name='g_bn1') 21 | self.g_bn2 = ops.batch_norm(name='g_bn2') 22 | self.g_bn3 = ops.batch_norm(name='g_bn3') 23 | 24 | self.d_bn1 = ops.batch_norm(name='d_bn1') 25 | self.d_bn2 = ops.batch_norm(name='d_bn2') 26 | self.d_bn3 = ops.batch_norm(name='d_bn3') 27 | self.d_bn4 = ops.batch_norm(name='d_bn4') 28 | 29 | 30 | def build_model(self): 31 | img_size = self.options['image_size'] 32 | t_real_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'real_image') 33 | t_wrong_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'wrong_image') 34 | t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input') 35 | t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']]) 36 | 37 | fake_image = self.generator(t_z, t_real_caption) 38 | 39 | disc_real_image, disc_real_image_logits = self.discriminator(t_real_image, t_real_caption) 40 | disc_wrong_image, disc_wrong_image_logits = self.discriminator(t_wrong_image, t_real_caption, reuse = True) 41 | disc_fake_image, disc_fake_image_logits = self.discriminator(fake_image, t_real_caption, reuse = True) 42 | 43 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_image_logits, tf.ones_like(disc_fake_image))) 44 | 45 | d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_real_image_logits, tf.ones_like(disc_real_image))) 46 | d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_wrong_image_logits, tf.zeros_like(disc_wrong_image))) 47 | d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_image_logits, tf.zeros_like(disc_fake_image))) 48 | 49 | d_loss = d_loss1 + d_loss2 + d_loss3 50 | 51 | t_vars = tf.trainable_variables() 52 | d_vars = [var for var in t_vars if 'd_' in var.name] 53 | g_vars = [var for var in t_vars if 'g_' in var.name] 54 | 55 | input_tensors = { 56 | 't_real_image' : t_real_image, 57 | 't_wrong_image' : t_wrong_image, 58 | 't_real_caption' : t_real_caption, 59 | 't_z' : t_z 60 | } 61 | 62 | variables = { 63 | 'd_vars' : d_vars, 64 | 'g_vars' : g_vars 65 | } 66 | 67 | loss = { 68 | 'g_loss' : g_loss, 69 | 'd_loss' : d_loss 70 | } 71 | 72 | outputs = { 73 | 'generator' : fake_image 74 | } 75 | 76 | checks = { 77 | 'd_loss1': d_loss1, 78 | 'd_loss2': d_loss2, 79 | 'd_loss3' : d_loss3, 80 | 'disc_real_image_logits' : disc_real_image_logits, 81 | 'disc_wrong_image_logits' : disc_wrong_image, 82 | 'disc_fake_image_logits' : disc_fake_image_logits 83 | } 84 | 85 | return input_tensors, variables, loss, outputs, checks 86 | 87 | def build_generator(self): 88 | img_size = self.options['image_size'] 89 | t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input') 90 | t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']]) 91 | fake_image = self.sampler(t_z, t_real_caption) 92 | 93 | input_tensors = { 94 | 't_real_caption' : t_real_caption, 95 | 't_z' : t_z 96 | } 97 | 98 | outputs = { 99 | 'generator' : fake_image 100 | } 101 | 102 | return input_tensors, outputs 103 | 104 | # Sample Images for a text embedding 105 | def sampler(self, t_z, t_text_embedding): 106 | tf.get_variable_scope().reuse_variables() 107 | 108 | s = self.options['image_size'] 109 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 110 | 111 | reduced_text_embedding = ops.lrelu( ops.linear(t_text_embedding, self.options['t_dim'], 'g_embedding') ) 112 | z_concat = tf.concat(1, [t_z, reduced_text_embedding]) 113 | z_ = ops.linear(z_concat, self.options['gf_dim']*8*s16*s16, 'g_h0_lin') 114 | h0 = tf.reshape(z_, [-1, s16, s16, self.options['gf_dim'] * 8]) 115 | h0 = tf.nn.relu(self.g_bn0(h0, train = False)) 116 | 117 | h1 = ops.deconv2d(h0, [self.options['batch_size'], s8, s8, self.options['gf_dim']*4], name='g_h1') 118 | h1 = tf.nn.relu(self.g_bn1(h1, train = False)) 119 | 120 | h2 = ops.deconv2d(h1, [self.options['batch_size'], s4, s4, self.options['gf_dim']*2], name='g_h2') 121 | h2 = tf.nn.relu(self.g_bn2(h2, train = False)) 122 | 123 | h3 = ops.deconv2d(h2, [self.options['batch_size'], s2, s2, self.options['gf_dim']*1], name='g_h3') 124 | h3 = tf.nn.relu(self.g_bn3(h3, train = False)) 125 | 126 | h4 = ops.deconv2d(h3, [self.options['batch_size'], s, s, 3], name='g_h4') 127 | 128 | return (tf.tanh(h4)/2. + 0.5) 129 | 130 | # GENERATOR IMPLEMENTATION based on : https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py 131 | def generator(self, t_z, t_text_embedding): 132 | 133 | s = self.options['image_size'] 134 | s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) 135 | 136 | reduced_text_embedding = ops.lrelu( ops.linear(t_text_embedding, self.options['t_dim'], 'g_embedding') ) 137 | z_concat = tf.concat(1, [t_z, reduced_text_embedding]) 138 | z_ = ops.linear(z_concat, self.options['gf_dim']*8*s16*s16, 'g_h0_lin') 139 | h0 = tf.reshape(z_, [-1, s16, s16, self.options['gf_dim'] * 8]) 140 | h0 = tf.nn.relu(self.g_bn0(h0)) 141 | 142 | h1 = ops.deconv2d(h0, [self.options['batch_size'], s8, s8, self.options['gf_dim']*4], name='g_h1') 143 | h1 = tf.nn.relu(self.g_bn1(h1)) 144 | 145 | h2 = ops.deconv2d(h1, [self.options['batch_size'], s4, s4, self.options['gf_dim']*2], name='g_h2') 146 | h2 = tf.nn.relu(self.g_bn2(h2)) 147 | 148 | h3 = ops.deconv2d(h2, [self.options['batch_size'], s2, s2, self.options['gf_dim']*1], name='g_h3') 149 | h3 = tf.nn.relu(self.g_bn3(h3)) 150 | 151 | h4 = ops.deconv2d(h3, [self.options['batch_size'], s, s, 3], name='g_h4') 152 | 153 | return (tf.tanh(h4)/2. + 0.5) 154 | 155 | # DISCRIMINATOR IMPLEMENTATION based on : https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py 156 | def discriminator(self, image, t_text_embedding, reuse=False): 157 | if reuse: 158 | tf.get_variable_scope().reuse_variables() 159 | 160 | h0 = ops.lrelu(ops.conv2d(image, self.options['df_dim'], name = 'd_h0_conv')) #32 161 | h1 = ops.lrelu( self.d_bn1(ops.conv2d(h0, self.options['df_dim']*2, name = 'd_h1_conv'))) #16 162 | h2 = ops.lrelu( self.d_bn2(ops.conv2d(h1, self.options['df_dim']*4, name = 'd_h2_conv'))) #8 163 | h3 = ops.lrelu( self.d_bn3(ops.conv2d(h2, self.options['df_dim']*8, name = 'd_h3_conv'))) #4 164 | 165 | # ADD TEXT EMBEDDING TO THE NETWORK 166 | reduced_text_embeddings = ops.lrelu(ops.linear(t_text_embedding, self.options['t_dim'], 'd_embedding')) 167 | reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,1) 168 | reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,2) 169 | tiled_embeddings = tf.tile(reduced_text_embeddings, [1,4,4,1], name='tiled_embeddings') 170 | 171 | h3_concat = tf.concat( 3, [h3, tiled_embeddings], name='h3_concat') 172 | h3_new = ops.lrelu( self.d_bn4(ops.conv2d(h3_concat, self.options['df_dim']*8, 1,1,1,1, name = 'd_h3_conv_new'))) #4 173 | 174 | h4 = ops.linear(tf.reshape(h3_new, [self.options['batch_size'], -1]), 1, 'd_h3_lin') 175 | 176 | return tf.nn.sigmoid(h4), h4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.6.0 2 | nltk==3.2.1 3 | numpy==1.11.2 4 | scipy==0.16.0 5 | scikit_image==0.12.3 6 | tensorflow==0.11.0rc0 7 | Theano==0.8.2 8 | -------------------------------------------------------------------------------- /skipthoughts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Skip-thought vectors 3 | https://github.com/ryankiros/skip-thoughts 4 | ''' 5 | import os 6 | 7 | import theano 8 | import theano.tensor as tensor 9 | 10 | import cPickle as pkl 11 | import numpy 12 | import copy 13 | import nltk 14 | 15 | from collections import OrderedDict, defaultdict 16 | from scipy.linalg import norm 17 | from nltk.tokenize import word_tokenize 18 | 19 | profile = False 20 | 21 | #-----------------------------------------------------------------------------# 22 | # Specify model and table locations here 23 | #-----------------------------------------------------------------------------# 24 | path_to_models = 'Data/skipthoughts/' 25 | path_to_tables = 'Data/skipthoughts/' 26 | #-----------------------------------------------------------------------------# 27 | 28 | path_to_umodel = path_to_models + 'uni_skip.npz' 29 | path_to_bmodel = path_to_models + 'bi_skip.npz' 30 | 31 | 32 | def load_model(): 33 | """ 34 | Load the model with saved tables 35 | """ 36 | # Load model options 37 | print 'Loading model parameters...' 38 | with open('%s.pkl'%path_to_umodel, 'rb') as f: 39 | uoptions = pkl.load(f) 40 | with open('%s.pkl'%path_to_bmodel, 'rb') as f: 41 | boptions = pkl.load(f) 42 | 43 | # Load parameters 44 | uparams = init_params(uoptions) 45 | uparams = load_params(path_to_umodel, uparams) 46 | utparams = init_tparams(uparams) 47 | bparams = init_params_bi(boptions) 48 | bparams = load_params(path_to_bmodel, bparams) 49 | btparams = init_tparams(bparams) 50 | 51 | # Extractor functions 52 | print 'Compiling encoders...' 53 | embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions) 54 | f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v') 55 | embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions) 56 | f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2') 57 | 58 | # Tables 59 | print 'Loading tables...' 60 | utable, btable = load_tables() 61 | 62 | # Store everything we need in a dictionary 63 | print 'Packing up...' 64 | model = {} 65 | model['uoptions'] = uoptions 66 | model['boptions'] = boptions 67 | model['utable'] = utable 68 | model['btable'] = btable 69 | model['f_w2v'] = f_w2v 70 | model['f_w2v2'] = f_w2v2 71 | 72 | return model 73 | 74 | 75 | def load_tables(): 76 | """ 77 | Load the tables 78 | """ 79 | words = [] 80 | utable = numpy.load(path_to_tables + 'utable.npy') 81 | btable = numpy.load(path_to_tables + 'btable.npy') 82 | f = open(path_to_tables + 'dictionary.txt', 'rb') 83 | for line in f: 84 | words.append(line.decode('utf-8').strip()) 85 | f.close() 86 | utable = OrderedDict(zip(words, utable)) 87 | btable = OrderedDict(zip(words, btable)) 88 | return utable, btable 89 | 90 | 91 | def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False): 92 | """ 93 | Encode sentences in the list X. Each entry will return a vector 94 | """ 95 | # first, do preprocessing 96 | X = preprocess(X) 97 | 98 | # word dictionary and init 99 | d = defaultdict(lambda : 0) 100 | for w in model['utable'].keys(): 101 | d[w] = 1 102 | ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32') 103 | bfeatures = numpy.zeros((len(X), 2 * model['boptions']['dim']), dtype='float32') 104 | 105 | # length dictionary 106 | ds = defaultdict(list) 107 | captions = [s.split() for s in X] 108 | for i,s in enumerate(captions): 109 | ds[len(s)].append(i) 110 | 111 | # Get features. This encodes by length, in order to avoid wasting computation 112 | for k in ds.keys(): 113 | if verbose: 114 | print k 115 | numbatches = len(ds[k]) / batch_size + 1 116 | for minibatch in range(numbatches): 117 | caps = ds[k][minibatch::numbatches] 118 | 119 | if use_eos: 120 | uembedding = numpy.zeros((k+1, len(caps), model['uoptions']['dim_word']), dtype='float32') 121 | bembedding = numpy.zeros((k+1, len(caps), model['boptions']['dim_word']), dtype='float32') 122 | else: 123 | uembedding = numpy.zeros((k, len(caps), model['uoptions']['dim_word']), dtype='float32') 124 | bembedding = numpy.zeros((k, len(caps), model['boptions']['dim_word']), dtype='float32') 125 | for ind, c in enumerate(caps): 126 | caption = captions[c] 127 | for j in range(len(caption)): 128 | if d[caption[j]] > 0: 129 | uembedding[j,ind] = model['utable'][caption[j]] 130 | bembedding[j,ind] = model['btable'][caption[j]] 131 | else: 132 | uembedding[j,ind] = model['utable']['UNK'] 133 | bembedding[j,ind] = model['btable']['UNK'] 134 | if use_eos: 135 | uembedding[-1,ind] = model['utable'][''] 136 | bembedding[-1,ind] = model['btable'][''] 137 | if use_eos: 138 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32')) 139 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32')) 140 | else: 141 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption),len(caps)), dtype='float32')) 142 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption),len(caps)), dtype='float32')) 143 | if use_norm: 144 | for j in range(len(uff)): 145 | uff[j] /= norm(uff[j]) 146 | bff[j] /= norm(bff[j]) 147 | for ind, c in enumerate(caps): 148 | ufeatures[c] = uff[ind] 149 | bfeatures[c] = bff[ind] 150 | 151 | features = numpy.c_[ufeatures, bfeatures] 152 | return features 153 | 154 | 155 | def preprocess(text): 156 | """ 157 | Preprocess text for encoder 158 | """ 159 | X = [] 160 | sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') 161 | for t in text: 162 | sents = sent_detector.tokenize(t) 163 | result = '' 164 | for s in sents: 165 | tokens = word_tokenize(s) 166 | result += ' ' + ' '.join(tokens) 167 | X.append(result) 168 | return X 169 | 170 | 171 | def nn(model, text, vectors, query, k=5): 172 | """ 173 | Return the nearest neighbour sentences to query 174 | text: list of sentences 175 | vectors: the corresponding representations for text 176 | query: a string to search 177 | """ 178 | qf = encode(model, [query]) 179 | qf /= norm(qf) 180 | scores = numpy.dot(qf, vectors.T).flatten() 181 | sorted_args = numpy.argsort(scores)[::-1] 182 | sentences = [text[a] for a in sorted_args[:k]] 183 | print 'QUERY: ' + query 184 | print 'NEAREST: ' 185 | for i, s in enumerate(sentences): 186 | print s, sorted_args[i] 187 | 188 | 189 | def word_features(table): 190 | """ 191 | Extract word features into a normalized matrix 192 | """ 193 | features = numpy.zeros((len(table), 620), dtype='float32') 194 | keys = table.keys() 195 | for i in range(len(table)): 196 | f = table[keys[i]] 197 | features[i] = f / norm(f) 198 | return features 199 | 200 | 201 | def nn_words(table, wordvecs, query, k=10): 202 | """ 203 | Get the nearest neighbour words 204 | """ 205 | keys = table.keys() 206 | qf = table[query] 207 | scores = numpy.dot(qf, wordvecs.T).flatten() 208 | sorted_args = numpy.argsort(scores)[::-1] 209 | words = [keys[a] for a in sorted_args[:k]] 210 | print 'QUERY: ' + query 211 | print 'NEAREST: ' 212 | for i, w in enumerate(words): 213 | print w 214 | 215 | 216 | def _p(pp, name): 217 | """ 218 | make prefix-appended name 219 | """ 220 | return '%s_%s'%(pp, name) 221 | 222 | 223 | def init_tparams(params): 224 | """ 225 | initialize Theano shared variables according to the initial parameters 226 | """ 227 | tparams = OrderedDict() 228 | for kk, pp in params.iteritems(): 229 | tparams[kk] = theano.shared(params[kk], name=kk) 230 | return tparams 231 | 232 | 233 | def load_params(path, params): 234 | """ 235 | load parameters 236 | """ 237 | pp = numpy.load(path) 238 | for kk, vv in params.iteritems(): 239 | if kk not in pp: 240 | warnings.warn('%s is not in the archive'%kk) 241 | continue 242 | params[kk] = pp[kk] 243 | return params 244 | 245 | 246 | # layers: 'name': ('parameter initializer', 'feedforward') 247 | layers = {'gru': ('param_init_gru', 'gru_layer')} 248 | 249 | def get_layer(name): 250 | fns = layers[name] 251 | return (eval(fns[0]), eval(fns[1])) 252 | 253 | 254 | def init_params(options): 255 | """ 256 | initialize all parameters needed for the encoder 257 | """ 258 | params = OrderedDict() 259 | 260 | # embedding 261 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 262 | 263 | # encoder: GRU 264 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 265 | nin=options['dim_word'], dim=options['dim']) 266 | return params 267 | 268 | 269 | def init_params_bi(options): 270 | """ 271 | initialize all paramters needed for bidirectional encoder 272 | """ 273 | params = OrderedDict() 274 | 275 | # embedding 276 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 277 | 278 | # encoder: GRU 279 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 280 | nin=options['dim_word'], dim=options['dim']) 281 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r', 282 | nin=options['dim_word'], dim=options['dim']) 283 | return params 284 | 285 | 286 | def build_encoder(tparams, options): 287 | """ 288 | build an encoder, given pre-computed word embeddings 289 | """ 290 | # word embedding (source) 291 | embedding = tensor.tensor3('embedding', dtype='float32') 292 | x_mask = tensor.matrix('x_mask', dtype='float32') 293 | 294 | # encoder 295 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 296 | prefix='encoder', 297 | mask=x_mask) 298 | ctx = proj[0][-1] 299 | 300 | return embedding, x_mask, ctx 301 | 302 | 303 | def build_encoder_bi(tparams, options): 304 | """ 305 | build bidirectional encoder, given pre-computed word embeddings 306 | """ 307 | # word embedding (source) 308 | embedding = tensor.tensor3('embedding', dtype='float32') 309 | embeddingr = embedding[::-1] 310 | x_mask = tensor.matrix('x_mask', dtype='float32') 311 | xr_mask = x_mask[::-1] 312 | 313 | # encoder 314 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 315 | prefix='encoder', 316 | mask=x_mask) 317 | projr = get_layer(options['encoder'])[1](tparams, embeddingr, options, 318 | prefix='encoder_r', 319 | mask=xr_mask) 320 | 321 | ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1) 322 | 323 | return embedding, x_mask, ctx 324 | 325 | 326 | # some utilities 327 | def ortho_weight(ndim): 328 | W = numpy.random.randn(ndim, ndim) 329 | u, s, v = numpy.linalg.svd(W) 330 | return u.astype('float32') 331 | 332 | 333 | def norm_weight(nin,nout=None, scale=0.1, ortho=True): 334 | if nout == None: 335 | nout = nin 336 | if nout == nin and ortho: 337 | W = ortho_weight(nin) 338 | else: 339 | W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout)) 340 | return W.astype('float32') 341 | 342 | 343 | def param_init_gru(options, params, prefix='gru', nin=None, dim=None): 344 | """ 345 | parameter init for GRU 346 | """ 347 | if nin == None: 348 | nin = options['dim_proj'] 349 | if dim == None: 350 | dim = options['dim_proj'] 351 | W = numpy.concatenate([norm_weight(nin,dim), 352 | norm_weight(nin,dim)], axis=1) 353 | params[_p(prefix,'W')] = W 354 | params[_p(prefix,'b')] = numpy.zeros((2 * dim,)).astype('float32') 355 | U = numpy.concatenate([ortho_weight(dim), 356 | ortho_weight(dim)], axis=1) 357 | params[_p(prefix,'U')] = U 358 | 359 | Wx = norm_weight(nin, dim) 360 | params[_p(prefix,'Wx')] = Wx 361 | Ux = ortho_weight(dim) 362 | params[_p(prefix,'Ux')] = Ux 363 | params[_p(prefix,'bx')] = numpy.zeros((dim,)).astype('float32') 364 | 365 | return params 366 | 367 | 368 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs): 369 | """ 370 | Forward pass through GRU layer 371 | """ 372 | nsteps = state_below.shape[0] 373 | if state_below.ndim == 3: 374 | n_samples = state_below.shape[1] 375 | else: 376 | n_samples = 1 377 | 378 | dim = tparams[_p(prefix,'Ux')].shape[1] 379 | 380 | if mask == None: 381 | mask = tensor.alloc(1., state_below.shape[0], 1) 382 | 383 | def _slice(_x, n, dim): 384 | if _x.ndim == 3: 385 | return _x[:, :, n*dim:(n+1)*dim] 386 | return _x[:, n*dim:(n+1)*dim] 387 | 388 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')] 389 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')] 390 | U = tparams[_p(prefix, 'U')] 391 | Ux = tparams[_p(prefix, 'Ux')] 392 | 393 | def _step_slice(m_, x_, xx_, h_, U, Ux): 394 | preact = tensor.dot(h_, U) 395 | preact += x_ 396 | 397 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) 398 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) 399 | 400 | preactx = tensor.dot(h_, Ux) 401 | preactx = preactx * r 402 | preactx = preactx + xx_ 403 | 404 | h = tensor.tanh(preactx) 405 | 406 | h = u * h_ + (1. - u) * h 407 | h = m_[:,None] * h + (1. - m_)[:,None] * h_ 408 | 409 | return h 410 | 411 | seqs = [mask, state_below_, state_belowx] 412 | _step = _step_slice 413 | 414 | rval, updates = theano.scan(_step, 415 | sequences=seqs, 416 | outputs_info = [tensor.alloc(0., n_samples, dim)], 417 | non_sequences = [tparams[_p(prefix, 'U')], 418 | tparams[_p(prefix, 'Ux')]], 419 | name=_p(prefix, '_layers'), 420 | n_steps=nsteps, 421 | profile=profile, 422 | strict=True) 423 | rval = [rval] 424 | return rval 425 | 426 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import model 4 | import argparse 5 | import pickle 6 | from os.path import join 7 | import h5py 8 | from Utils import image_processing 9 | import scipy.misc 10 | import random 11 | import json 12 | import os 13 | import shutil 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--z_dim', type=int, default=100, 18 | help='Noise dimension') 19 | 20 | parser.add_argument('--t_dim', type=int, default=256, 21 | help='Text feature dimension') 22 | 23 | parser.add_argument('--batch_size', type=int, default=64, 24 | help='Batch Size') 25 | 26 | parser.add_argument('--image_size', type=int, default=64, 27 | help='Image Size a, a x a') 28 | 29 | parser.add_argument('--gf_dim', type=int, default=64, 30 | help='Number of conv in the first layer gen.') 31 | 32 | parser.add_argument('--df_dim', type=int, default=64, 33 | help='Number of conv in the first layer discr.') 34 | 35 | parser.add_argument('--gfc_dim', type=int, default=1024, 36 | help='Dimension of gen untis for for fully connected layer 1024') 37 | 38 | parser.add_argument('--caption_vector_length', type=int, default=2400, 39 | help='Caption Vector Length') 40 | 41 | parser.add_argument('--data_dir', type=str, default="Data", 42 | help='Data Directory') 43 | 44 | parser.add_argument('--learning_rate', type=float, default=0.0002, 45 | help='Learning Rate') 46 | 47 | parser.add_argument('--beta1', type=float, default=0.5, 48 | help='Momentum for Adam Update') 49 | 50 | parser.add_argument('--epochs', type=int, default=600, 51 | help='Max number of epochs') 52 | 53 | parser.add_argument('--save_every', type=int, default=30, 54 | help='Save Model/Samples every x iterations over batches') 55 | 56 | parser.add_argument('--resume_model', type=str, default=None, 57 | help='Pre-Trained Model Path, to resume from') 58 | 59 | parser.add_argument('--data_set', type=str, default="flowers", 60 | help='Dat set: MS-COCO, flowers') 61 | 62 | args = parser.parse_args() 63 | model_options = { 64 | 'z_dim' : args.z_dim, 65 | 't_dim' : args.t_dim, 66 | 'batch_size' : args.batch_size, 67 | 'image_size' : args.image_size, 68 | 'gf_dim' : args.gf_dim, 69 | 'df_dim' : args.df_dim, 70 | 'gfc_dim' : args.gfc_dim, 71 | 'caption_vector_length' : args.caption_vector_length 72 | } 73 | 74 | 75 | gan = model.GAN(model_options) 76 | input_tensors, variables, loss, outputs, checks = gan.build_model() 77 | 78 | d_optim = tf.train.AdamOptimizer(args.learning_rate, beta1 = args.beta1).minimize(loss['d_loss'], var_list=variables['d_vars']) 79 | g_optim = tf.train.AdamOptimizer(args.learning_rate, beta1 = args.beta1).minimize(loss['g_loss'], var_list=variables['g_vars']) 80 | 81 | sess = tf.InteractiveSession() 82 | tf.initialize_all_variables().run() 83 | 84 | saver = tf.train.Saver() 85 | if args.resume_model: 86 | saver.restore(sess, args.resume_model) 87 | 88 | loaded_data = load_training_data(args.data_dir, args.data_set) 89 | 90 | for i in range(args.epochs): 91 | batch_no = 0 92 | while batch_no*args.batch_size < loaded_data['data_length']: 93 | real_images, wrong_images, caption_vectors, z_noise, image_files = get_training_batch(batch_no, args.batch_size, 94 | args.image_size, args.z_dim, args.caption_vector_length, 'train', args.data_dir, args.data_set, loaded_data) 95 | 96 | # DISCR UPDATE 97 | check_ts = [ checks['d_loss1'] , checks['d_loss2'], checks['d_loss3']] 98 | _, d_loss, gen, d1, d2, d3 = sess.run([d_optim, loss['d_loss'], outputs['generator']] + check_ts, 99 | feed_dict = { 100 | input_tensors['t_real_image'] : real_images, 101 | input_tensors['t_wrong_image'] : wrong_images, 102 | input_tensors['t_real_caption'] : caption_vectors, 103 | input_tensors['t_z'] : z_noise, 104 | }) 105 | 106 | print "d1", d1 107 | print "d2", d2 108 | print "d3", d3 109 | print "D", d_loss 110 | 111 | # GEN UPDATE 112 | _, g_loss, gen = sess.run([g_optim, loss['g_loss'], outputs['generator']], 113 | feed_dict = { 114 | input_tensors['t_real_image'] : real_images, 115 | input_tensors['t_wrong_image'] : wrong_images, 116 | input_tensors['t_real_caption'] : caption_vectors, 117 | input_tensors['t_z'] : z_noise, 118 | }) 119 | 120 | # GEN UPDATE TWICE, to make sure d_loss does not go to 0 121 | _, g_loss, gen = sess.run([g_optim, loss['g_loss'], outputs['generator']], 122 | feed_dict = { 123 | input_tensors['t_real_image'] : real_images, 124 | input_tensors['t_wrong_image'] : wrong_images, 125 | input_tensors['t_real_caption'] : caption_vectors, 126 | input_tensors['t_z'] : z_noise, 127 | }) 128 | 129 | print "LOSSES", d_loss, g_loss, batch_no, i, len(loaded_data['image_list'])/ args.batch_size 130 | batch_no += 1 131 | if (batch_no % args.save_every) == 0: 132 | print "Saving Images, Model" 133 | save_for_vis(args.data_dir, real_images, gen, image_files) 134 | save_path = saver.save(sess, "Data/Models/latest_model_{}_temp.ckpt".format(args.data_set)) 135 | if i%5 == 0: 136 | save_path = saver.save(sess, "Data/Models/model_after_{}_epoch_{}.ckpt".format(args.data_set, i)) 137 | 138 | def load_training_data(data_dir, data_set): 139 | if data_set == 'flowers': 140 | h = h5py.File(join(data_dir, 'flower_tv.hdf5')) 141 | flower_captions = {} 142 | for ds in h.iteritems(): 143 | flower_captions[ds[0]] = np.array(ds[1]) 144 | image_list = [key for key in flower_captions] 145 | image_list.sort() 146 | 147 | img_75 = int(len(image_list)*0.75) 148 | training_image_list = image_list[0:img_75] 149 | random.shuffle(training_image_list) 150 | 151 | return { 152 | 'image_list' : training_image_list, 153 | 'captions' : flower_captions, 154 | 'data_length' : len(training_image_list) 155 | } 156 | 157 | else: 158 | with open(join(data_dir, 'meta_train.pkl')) as f: 159 | meta_data = pickle.load(f) 160 | # No preloading for MS-COCO 161 | return meta_data 162 | 163 | def save_for_vis(data_dir, real_images, generated_images, image_files): 164 | 165 | shutil.rmtree( join(data_dir, 'samples') ) 166 | os.makedirs( join(data_dir, 'samples') ) 167 | 168 | for i in range(0, real_images.shape[0]): 169 | real_image_255 = np.zeros( (64,64,3), dtype=np.uint8) 170 | real_images_255 = (real_images[i,:,:,:]) 171 | scipy.misc.imsave( join(data_dir, 'samples/{}_{}.jpg'.format(i, image_files[i].split('/')[-1] )) , real_images_255) 172 | 173 | fake_image_255 = np.zeros( (64,64,3), dtype=np.uint8) 174 | fake_images_255 = (generated_images[i,:,:,:]) 175 | scipy.misc.imsave(join(data_dir, 'samples/fake_image_{}.jpg'.format(i)), fake_images_255) 176 | 177 | 178 | def get_training_batch(batch_no, batch_size, image_size, z_dim, 179 | caption_vector_length, split, data_dir, data_set, loaded_data = None): 180 | if data_set == 'mscoco': 181 | with h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no))) as hf: 182 | caption_vectors = np.array(hf.get('tv')) 183 | caption_vectors = caption_vectors[:,0:caption_vector_length] 184 | with h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no))) as hf: 185 | image_ids = np.array(hf.get('tv')) 186 | 187 | real_images = np.zeros((batch_size, 64, 64, 3)) 188 | wrong_images = np.zeros((batch_size, 64, 64, 3)) 189 | 190 | image_files = [] 191 | for idx, image_id in enumerate(image_ids): 192 | image_file = join(data_dir, '%s2014/COCO_%s2014_%.12d.jpg'%(split, split, image_id) ) 193 | image_array = image_processing.load_image_array(image_file, image_size) 194 | real_images[idx,:,:,:] = image_array 195 | image_files.append(image_file) 196 | 197 | # TODO>> As of Now, wrong images are just shuffled real images. 198 | first_image = real_images[0,:,:,:] 199 | for i in range(0, batch_size): 200 | if i < batch_size - 1: 201 | wrong_images[i,:,:,:] = real_images[i+1,:,:,:] 202 | else: 203 | wrong_images[i,:,:,:] = first_image 204 | 205 | z_noise = np.random.uniform(-1, 1, [batch_size, z_dim]) 206 | 207 | 208 | return real_images, wrong_images, caption_vectors, z_noise, image_files 209 | 210 | if data_set == 'flowers': 211 | real_images = np.zeros((batch_size, 64, 64, 3)) 212 | wrong_images = np.zeros((batch_size, 64, 64, 3)) 213 | captions = np.zeros((batch_size, caption_vector_length)) 214 | 215 | cnt = 0 216 | image_files = [] 217 | for i in range(batch_no * batch_size, batch_no * batch_size + batch_size): 218 | idx = i % len(loaded_data['image_list']) 219 | image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][idx]) 220 | image_array = image_processing.load_image_array(image_file, image_size) 221 | real_images[cnt,:,:,:] = image_array 222 | 223 | # Improve this selection of wrong image 224 | wrong_image_id = random.randint(0,len(loaded_data['image_list'])-1) 225 | wrong_image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][wrong_image_id]) 226 | wrong_image_array = image_processing.load_image_array(wrong_image_file, image_size) 227 | wrong_images[cnt, :,:,:] = wrong_image_array 228 | 229 | random_caption = random.randint(0,4) 230 | captions[cnt,:] = loaded_data['captions'][ loaded_data['image_list'][idx] ][ random_caption ][0:caption_vector_length] 231 | image_files.append( image_file ) 232 | cnt += 1 233 | 234 | z_noise = np.random.uniform(-1, 1, [batch_size, z_dim]) 235 | return real_images, wrong_images, captions, z_noise, image_files 236 | 237 | if __name__ == '__main__': 238 | main() 239 | --------------------------------------------------------------------------------