├── README.md ├── data_loader.py ├── download_skipthought.py ├── imgs └── net.jpeg ├── main.py ├── nets ├── __init__.py ├── discriminator.py └── generator.py ├── predict.py ├── skipthoughts.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Text to Image Synthesis using Skip-thought Vectors 2 | 3 | **The code for this model has not been tested because of time constraints. All the steps have been given below to train/test the model. Feel free to create a pull request if anyone is interested in testing this model.** 4 | 5 | ## Description 6 | This is a PyTorch implementation of the paper Generative Adversarial Text-to-Image Synthesis [http://arxiv.org/abs/1605.05396] using skip thought vectors for caption embedding. This implementation is based on DCGAN. Below is the model architecture where blue bars represent skip thought vector for the captions. 7 | 8 | [Figure] 9 | Image Source : Paper 10 | 11 | ## Setup and Installments 12 | * Python==3.6.6 13 | * PyTorch==0.4.0 14 | * TorchVision==0.2.1 15 | * Theano 16 | 17 | ## Dataset 18 | * This model can be trained on the flowers dataset. Download flower dataset from here[] and save the images in Data folder as Data/flowers. 19 | * Now download the corresponding captions from here[]. After extracting, copy the text_c10 folder and paste it in Data folder as Data/text_c10. 20 | 21 | ## Skip-Thought Model 22 | * Download the pretrained models and vocabulary for skip thought vectors as per the instructions given below. Save the downloaded files in Data/skipthoughts. 23 | 24 | * Some of the files are quite large(>2GB). So make sure there is enough space available. 25 | 26 | * Run below code to download skip thought model and all other required files 27 | python download_skipthought.py 28 | 29 | 30 | ## Usage 31 | * Data Pre-processing : 32 | ```shell 33 | $ python data_loader.py 34 | ``` 35 | 36 | * Training Arguments: 37 | ```shell 38 | dataset : Dataset used. Default = flowers 39 | batch_size : Batch Size. Default = 1 40 | num_epochs : NUmber of epochs to train. Default = 200 41 | img_size : Size of the image. Default = 64 42 | z_dim : Latent variable dimension. Default = 100 43 | text_embedding_dim : Embedding dim of caption. Default = 4800 44 | reduced_text_dim : Reduced embedding dim of caption. Default = 1024 45 | learning_rate : Learning Rate. Default = 0.0002 46 | beta1 : Hyperparameter of the Adam optimizer. Default = 0.5 47 | beta2 : Hyperparameter of the Adam optimizer. Default = 0.999 48 | l1_coeff : Coefficient for the L1 Loss. Default = 50 49 | resume_epoch : Resume epoch to resume training. Default = 1 50 | ``` 51 | * Train the model by running below code 52 | ```shell 53 | $ python main.py 54 | ``` 55 | * Testing model by giving custom input text 56 | ```shell 57 | $ python predict.py --text="Input caption to be used to generate the image" 58 | ``` 59 | The generated image will be save to text directory inside Data folder as Data/Testing 60 | 61 | ## Model key-points 62 | 63 | * Skip Thought is an efficient model used for sentence embedding and is based on the concept of word 64 | embedding (word2vec or Glove). It returns a numpy array of dimension 4800 in which the first 2400 65 | dimensions is the uni-skip model and the last 2400 dimensions is the bi-skip model. We use the combine 66 | -skip vectors as experimentally, they perform the best. 67 | 68 | * Text2Image model is a Generarive Adversarial Network based model which is built on top of the DCGAN. 69 | It consists of a Discriminator network and a Generator network. 70 | 71 | * Discriminator network not only classifies the images generated by the generate as a fake image but also those real images which do not correspond to the correct caption. In short, fake examples are categorized by following : 72 | Fake Image + Correct Caption 73 | False Image(Real Image) + Incorrect Caption 74 | 75 | * Images are 64 x 64 in dimension 76 | 77 | ## Generated Images 78 | Following are some of the images generated by this model 79 | [A table of few 5-6 images along with their captions] 80 | 81 | ## TODO 82 | Implementation of the same using an autoencoder for sentence embedding 83 | 84 | 85 | ## References 86 | * Generative Adversarial Text-to-Image Synthesis - http://arxiv.org/abs/1605.05396 87 | * Tensorflow implementation - https://github.com/paarthneekhara/text-to-image 88 | * Skip-Thought Model - https://github.com/ryankiros/skip-thoughts 89 | 90 | 91 | ## License 92 | MIT 93 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import skipthoughts 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.autograd import Variable 8 | from torch.utils.data import Dataset 9 | 10 | # Each batch will have 3 things : true image, its captions(5), and false image(real image but image 11 | # corresponding to an incorrect caption). 12 | # Discriminator is trained in such a way that true_img + caption corresponds to a real example and 13 | # false_img + caption corresponds to a fake example. 14 | 15 | 16 | class Text2ImageDataset(Dataset): 17 | 18 | def __init__(self, data_dir): 19 | self.data_dir = data_dir 20 | 21 | self.load_flower_dataset() 22 | 23 | def load_flower_dataset(self): 24 | # It will return two things : a list of image file names, a dictionary of 5 captions per image 25 | # with image file name as the key of the dictionary and 5 values(captions) for each key. 26 | 27 | print ("------------------ Loading images ------------------") 28 | self.img_files = [] 29 | for f in os.listdir(os.path.join(self.data_dir, 'flowers')): 30 | self.img_files.append(f) 31 | 32 | print ('Total number of images : {}'.format(len(self.img_files))) 33 | 34 | print ("------------------ Loading captions ----------------") 35 | self.img_captions = {} 36 | for class_dir in tqdm(os.listdir(os.path.join(self.data_dir, 'text_c10'))): 37 | if not 't7' in class_dir: 38 | for cap_file in class_dir: 39 | if 'txt' in cap_file: 40 | with open(cap_file) as f: 41 | captions = f.read().split('\n') 42 | img_file = cap_file[:11] + '.jpg' 43 | # 5 captions per image 44 | self.img_captions[img_file] = captions[:5] 45 | 46 | print ("--------------- Loading Skip-thought Model ---------------") 47 | model = skipthoughts.load_model() 48 | self.encoded_captions = {} 49 | 50 | print ("------------ Encoding of image captions STARTED ------------") 51 | for img_file in self.img_captions: 52 | self.encoded_captions[img_file] = skipthoughts.encode(model, self.img_captions[img_file]) 53 | # print (type(self.encoded_captions[img_file])) 54 | # convert it to torch tensor if it is a numpy array 55 | 56 | print ("------------- Encoding of image captions DONE -------------") 57 | 58 | def read_image(self, image_file_name): 59 | image = Image.open(os.path.join(self.data_dir, 'flowers/' + image_file_name)) 60 | # check its shape and reshape it to (64, 64, 3) 61 | return image 62 | 63 | def get_false_img(self, index): 64 | false_img_id = np.random.randint(len(self.img_files)) 65 | if false_img_id != index: 66 | return self.img_files[false_img_id] 67 | 68 | return self.get_false_img(index) 69 | 70 | def __len__(self): 71 | 72 | return len(self.img_files) 73 | 74 | def __getitem__(self, index): 75 | 76 | sample = {} 77 | sample['true_imgs'] = torch.FloatTensor(self.read_image(self.img_files[index])) 78 | sample['false_imgs'] = torch.FloatTensor(self.read_image(self.get_false_img(index))) 79 | sample['true_embed'] = torch.FloatTensor(self.encoded_captions[self.img_files[index]]) 80 | 81 | return sample 82 | -------------------------------------------------------------------------------- /download_skipthought.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | print ('Downloading Skip-Thought Model ...........') 4 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/dictionary.txt') 5 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/utable.npy') 6 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/btable.npy') 7 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz') 8 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz.pkl') 9 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz') 10 | os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz.pkl') 11 | 12 | print ('Download Completed ............') -------------------------------------------------------------------------------- /imgs/net.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prakashpandey9/Text2Image-PyTorch/1cafacdc284590c30c635e7e519a5acaabd4463c/imgs/net.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import train 4 | import argparse 5 | import numpy as np 6 | 7 | from train import GAN_CLS 8 | from torch.utils.data import DataLoader 9 | from data_loader import Text2ImageDataset 10 | 11 | 12 | def check_dir(dir_name): 13 | if not os.path.exists(dir_name): 14 | os.makedirs(dir_name) 15 | 16 | print ('{} created'.format(dir_name)) 17 | 18 | 19 | def check_args(args): 20 | # Make all directories if they don't exist 21 | 22 | # --checkpoint_dir 23 | check_dir(args.checkpoint_dir) 24 | 25 | # --sample_dir 26 | check_dir(args.sample_dir) 27 | 28 | # --log_dir 29 | check_dir(args.log_dir) 30 | 31 | # --final_model dir 32 | check_dir(args.final_model) 33 | 34 | # --epoch 35 | assert args.num_epochs > 0, 'Number of epochs must be greater than 0' 36 | 37 | # --batch_size 38 | assert args.batch_size > 0, 'Batch size must be greater than zero' 39 | 40 | # --z_dim 41 | assert args.z_dim > 0, 'Size of the noise vector must be greater than zero' 42 | 43 | return args 44 | 45 | 46 | def main(): 47 | 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument_group('Dataset related arguments') 51 | parser.add_argument('--data_dir', type=str, default="Data", 52 | help='Data Directory') 53 | 54 | parser.add_argument('--dataset', type=str, default="flowers", 55 | help='Dataset to train') 56 | 57 | parser.add_argument_group('Model saving path and steps related arguments') 58 | parser.add_argument('--log_step', type=int, default=100, 59 | help='Save INFO into logger after every x iterations') 60 | 61 | parser.add_argument('--sample_step', type=int, default=100, 62 | help='Save generated image after every x iterations') 63 | 64 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', 65 | help='Save model checkpoints after every x iterations') 66 | 67 | parser.add_argument('--sample_dir', type=str, default='sample', 68 | help='Save generated image after every x iterations') 69 | 70 | parser.add_argument('--log_dir', type=str, default='logs', 71 | help='Save INFO into logger after every x iterations') 72 | 73 | parser.add_argument('--final_model', type=str, default='final_model', 74 | help='Save INFO into logger after every x iterations') 75 | 76 | parser.add_argument_group('Model training related arguments') 77 | parser.add_argument('--num_epochs', type=int, default=200, 78 | help='Total number of epochs to train') 79 | 80 | parser.add_argument('--batch_size', type=int, default=1, 81 | help='Batch Size') 82 | 83 | parser.add_argument('--img_size', type=int, default=64, 84 | help='Size of the image') 85 | 86 | parser.add_argument('--z_dim', type=int, default=100, 87 | help='Size of the latent variable') 88 | 89 | parser.add_argument('--text_embed_dim', type=int, default=4800, 90 | help='Size of the embeddding for the captions') 91 | 92 | parser.add_argument('--text_reduced_dim', type=int, default=1024, 93 | help='Reduced dimension of the caption encoding') 94 | 95 | parser.add_argument('--learning_rate', type=float, default=0.0002, 96 | help='Learning Rate') 97 | 98 | parser.add_argument('--beta1', type=float, default=0.5, 99 | help='Hyperparameter of the Adam optimizer') 100 | 101 | parser.add_argument('--beta2', type=float, default=0.999, 102 | help='Hyperparameter of the Adam optimizer') 103 | 104 | parser.add_argument('--l1_coeff', type=float, default=50, 105 | help='Coefficient for the L1 Loss') 106 | 107 | parser.add_argument('--resume_epoch', type=int, default=1, 108 | help='Resume epoch to resume training') 109 | 110 | args = parser.parse_args() 111 | 112 | check_args(args) 113 | 114 | dataset = Text2ImageDataset() 115 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 116 | 117 | gan = GAN_CLS(args, data_loader) 118 | 119 | gan.build_model() 120 | gan.train_model() 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prakashpandey9/Text2Image-PyTorch/1cafacdc284590c30c635e7e519a5acaabd4463c/nets/__init__.py -------------------------------------------------------------------------------- /nets/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | 6 | 7 | class Discriminator(nn.Module): 8 | def __init__(self, batch_size, img_size, text_embed_dim, text_reduced_dim): 9 | super(Discriminator, self).__init__() 10 | 11 | self.batch_size = batch_size 12 | self.img_size = img_size 13 | self.in_channels = img_size.size()[2] 14 | self.text_embed_dim = text_embed_dim 15 | self.text_reduced_dim = text_reduced_dim 16 | 17 | # Defining the discriminator network architecture 18 | self.d_net = nn.Sequential( 19 | nn.Conv2d(self.in_channels, 64, 4, 2, 1, bias=False), 20 | nn.LeakyReLU(0.2, inplace=True), 21 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), 22 | nn.BatchNorm2d(128), 23 | nn.LeakyReLU(0.2, inplace=True), 24 | nn.Conv2d(128, 256, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(256), 26 | nn.LeakyReLU(0.2, inplace=True), 27 | nn.Conv2d(256, 512, 4, 2, 1, bias=False), 28 | nn.BatchNorm2d(512), 29 | nn.LeakyReLU(0.2, inplace=True)) 30 | 31 | # output_dim = (batch_size, 4, 4, 512) 32 | # text.size() = (batch_size, text_embed_dim) 33 | 34 | # Defining a linear layer to reduce the dimensionality of caption embedding 35 | # from text_embed_dim to text_reduced_dim 36 | self.text_reduced_dim = nn.Linear(self.text_embed_dim, self.text_reduced_dim) 37 | 38 | self.cat_net = nn.Sequential( 39 | nn.Conv2d(512 + self.text_reduced_dim, 512, 4, 2, 1, bias=False), 40 | nn.BatchNorm2d(512), 41 | nn.LeakyReLU(0.2, inplace=True)) 42 | 43 | self.linear = nn.Linear(2 * 2 * 512, 1) 44 | 45 | def forward(self, image, text): 46 | """ Given the image and its caption embedding, predict whether the image 47 | is real or fake. 48 | 49 | Arguments 50 | --------- 51 | image : torch.FloatTensor 52 | image.size() = (batch_size, 64, 64, 3) 53 | 54 | text : torch.FloatTensor 55 | Output of the skipthought embedding model for the caption 56 | text.size() = (batch_size, text_embed_dim) 57 | 58 | -------- 59 | Returns 60 | -------- 61 | output : Probability for the image being real/fake 62 | logit : Final score of the discriminator 63 | 64 | """ 65 | 66 | d_net_out = self.d_net(image) # (batch_size, 4, 4, 512) 67 | text_reduced = self.text_reduced_dim(text) # (batch_size, text_reduced_dim) 68 | text_reduced = text_reduced.squeeze(1) # (batch_size, 1, text_reduced_dim) 69 | text_reduced = text_reduced.squeeze(2) # (batch_size, 1, 1, text_reduced_dim) 70 | text_reduced = text_reduced.expand(1, 4, 4, self.text_reduced_dim) 71 | 72 | concat_out = torch.cat((d_net_out, text_reduced), 3) # (1, 4, 4, 512+text_reduced_dim) 73 | 74 | logit = self.cat_net(concat_out) 75 | concat_out = torch.view(-1, concat_out.size()[1] * concat_out.size()[2] * concat_out.size()[3]) 76 | concat_out = self.linear(concat_out) 77 | 78 | output = F.sigmoid(logit) 79 | 80 | return output, logit 81 | -------------------------------------------------------------------------------- /nets/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, batch_size, img_size, z_dim, text_embed_dim, reduced_text_dim): 9 | super(Generator, self).__init__() 10 | 11 | self.img_size = img_size 12 | self.z_dim = z_dim 13 | self.text_embed_dim = text_embed_dim 14 | self.reduced_text_dim = reduced_text_dim 15 | 16 | self.reduced_text_dim = nn.Linear(text_embed_dim, reduced_text_dim) 17 | self.concat = nn.Linear(z_dim + reduced_text_dim, 64 * 8 * 4 * 4) 18 | 19 | # Defining the generator network architecture 20 | self.d_net = nn.Sequential( 21 | nn.ReLU(), 22 | nn.ConvTranspose2d(512, 256, 4, 2, 1), 23 | nn.BatchNorm2d(256), 24 | nn.ReLU(), 25 | nn.ConvTranspose2d(256, 128, 4, 2, 1), 26 | nn.BatchNorm2d(128), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 29 | nn.BatchNorm2d(64), 30 | nn.ReLU(), 31 | nn.ConvTranspose2d(64, 3, 4, 2, 1), 32 | nn.Tanh() 33 | ) 34 | 35 | def forward(self, text, z): 36 | """ Given a caption embedding and latent variable z(noise), generate an image 37 | 38 | Arguments 39 | --------- 40 | text : torch.FloatTensor 41 | Output of the skipthought embedding model for the caption 42 | text.size() = (batch_size, text_embed_dim) 43 | 44 | z : torch.FloatTensor 45 | Latent variable or noise 46 | z.size() = (batch_size, z_dim) 47 | 48 | -------- 49 | Returns 50 | -------- 51 | output : An image of shape (64, 64, 3) 52 | 53 | """ 54 | reduced_text = self.reduced_text_dim(text) # (batch_size, reduced_text_dim) 55 | concat = torch.cat((reduced_text, z), 1) # (batch_size, reduced_text_dim + z_dim) 56 | concat = self.concat(concat) # (batch_size, 64*8*4*4) 57 | concat = torch.view(-1, 4, 4, 64 * 8) # (batch_size, 4, 4, 64*8) 58 | d_net_out = self.d_net(concat) # (batch_size, 64, 64, 3) 59 | output = d_net_out / 2. + 0.5 # (batch_size, 64, 64, 3) 60 | 61 | return output 62 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argsparse 4 | import skipthoughts 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch.nn import functional as F 9 | from torchvision.utils import save_image 10 | 11 | from net import Generator 12 | 13 | 14 | def test(): 15 | 16 | parser = argsparse.ArgumentParser() 17 | parser.add_argument('--batch_size', type=int, default=1, 18 | help='Batch Size') 19 | parser.add_argument('--img_size', type=int, default=64, 20 | help='Size of the image') 21 | parser.add_argument('--z_dim', type=int, default=100, 22 | help='Size of the latent variable') 23 | parser.add_argument('--final_model', type=str, default='final_model', 24 | help='Save INFO into logger after every x iterations') 25 | parser.add_argument('--save_img', type=str, default='test', 26 | help='Save predicted images') 27 | parser.add_argument('--text_embed_dim', type=int, default=4800, 28 | help='Size of the embeddding for the captions') 29 | parser.add_argument('--text_reduced_dim', type=int, default=1024, 30 | help='Reduced dimension of the caption encoding') 31 | parser.add_argument('--text', type=str, help='Input text to be converted into image') 32 | 33 | config = parser.parse_args() 34 | if not os.path.exists(config.save_img): 35 | os.makedirs('Data' + config.save_img) 36 | 37 | start_time = time.time() 38 | gen = Generator(batch_size=config.batch_size, 39 | img_size=config.img_size, 40 | z_dim=config.z_dim, 41 | text_embed_dim=config.text_embed_dim, 42 | text_reduced_dim=config.text_reduced_dim) 43 | 44 | # Loading the trained model 45 | G_path = os.path.join(config.final_model, '{}-G.pth'.format('final')) 46 | gen.load_state_dict(torch.load(G_path)) 47 | # torch.load(gen.state_dict(), G_path) 48 | gen.eval() 49 | 50 | z = Variable(torch.randn(config.batch_size, config.z_dim)).cuda() 51 | model = skipthoughts.load_model() 52 | text_embed = skipthoughts.encode(model, config.text) 53 | output_img = gen(text_embed, z) 54 | save_image(output_img.cpu(), config.save_img, nrow=1, padding=0) 55 | 56 | print ('Generated image save to {}'.format(config.save_img)) 57 | print ('Time taken for the task : {}'.format(time.time() - start_time)) 58 | 59 | 60 | if __name__ == '__main': 61 | test() 62 | -------------------------------------------------------------------------------- /skipthoughts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Skip-thought vectors 3 | 4 | From : https://github.com/ryankiros/skip-thoughts 5 | 6 | ''' 7 | import os 8 | 9 | import theano 10 | import theano.tensor as tensor 11 | 12 | import pickle as pkl 13 | import numpy 14 | import copy 15 | import nltk 16 | 17 | from collections import OrderedDict, defaultdict 18 | from scipy.linalg import norm 19 | from nltk.tokenize import word_tokenize 20 | 21 | profile = False 22 | 23 | #-----------------------------------------------------------------------------# 24 | # Specify model and table locations here 25 | #-----------------------------------------------------------------------------# 26 | path_to_models = 'Data/skipthoughts/' 27 | path_to_tables = 'Data/skipthoughts/' 28 | #-----------------------------------------------------------------------------# 29 | 30 | path_to_umodel = path_to_models + 'uni_skip.npz' 31 | path_to_bmodel = path_to_models + 'bi_skip.npz' 32 | 33 | 34 | def load_model(): 35 | """ 36 | Load the model with saved tables 37 | """ 38 | # Load model options 39 | print('Loading model parameters...') 40 | with open('%s.pkl' % path_to_umodel, 'rb') as f: 41 | uoptions = pkl.load(f) 42 | with open('%s.pkl' % path_to_bmodel, 'rb') as f: 43 | boptions = pkl.load(f) 44 | 45 | # Load parameters 46 | uparams = init_params(uoptions) 47 | uparams = load_params(path_to_umodel, uparams) 48 | utparams = init_tparams(uparams) 49 | bparams = init_params_bi(boptions) 50 | bparams = load_params(path_to_bmodel, bparams) 51 | btparams = init_tparams(bparams) 52 | 53 | # Extractor functions 54 | print('Compiling encoders...') 55 | embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions) 56 | f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v') 57 | embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions) 58 | f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2') 59 | 60 | # Tables 61 | print('Loading tables...') 62 | utable, btable = load_tables() 63 | 64 | # Store everything we need in a dictionary 65 | print('Packing up...') 66 | model = {} 67 | model['uoptions'] = uoptions 68 | model['boptions'] = boptions 69 | model['utable'] = utable 70 | model['btable'] = btable 71 | model['f_w2v'] = f_w2v 72 | model['f_w2v2'] = f_w2v2 73 | 74 | return model 75 | 76 | 77 | def load_tables(): 78 | """ 79 | Load the tables 80 | """ 81 | words = [] 82 | utable = numpy.load(path_to_tables + 'utable.npy', encoding='latin1') 83 | btable = numpy.load(path_to_tables + 'btable.npy', encoding='latin1') 84 | f = open(path_to_tables + 'dictionary.txt', 'rb') 85 | for line in f: 86 | words.append(line.decode('utf-8').strip()) 87 | f.close() 88 | utable = OrderedDict(list(zip(words, utable))) 89 | btable = OrderedDict(list(zip(words, btable))) 90 | return utable, btable 91 | 92 | 93 | def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False): 94 | """ 95 | Encode sentences in the list X. Each entry will return a vector 96 | """ 97 | # first, do preprocessing 98 | X = preprocess(X) 99 | 100 | # word dictionary and init 101 | d = defaultdict(lambda: 0) 102 | for w in list(model['utable'].keys()): 103 | d[w] = 1 104 | ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32') 105 | bfeatures = numpy.zeros((len(X), 2 * model['boptions']['dim']), dtype='float32') 106 | 107 | # length dictionary 108 | ds = defaultdict(list) 109 | captions = [s.split() for s in X] 110 | for i, s in enumerate(captions): 111 | ds[len(s)].append(i) 112 | 113 | # Get features. This encodes by length, in order to avoid wasting computation 114 | for k in list(ds.keys()): 115 | if verbose: 116 | print(k) 117 | numbatches = len(ds[k]) // batch_size + 1 118 | for minibatch in range(numbatches): 119 | caps = ds[k][minibatch::numbatches] 120 | 121 | if use_eos: 122 | uembedding = numpy.zeros((k + 1, len(caps), model['uoptions']['dim_word']), dtype='float32') 123 | bembedding = numpy.zeros((k + 1, len(caps), model['boptions']['dim_word']), dtype='float32') 124 | else: 125 | uembedding = numpy.zeros((k, len(caps), model['uoptions']['dim_word']), dtype='float32') 126 | bembedding = numpy.zeros((k, len(caps), model['boptions']['dim_word']), dtype='float32') 127 | for ind, c in enumerate(caps): 128 | caption = captions[c] 129 | for j in range(len(caption)): 130 | if d[caption[j]] > 0: 131 | uembedding[j, ind] = model['utable'][caption[j]] 132 | bembedding[j, ind] = model['btable'][caption[j]] 133 | else: 134 | uembedding[j, ind] = model['utable']['UNK'] 135 | bembedding[j, ind] = model['btable']['UNK'] 136 | if use_eos: 137 | uembedding[-1, ind] = model['utable'][''] 138 | bembedding[-1, ind] = model['btable'][''] 139 | if use_eos: 140 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption) + 1, len(caps)), dtype='float32')) 141 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption) + 1, len(caps)), dtype='float32')) 142 | else: 143 | uff = model['f_w2v'](uembedding, numpy.ones((len(caption), len(caps)), dtype='float32')) 144 | bff = model['f_w2v2'](bembedding, numpy.ones((len(caption), len(caps)), dtype='float32')) 145 | if use_norm: 146 | for j in range(len(uff)): 147 | uff[j] /= norm(uff[j]) 148 | bff[j] /= norm(bff[j]) 149 | for ind, c in enumerate(caps): 150 | ufeatures[c] = uff[ind] 151 | bfeatures[c] = bff[ind] 152 | 153 | features = numpy.c_[ufeatures, bfeatures] 154 | return features 155 | 156 | 157 | def preprocess(text): 158 | """ 159 | Preprocess text for encoder 160 | """ 161 | X = [] 162 | sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') 163 | for t in text: 164 | sents = sent_detector.tokenize(t) 165 | result = '' 166 | for s in sents: 167 | tokens = word_tokenize(s) 168 | result += ' ' + ' '.join(tokens) 169 | X.append(result) 170 | return X 171 | 172 | 173 | def nn(model, text, vectors, query, k=5): 174 | """ 175 | Return the nearest neighbour sentences to query 176 | text: list of sentences 177 | vectors: the corresponding representations for text 178 | query: a string to search 179 | """ 180 | qf = encode(model, [query]) 181 | qf /= norm(qf) 182 | scores = numpy.dot(qf, vectors.T).flatten() 183 | sorted_args = numpy.argsort(scores)[::-1] 184 | sentences = [text[a] for a in sorted_args[:k]] 185 | print(('QUERY: ' + query)) 186 | print('NEAREST: ') 187 | for i, s in enumerate(sentences): 188 | print((s, sorted_args[i])) 189 | 190 | 191 | def word_features(table): 192 | """ 193 | Extract word features into a normalized matrix 194 | """ 195 | features = numpy.zeros((len(table), 620), dtype='float32') 196 | keys = list(table.keys()) 197 | for i in range(len(table)): 198 | f = table[keys[i]] 199 | features[i] = f / norm(f) 200 | return features 201 | 202 | 203 | def nn_words(table, wordvecs, query, k=10): 204 | """ 205 | Get the nearest neighbour words 206 | """ 207 | keys = list(table.keys()) 208 | qf = table[query] 209 | scores = numpy.dot(qf, wordvecs.T).flatten() 210 | sorted_args = numpy.argsort(scores)[::-1] 211 | words = [keys[a] for a in sorted_args[:k]] 212 | print(('QUERY: ' + query)) 213 | print('NEAREST: ') 214 | for i, w in enumerate(words): 215 | print(w) 216 | 217 | 218 | def _p(pp, name): 219 | """ 220 | make prefix-appended name 221 | """ 222 | return '%s_%s' % (pp, name) 223 | 224 | 225 | def init_tparams(params): 226 | """ 227 | initialize Theano shared variables according to the initial parameters 228 | """ 229 | tparams = OrderedDict() 230 | for kk, pp in list(params.items()): 231 | tparams[kk] = theano.shared(params[kk], name=kk) 232 | return tparams 233 | 234 | 235 | def load_params(path, params): 236 | """ 237 | load parameters 238 | """ 239 | pp = numpy.load(path) 240 | for kk, vv in list(params.items()): 241 | if kk not in pp: 242 | warnings.warn('%s is not in the archive' % kk) 243 | continue 244 | params[kk] = pp[kk] 245 | return params 246 | 247 | 248 | # layers: 'name': ('parameter initializer', 'feedforward') 249 | layers = {'gru': ('param_init_gru', 'gru_layer')} 250 | 251 | 252 | def get_layer(name): 253 | fns = layers[name] 254 | return (eval(fns[0]), eval(fns[1])) 255 | 256 | 257 | def init_params(options): 258 | """ 259 | initialize all parameters needed for the encoder 260 | """ 261 | params = OrderedDict() 262 | 263 | # embedding 264 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 265 | 266 | # encoder: GRU 267 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 268 | nin=options['dim_word'], dim=options['dim']) 269 | return params 270 | 271 | 272 | def init_params_bi(options): 273 | """ 274 | initialize all paramters needed for bidirectional encoder 275 | """ 276 | params = OrderedDict() 277 | 278 | # embedding 279 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 280 | 281 | # encoder: GRU 282 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder', 283 | nin=options['dim_word'], dim=options['dim']) 284 | params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r', 285 | nin=options['dim_word'], dim=options['dim']) 286 | return params 287 | 288 | 289 | def build_encoder(tparams, options): 290 | """ 291 | build an encoder, given pre-computed word embeddings 292 | """ 293 | # word embedding (source) 294 | embedding = tensor.tensor3('embedding', dtype='float32') 295 | x_mask = tensor.matrix('x_mask', dtype='float32') 296 | 297 | # encoder 298 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 299 | prefix='encoder', 300 | mask=x_mask) 301 | ctx = proj[0][-1] 302 | 303 | return embedding, x_mask, ctx 304 | 305 | 306 | def build_encoder_bi(tparams, options): 307 | """ 308 | build bidirectional encoder, given pre-computed word embeddings 309 | """ 310 | # word embedding (source) 311 | embedding = tensor.tensor3('embedding', dtype='float32') 312 | embeddingr = embedding[::-1] 313 | x_mask = tensor.matrix('x_mask', dtype='float32') 314 | xr_mask = x_mask[::-1] 315 | 316 | # encoder 317 | proj = get_layer(options['encoder'])[1](tparams, embedding, options, 318 | prefix='encoder', 319 | mask=x_mask) 320 | projr = get_layer(options['encoder'])[1](tparams, embeddingr, options, 321 | prefix='encoder_r', 322 | mask=xr_mask) 323 | 324 | ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1) 325 | 326 | return embedding, x_mask, ctx 327 | 328 | 329 | # some utilities 330 | def ortho_weight(ndim): 331 | W = numpy.random.randn(ndim, ndim) 332 | u, s, v = numpy.linalg.svd(W) 333 | return u.astype('float32') 334 | 335 | 336 | def norm_weight(nin, nout=None, scale=0.1, ortho=True): 337 | if nout == None: 338 | nout = nin 339 | if nout == nin and ortho: 340 | W = ortho_weight(nin) 341 | else: 342 | W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout)) 343 | return W.astype('float32') 344 | 345 | 346 | def param_init_gru(options, params, prefix='gru', nin=None, dim=None): 347 | """ 348 | parameter init for GRU 349 | """ 350 | if nin == None: 351 | nin = options['dim_proj'] 352 | if dim == None: 353 | dim = options['dim_proj'] 354 | W = numpy.concatenate([norm_weight(nin, dim), 355 | norm_weight(nin, dim)], axis=1) 356 | params[_p(prefix, 'W')] = W 357 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32') 358 | U = numpy.concatenate([ortho_weight(dim), 359 | ortho_weight(dim)], axis=1) 360 | params[_p(prefix, 'U')] = U 361 | 362 | Wx = norm_weight(nin, dim) 363 | params[_p(prefix, 'Wx')] = Wx 364 | Ux = ortho_weight(dim) 365 | params[_p(prefix, 'Ux')] = Ux 366 | params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32') 367 | 368 | return params 369 | 370 | 371 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs): 372 | """ 373 | Forward pass through GRU layer 374 | """ 375 | nsteps = state_below.shape[0] 376 | if state_below.ndim == 3: 377 | n_samples = state_below.shape[1] 378 | else: 379 | n_samples = 1 380 | 381 | dim = tparams[_p(prefix, 'Ux')].shape[1] 382 | 383 | if mask == None: 384 | mask = tensor.alloc(1., state_below.shape[0], 1) 385 | 386 | def _slice(_x, n, dim): 387 | if _x.ndim == 3: 388 | return _x[:, :, n * dim:(n + 1) * dim] 389 | return _x[:, n * dim:(n + 1) * dim] 390 | 391 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')] 392 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')] 393 | U = tparams[_p(prefix, 'U')] 394 | Ux = tparams[_p(prefix, 'Ux')] 395 | 396 | def _step_slice(m_, x_, xx_, h_, U, Ux): 397 | preact = tensor.dot(h_, U) 398 | preact += x_ 399 | 400 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) 401 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) 402 | 403 | preactx = tensor.dot(h_, Ux) 404 | preactx = preactx * r 405 | preactx = preactx + xx_ 406 | 407 | h = tensor.tanh(preactx) 408 | 409 | h = u * h_ + (1. - u) * h 410 | h = m_[:, None] * h + (1. - m_)[:, None] * h_ 411 | 412 | return h 413 | 414 | seqs = [mask, state_below_, state_belowx] 415 | _step = _step_slice 416 | 417 | rval, updates = theano.scan(_step, 418 | sequences=seqs, 419 | outputs_info=[tensor.alloc(0., n_samples, dim)], 420 | non_sequences=[tparams[_p(prefix, 'U')], 421 | tparams[_p(prefix, 'Ux')]], 422 | name=_p(prefix, '_layers'), 423 | n_steps=nsteps, 424 | profile=profile, 425 | strict=True) 426 | rval = [rval] 427 | 428 | 429 | return rval 430 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import args 3 | import time 4 | import datetime 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from torch.nn import functional as F 12 | from torchvision.utils import save_image 13 | 14 | import numpy as np 15 | from nets import Discriminator, Generator 16 | 17 | 18 | class GAN_CLS(object): 19 | def __init__(self, args, data_loader, SUPERVISED=True): 20 | """ 21 | Arguments : 22 | ---------- 23 | args : Arguments defined in Argument Parser 24 | data_loader = An instance of class DataLoader for loading our dataset in batches 25 | SUPERVISED : 26 | 27 | """ 28 | 29 | self.data_loader = data_loader 30 | self.num_epochs = args.num_epochs 31 | self.batch_size = args.batch_size 32 | 33 | self.log_step = config.log_step 34 | self.sample_step = config.sample_step 35 | 36 | self.log_dir = args.log_dir 37 | self.checkpoint_dir = args.checkpoint_dir 38 | self.sample_dir = config.sample_dir 39 | self.final_model = args.final_model 40 | 41 | self.dataset = args.dataset 42 | self.model_name = args.model_name 43 | 44 | self.img_size = args.img_size 45 | self.z_dim = args.z_dim 46 | self.text_embed_dim = args.text_embed_dim 47 | self.text_reduced_dim = args.text_reduced_dim 48 | self.learning_rate = args.learning_rate 49 | self.beta1 = args.beta1 50 | self.beta2 = args.beta2 51 | self.l1_coeff = args.l1_coeff 52 | self.resume_epoch = args.resume_epoch 53 | self.SUPERVISED = SUPERVISED 54 | 55 | # Logger setting 56 | self.logger = logging.getLogger('__name__') 57 | self.logger.setLevel(logging.INFO) 58 | self.formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') 59 | self.file_handler = logging.FileHandler(self.log_dir) 60 | self.file_handler.setFormatter(self.formatter) 61 | self.logger.addHandler(self.file_handler) 62 | 63 | self.build_model() 64 | 65 | def build_model(self): 66 | """ A function of defining following instances : 67 | 68 | ----- Generator 69 | ----- Discriminator 70 | ----- Optimizer for Generator 71 | ----- Optimizer for Discriminator 72 | ----- Defining Loss functions 73 | 74 | """ 75 | 76 | # --------------------------------------------------------------------- 77 | # 1. Network Initialization 78 | # --------------------------------------------------------------------- 79 | self.gen = Generator(batch_size=self.batch_size, 80 | img_size=self, img_size, 81 | z_dim=self.z_dim, 82 | text_embed_dim=self.text_embed_dim, 83 | text_reduced_dim=self.text_reduced_dim) 84 | 85 | self.disc = Discriminator(batch_size=self.batch_size, 86 | img_size=self, img_size, 87 | text_embed_dim=self.text_embed_dim, 88 | text_reduced_dim=self.text_reduced_dim) 89 | 90 | self.gen_optim = optim.Adam(self.gen.parameters(), 91 | lr=self.learning_rate, 92 | betas=(self.beta1, self.beta2)) 93 | 94 | self.disc_optim = optim.Adam(self.disc.parameters(), 95 | lr=self.learning_rate, 96 | betas=(self.beta1, self.beta2)) 97 | 98 | self.cls_gan_optim = optim.Adam(itertools.chain(self.gen.parameters(), 99 | self.disc.parameters()), 100 | lr=self.learning_rate, 101 | betas=(self.beta1, self.beta2)) 102 | 103 | print ('------------- Generator Model Info ---------------') 104 | self.print_network(self.gen, 'G') 105 | print ('------------------------------------------------') 106 | 107 | print ('------------- Discriminator Model Info ---------------') 108 | self.print_network(self.disc, 'D') 109 | print ('------------------------------------------------') 110 | 111 | self.gen.cuda() 112 | self.disc.cuda() 113 | self.criterion = nn.BCELoss().cuda() 114 | # self.CE_loss = nn.CrossEntropyLoss().cuda() 115 | # self.MSE_loss = nn.MSELoss().cuda() 116 | self.gen.train() 117 | self.disc.train() 118 | 119 | def print_network(self, model, name): 120 | """ A function for printing total number of model parameters """ 121 | num_params = 0 122 | for p in model.parameters(): 123 | num_params += p.numel() 124 | 125 | print(model) 126 | print(name) 127 | print("Total number of parameters: {}".format(num_params)) 128 | 129 | def load_checkpoints(self, resume_epoch): 130 | """Restore the trained generator and discriminator.""" 131 | print('Loading the trained models from step {}...'.format(resume_epoch)) 132 | G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(resume_epoch)) 133 | D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(resume_epoch)) 134 | self.gen.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) 135 | self.disc.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) 136 | 137 | def train_model(self): 138 | 139 | data_loader = self.data_loader 140 | 141 | start_epoch = 0 142 | if self.resume_epoch: 143 | start_epoch = self.resume_epoch 144 | self.load_checkpoints(self.resume_epoch) 145 | 146 | print ('--------------- Model Training Started ---------------') 147 | start_time = time.time() 148 | 149 | for epoch in range(start_epoch, self.num_epochs): 150 | for idx, batch in enumerate(data_loader): 151 | true_imgs = batch['true_imgs'] 152 | true_embed = batch['true_embed'] 153 | false_imgs = batch['false_imgs'] 154 | 155 | real_labels = torch.ones(true_imgs.size(0)) 156 | fake_labels = torch.zeros(true_imgs.size(0)) 157 | 158 | smooth_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1)) 159 | 160 | true_imgs = Variable(true_imgs.float()).cuda() 161 | true_embed = Variable(true_embed.float()).cuda() 162 | false_imgs = Variable(false_imgs.float()).cuda() 163 | 164 | real_labels = Variable(real_labels).cuda() 165 | smooth_real_labels = Variable(smooth_real_labels).cuda() 166 | fake_labels = Variable(fake_labels).cuda() 167 | 168 | # --------------------------------------------------------------- 169 | # 2. Training the generator 170 | # --------------------------------------------------------------- 171 | self.gen.zero_grad() 172 | z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda() 173 | fake_imgs = self.gen(true_embed, z) 174 | fake_out, fake_logit = self.disc(fake_imgs, true_embed) 175 | true_out, true_logit = self.disc(true_imgs, true_embed) 176 | 177 | gen_loss = self.criterion(fake_out, real_labels) + 178 | self.l1_coeff * nn.L1Loss(fake_imgs, true_imgs) 179 | 180 | gen_loss.backward() 181 | self.gen_optim.step() 182 | 183 | # --------------------------------------------------------------- 184 | # 3. Training the discriminator 185 | # --------------------------------------------------------------- 186 | self.disc.zero_grad() 187 | false_out, false_logit = self.disc(false_imgs, true_embed) 188 | disc_loss = self.criterion(true_out, smooth_real_labels) + 189 | self.criterion(fake_out, fake_labels) + self.criterion(false_out, fake_labels) 190 | 191 | disc_loss.backward() 192 | self.disc_optim.step() 193 | 194 | # self.cls_gan_optim.step() 195 | 196 | # Logging 197 | loss = {} 198 | loss['G_loss'] = gen_loss.item() 199 | loss['D_loss'] = disc_loss.item() 200 | 201 | # --------------------------------------------------------------- 202 | # 4. Logging INFO into log_dir 203 | # --------------------------------------------------------------- 204 | if (idx + 1) % self.log_step == 0: 205 | end_time = time.time() - start_time 206 | end_time = datetime.timedelta(seconds=end_time) 207 | log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format(end_time, epoch + 1, 208 | self.num_epochs, idx) 209 | 210 | for net, loss_value in loss.items(): 211 | log += ", {}: {:.4f}".format(net, loss_value) 212 | self.logger.info(log) 213 | print (log) 214 | 215 | # --------------------------------------------------------------- 216 | # 5. Saving generated images 217 | # --------------------------------------------------------------- 218 | if (idx + 1) % self.sample_step == 0: 219 | concat_imgs = torch.cat((true_imgs, fake_imgs), 2) # ?????????? 220 | save_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(idx + 1)) 221 | cocat_imgs = (cocat_imgs + 1) / 2 222 | # out.clamp_(0, 1) 223 | save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0) 224 | print ('Saved real and fake images into {}...'.format(self.sample_dir)) 225 | 226 | # --------------------------------------------------------------- 227 | # 6. Saving the checkpoints & final model 228 | # --------------------------------------------------------------- 229 | if (idx + 1) % self.model_save_step == 0: 230 | G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(idx + 1)) 231 | D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(idx + 1)) 232 | torch.save(self.gen.state_dict(), G_path) 233 | torch.save(self.disc.state_dict(), D_path) 234 | print('Saved model checkpoints into {}...'.format(self.checkpoint_dir)) 235 | 236 | print ('--------------- Model Training Completed ---------------') 237 | # Saving final model into final_model directory 238 | G_path = os.path.join(self.final_model, '{}-G.pth'.format('final')) 239 | D_path = os.path.join(self.final_model, '{}-D.pth'.format('final')) 240 | torch.save(self.gen.state_dict(), G_path) 241 | torch.save(self.disc.state_dict(), D_path) 242 | print('Saved final model into {}...'.format(self.final_model)) 243 | --------------------------------------------------------------------------------