├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── main.py ├── model.py ├── netD_final.pth ├── netG_final.pth ├── report.pdf └── results ├── a_long_bodied_bird,_with_a_black_face_and_upper_body,_with_a_yellow_underbelly_and_lower_body_0.jpg ├── a_medium_bird_with_a_white_belly,_black_rectrices,_and_a_bright_yellow_throat_0.jpg ├── a_short_billed,_fully_breasted_bird_with_a_beautifully_contrasting_crown,_this_green_and_gold_b_0.jpg ├── a_skinny_white_bird_with_brown_and_black_nape_and_back_sits_perched_on_a_thin_branch_0.jpg ├── a_small_bird_has_a_small_sharp_bill,_a_spotted_crown,_and_legs_with_a_large_tarsus_0.jpg ├── a_small_bird_with_a_blue_breast_and_a_grey_belly_with_a_sharp_small_beak_0.jpg ├── a_small_bird_with_a_bright_green_belly_and_dark_green_body_covering_0.jpg ├── a_small_bird_with_a_long_pointed_bill,_feathers_along_its_belly_are_gray,_white,_and_brown_0.jpg ├── fake_samples_epoch_243.png ├── fake_samples_epoch_377.png └── this_bird_has_a_fluffy_gray_crown_and_gray_and_white_striped_primaries_and_secondaries_0.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-to-Image-Synthesis 2 | PyTorch Implementation of the paper - 'Generative Adversarial Text to Image Synthesis' from ICML 2016 [ https://arxiv.org/abs/1605.05396 ] 3 | 4 | Please refer the report.pdf for overview of the process and details regarding the implementation. 5 | 6 | The original paper implemented 4 methods 7 | - GAN 8 | - GAN-CLS [Matching-aware discriminator] 9 | - GAN-INT [Learning with manifold interpolation] 10 | - GAN-INT-CLS [Combination of above two methods] 11 | 12 | However, this code only contains the Pytorch implementation of GAN-CLS. 13 | 14 | Please refer the original paper [ https://arxiv.org/abs/1605.05396 ] for the details of other methods. 15 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import PIL 10 | import os 11 | import os.path 12 | import pickle 13 | import random 14 | import numpy as np 15 | import pandas as pd 16 | 17 | 18 | class TextDataset(data.Dataset): 19 | def __init__(self, data_dir, split='train', embedding_type='cnn-rnn', 20 | imsize=64, transform=None, target_transform=None): 21 | 22 | self.transform = transform 23 | self.target_transform = target_transform 24 | self.imsize = imsize 25 | self.data = [] 26 | self.data_dir = data_dir 27 | if data_dir.find('birds') != -1: 28 | self.bbox = self.load_bbox() 29 | else: 30 | self.bbox = None 31 | split_dir = os.path.join(data_dir, split) 32 | 33 | self.filenames = self.load_filenames(split_dir) 34 | self.embeddings = self.load_embedding(split_dir, embedding_type) 35 | self.captions = self.load_all_captions() 36 | 37 | def get_img(self, img_path, bbox): 38 | img = Image.open(img_path).convert('RGB') 39 | width, height = img.size 40 | if bbox is not None: 41 | R = int(np.maximum(bbox[2], bbox[3]) * 0.75) 42 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 43 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 44 | y1 = np.maximum(0, center_y - R) 45 | y2 = np.minimum(height, center_y + R) 46 | x1 = np.maximum(0, center_x - R) 47 | x2 = np.minimum(width, center_x + R) 48 | img = img.crop([x1, y1, x2, y2]) 49 | load_size = int(self.imsize * 76 / 64) 50 | img = img.resize((load_size, load_size), PIL.Image.BILINEAR) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | return img 54 | 55 | def load_bbox(self): 56 | data_dir = self.data_dir 57 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') 58 | df_bounding_boxes = pd.read_csv(bbox_path, 59 | delim_whitespace=True, 60 | header=None).astype(int) 61 | # 62 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') 63 | df_filenames = \ 64 | pd.read_csv(filepath, delim_whitespace=True, header=None) 65 | filenames = df_filenames[1].tolist() 66 | print('Total filenames: ', len(filenames), filenames[0]) 67 | # 68 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 69 | numImgs = len(filenames) 70 | for i in xrange(0, numImgs): 71 | # bbox = [x-left, y-top, width, height] 72 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 73 | 74 | key = filenames[i][:-4] 75 | filename_bbox[key] = bbox 76 | # 77 | return filename_bbox 78 | 79 | def load_all_captions(self): 80 | caption_dict = {} 81 | for key in self.filenames: 82 | caption_name = '%s/text/%s.txt' % (self.data_dir, key) 83 | captions = self.load_captions(caption_name) 84 | caption_dict[key] = captions 85 | return caption_dict 86 | 87 | def load_captions(self, caption_name): 88 | cap_path = caption_name 89 | with open(cap_path, "r") as f: 90 | captions = f.read().decode('utf8').split('\n') 91 | captions = [cap.replace("\ufffd\ufffd", " ") 92 | for cap in captions if len(cap) > 0] 93 | return captions 94 | 95 | def load_embedding(self, data_dir, embedding_type): 96 | if embedding_type == 'cnn-rnn': 97 | embedding_filename = '/char-CNN-RNN-embeddings.pickle' 98 | elif embedding_type == 'cnn-gru': 99 | embedding_filename = '/char-CNN-GRU-embeddings.pickle' 100 | elif embedding_type == 'skip-thought': 101 | embedding_filename = '/skip-thought-embeddings.pickle' 102 | 103 | with open(data_dir + embedding_filename, 'rb') as f: 104 | embeddings = pickle.load(f) 105 | embeddings = np.array(embeddings) 106 | # embedding_shape = [embeddings.shape[-1]] 107 | print('embeddings: ', embeddings.shape) 108 | return embeddings 109 | 110 | def load_filenames(self, data_dir): 111 | filepath = os.path.join(data_dir, 'filenames.pickle') 112 | with open(filepath, 'rb') as f: 113 | filenames = pickle.load(f) 114 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 115 | return filenames 116 | 117 | def __getitem__(self, index): 118 | key = self.filenames[index] 119 | 120 | if self.bbox is not None: 121 | bbox = self.bbox[key] 122 | data_dir = '%s/CUB_200_2011' % self.data_dir 123 | else: 124 | bbox = None 125 | data_dir = self.data_dir 126 | 127 | captions = self.captions[key] 128 | embeddings = self.embeddings[index, :, :] 129 | img_name = '%s/images/%s.jpg' % (data_dir, key) 130 | img = self.get_img(img_name, bbox) 131 | 132 | rand_ix = random.randint(0, embeddings.shape[0]-1) 133 | embedding = embeddings[rand_ix, :] 134 | if self.target_transform is not None: 135 | embedding = self.target_transform(embedding) 136 | return img, embedding,captions[rand_ix] 137 | 138 | def __len__(self): 139 | return len(self.filenames) 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import StepLR 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.optim.lr_scheduler import StepLR 16 | from torch.autograd import Variable 17 | from datetime import datetime 18 | import model 19 | from dataset import TextDataset 20 | 21 | import pdb 22 | pdb.set_trace() 23 | 24 | parser = argparse.ArgumentParser() 25 | # parser.add_argument( 26 | # '--dataset', 27 | # required=True, 28 | # default='folder', 29 | # help='cifar10 | lsun | imagenet | folder | lfw | fake') 30 | parser.add_argument( 31 | '--dataroot', required=True, default='./data/coco', help='path to dataset') 32 | parser.add_argument( 33 | '--workers', type=int, help='number of data loading workers', default=2) 34 | parser.add_argument( 35 | '--batchSize', type=int, default=64, help='input batch size') 36 | parser.add_argument( 37 | '--imageSize', 38 | type=int, 39 | default=64, 40 | help='the height / width of the input image to network') 41 | parser.add_argument( 42 | '--nte', 43 | type=int, 44 | default=1024, 45 | help='the size of the text embedding vector') 46 | parser.add_argument( 47 | '--nt', 48 | type=int, 49 | default=256, 50 | help='the reduced size of the text embedding vector') 51 | parser.add_argument( 52 | '--nz', type=int, default=100, help='size of the latent z vector') 53 | parser.add_argument('--ngf', type=int, default=64) 54 | parser.add_argument('--ndf', type=int, default=64) 55 | parser.add_argument( 56 | '--niter', type=int, default=25, help='number of epochs to train for') 57 | parser.add_argument( 58 | '--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 59 | parser.add_argument( 60 | '--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 61 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 62 | parser.add_argument( 63 | '--ngpu', type=int, default=1, help='number of GPUs to use') 64 | parser.add_argument( 65 | '--netG', default='', help="path to netG (to continue training)") 66 | parser.add_argument( 67 | '--netD', default='', help="path to netD (to continue training)") 68 | parser.add_argument( 69 | '--outf', 70 | default='./output/', 71 | help='folder to output images and model checkpoints') 72 | parser.add_argument('--manualSeed', type=int, help='manual seed') 73 | parser.add_argument( 74 | '--eval', 75 | action='store_true', 76 | help="choose whether to train the model or show demo") 77 | opt = parser.parse_args() 78 | print(opt) 79 | 80 | try: 81 | output_dir = os.path.join(opt.outf, 82 | datetime.strftime(datetime.now(), "%Y%m%d_%H%M")) 83 | os.makedirs(output_dir) 84 | except OSError: 85 | pass 86 | 87 | if opt.manualSeed is None: 88 | opt.manualSeed = random.randint( 89 | 1, 10000 90 | ) #use random.randint(1, 10000) for randomness, shouldnt be done when we want to continue training from a checkpoint 91 | print("Random Seed: ", opt.manualSeed) 92 | random.seed(opt.manualSeed) 93 | torch.manual_seed(opt.manualSeed) 94 | if opt.cuda: 95 | torch.cuda.manual_seed_all(opt.manualSeed) 96 | 97 | cudnn.benchmark = True 98 | 99 | if torch.cuda.is_available() and not opt.cuda: 100 | print( 101 | "WARNING: You have a CUDA device, so you should probably run with --cuda" 102 | ) 103 | 104 | image_transform = transforms.Compose([ 105 | transforms.RandomCrop(opt.imageSize), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 109 | ]) 110 | 111 | ngpu = int(opt.ngpu) 112 | nz = int(opt.nz) 113 | ngf = int(opt.ngf) 114 | ndf = int(opt.ndf) 115 | nc = 3 116 | nt = int(opt.nt) 117 | nte = int(opt.nte) 118 | 119 | 120 | # custom weights initialization called on netG and netD 121 | def weights_init(m): 122 | classname = m.__class__.__name__ 123 | if classname.find('Conv') != -1: 124 | m.weight.data.normal_(0.0, 0.02) 125 | # m.bias.data.fill_(0) 126 | elif classname.find('BatchNorm') != -1: 127 | m.weight.data.normal_(1.0, 0.02) 128 | m.bias.data.fill_(0) 129 | 130 | 131 | netG = model._netG(ngpu, nz, ngf, nc, nte, nt) 132 | netG.apply(weights_init) 133 | if opt.netG != '': 134 | netG.load_state_dict(torch.load(opt.netG)) 135 | print(netG) 136 | 137 | netD = model._netD(ngpu, nc, ndf, nte, nt) 138 | netD.apply(weights_init) 139 | if opt.netD != '': 140 | netD.load_state_dict(torch.load(opt.netD)) 141 | print(netD) 142 | 143 | criterion = nn.BCELoss() 144 | 145 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) 146 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 147 | fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 148 | label = torch.FloatTensor(opt.batchSize) 149 | real_label = 1 150 | fake_label = 0 151 | 152 | if opt.cuda: 153 | netD.cuda() 154 | netG.cuda() 155 | criterion.cuda() 156 | input, label = input.cuda(), label.cuda() 157 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 158 | 159 | fixed_noise = Variable(fixed_noise) 160 | 161 | if not opt.eval: 162 | 163 | train_dataset = TextDataset(opt.dataroot, transform=image_transform) 164 | 165 | ## Completed - TODO: Make a new DataLoader and Dataset to include embeddings 166 | train_dataloader = torch.utils.data.DataLoader( 167 | train_dataset, 168 | batch_size=opt.batchSize, 169 | shuffle=True, 170 | num_workers=int(opt.workers)) 171 | 172 | # setup optimizer 173 | optimizerD = optim.Adam( 174 | netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 175 | optimizerG = optim.Adam( 176 | netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 177 | 178 | ## Completed TODO: Change the error loss function to include embeddings [refer main_cls.lua on the original paper repo] 179 | 180 | for epoch in range(1, opt.niter + 1): 181 | if epoch % 75 == 0: 182 | optimizerG.param_groups[0]['lr'] /= 2 183 | optimizerD.param_groups[0]['lr'] /= 2 184 | for i, data in enumerate(train_dataloader, 0): 185 | ############################ 186 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 187 | ########################### 188 | # train with real 189 | netD.zero_grad() 190 | real_cpu, text_embedding, _ = data 191 | batch_size = real_cpu.size(0) 192 | text_embedding = Variable(text_embedding) 193 | 194 | if opt.cuda: 195 | real_cpu = real_cpu.cuda() 196 | text_embedding = text_embedding.cuda() 197 | 198 | input.resize_as_(real_cpu).copy_(real_cpu) 199 | label.resize_(batch_size).fill_(real_label) 200 | inputv = Variable(input) 201 | labelv = Variable(label) 202 | 203 | output = netD(inputv, text_embedding) 204 | errD_real = criterion(output, labelv) ## 205 | errD_real.backward() 206 | D_x = output.data.mean() 207 | 208 | ### calculate errD_wrong 209 | inputv = torch.cat((inputv[1:], inputv[:1]), 0) 210 | output = netD(inputv, text_embedding) 211 | errD_wrong = criterion(output, labelv) * 0.5 212 | errD_wrong.backward() 213 | 214 | # train with fake 215 | noise.resize_(batch_size, nz, 1, 1).normal_(0, 1) 216 | noisev = Variable(noise) 217 | fake = netG(noisev, text_embedding) 218 | labelv = Variable(label.fill_(fake_label)) 219 | output = netD(fake.detach(), text_embedding) 220 | errD_fake = criterion(output, labelv) * 0.5 221 | errD_fake.backward() 222 | D_G_z1 = output.data.mean() 223 | 224 | errD = errD_real + errD_fake + errD_wrong 225 | # errD.backward() 226 | optimizerD.step() 227 | 228 | ############################ 229 | # (2) Update G network: maximize log(D(G(z))) 230 | ########################### 231 | netG.zero_grad() 232 | labelv = Variable(label.fill_( 233 | real_label)) # fake labels are real for generator cost 234 | output = netD(fake, text_embedding) 235 | errG = criterion(output, labelv) ## 236 | errG.backward() 237 | D_G_z2 = output.data.mean() 238 | optimizerG.step() 239 | 240 | print( 241 | '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 242 | % (epoch, opt.niter, i, len(train_dataloader), errD.data[0], 243 | errG.data[0], D_x, D_G_z1, D_G_z2)) 244 | if i % 100 == 0: 245 | vutils.save_image( 246 | real_cpu, '%s/real_samples.png' % output_dir, normalize=True) 247 | fake = netG(fixed_noise, text_embedding) 248 | vutils.save_image( 249 | fake.data, 250 | '%s/fake_samples_epoch_%03d.png' % (output_dir, epoch), 251 | normalize=True) 252 | 253 | # do checkpointing 254 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, 255 | epoch)) 256 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, 257 | epoch)) 258 | 259 | else: 260 | test_dataset = TextDataset(opt.dataroot, transform=image_transform,split='test') 261 | 262 | ## Completed - TODO: Make a new DataLoader and Dataset to include embeddings 263 | test_dataloader = torch.utils.data.DataLoader( 264 | test_dataset, 265 | batch_size=opt.batchSize, 266 | shuffle=True, 267 | num_workers=int(opt.workers)) 268 | 269 | for i, data in enumerate(test_dataloader, 0): 270 | real_image, text_embedding,caption = data 271 | batch_size = real_image.size(0) 272 | text_embedding = Variable(text_embedding) 273 | 274 | if opt.cuda: 275 | real_image = real_image.cuda() 276 | text_embedding = text_embedding.cuda() 277 | 278 | input.resize_as_(real_image).copy_(real_image) 279 | inputv = Variable(input) 280 | 281 | noise.resize_(batch_size, nz, 1, 1).normal_(0, 1) 282 | noisev = Variable(noise) 283 | num_test_outputs = 10 284 | 285 | 286 | # for count in range(num_test_outputs): 287 | # print (count) 288 | count =0 289 | print (i) 290 | synthetic_image = netG(noisev, text_embedding) 291 | synthetic_image = synthetic_image.detach() 292 | for i in range(synthetic_image.size()[0]): 293 | cap = caption[i].strip(".") 294 | cap = cap.replace("/"," or ") 295 | cap = cap.replace(" ","_") 296 | if len(cap) > 95: 297 | cap = cap[:95] 298 | file_path = './eval_results/'+cap 299 | # if not os.path.exists(file_path): 300 | # os.makedirs(file_path) 301 | try: 302 | vutils.save_image(synthetic_image[i].data,file_path+'_'+str(count)+'.jpg') 303 | # vutils.save_image(synthetic_image[i].data,os.path.join(file_path,str(count)+'.jpg')) 304 | except e: 305 | print (e) 306 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ## Completed - TODO: Change the models to include text embeddings 5 | ## Completed - TODO: Add FC to reduce the text_embedding to the size of nt 6 | class _netG(nn.Module): 7 | def __init__(self, ngpu, nz, ngf, nc, nte, nt): 8 | super(_netG, self).__init__() 9 | self.nt = nt 10 | self.ngpu = ngpu 11 | self.main = nn.Sequential( 12 | # input is Z, going into a convolution 13 | nn.ConvTranspose2d(nz + nt, ngf * 8, 4, 1, 0, bias=False), 14 | nn.BatchNorm2d(ngf * 8), 15 | # nn.ReLU(True), 16 | # state size. (ngf*8) x 4 x 4 17 | 18 | # Completed - TODO: check out paper's code and add layers if required 19 | 20 | ##there are more conv2d layers involved here in 21 | # https://github.com/reedscot/icml2016/blob/master/main_cls.lua 22 | 23 | nn.Conv2d(ngf*8,ngf*2,1,1), 24 | nn.Dropout2d(inplace=True), 25 | nn.BatchNorm2d(ngf * 2), 26 | nn.ReLU(True), 27 | # nn.SELU(True), 28 | 29 | nn.Conv2d(ngf*2,ngf*2,3,1,1), 30 | nn.Dropout2d(inplace=True), 31 | nn.BatchNorm2d(ngf * 2), 32 | nn.ReLU(True), 33 | # nn.SELU(True), 34 | 35 | nn.Conv2d(ngf*2,ngf*8,3,1,1), 36 | nn.Dropout2d(inplace=True), 37 | nn.BatchNorm2d(ngf * 8), 38 | nn.ReLU(inplace=True), 39 | # nn.SELU(True), 40 | 41 | 42 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 43 | nn.BatchNorm2d(ngf * 4), 44 | # nn.ReLU(True), 45 | # state size. (ngf*4) x 8 x 8 46 | 47 | # Completed - TODO: check out paper's code and add layers if required 48 | 49 | ##there are more conv2d layers involved here in 50 | # https://github.com/reedscot/icml2016/blob/master/main_cls.lua 51 | 52 | 53 | nn.Conv2d(ngf*4,ngf,1,1), 54 | nn.Dropout2d(inplace=True), 55 | nn.BatchNorm2d(ngf), 56 | nn.ReLU(True), 57 | # nn.SELU(True), 58 | 59 | nn.Conv2d(ngf,ngf,3,1,1), 60 | nn.Dropout2d(inplace=True), 61 | nn.BatchNorm2d(ngf), 62 | nn.ReLU(True), 63 | # nn.SELU(True), 64 | 65 | nn.Conv2d(ngf,ngf*4,3,1,1), 66 | nn.Dropout2d(inplace=True), 67 | nn.BatchNorm2d(ngf * 4), 68 | nn.ReLU(True), 69 | # nn.SELU(True), 70 | 71 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 72 | nn.BatchNorm2d(ngf * 2), 73 | nn.ReLU(True), 74 | # nn.SELU(True), 75 | 76 | # state size. (ngf*2) x 16 x 16 77 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 78 | nn.BatchNorm2d(ngf), 79 | nn.ReLU(True), 80 | # nn.SELU(True), 81 | 82 | # state size. (ngf) x 32 x 32 83 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 84 | nn.Tanh() 85 | # state size. (nc) x 64 x 64 86 | ) 87 | 88 | self.encode_text = nn.Sequential( 89 | nn.Linear(nte, nt), nn.LeakyReLU(0.2, inplace=True)) 90 | 91 | def forward(self, input, text_embedding): 92 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 93 | encoded_text = nn.parallel.data_parallel(self.encode_text, text_embedding, ) 94 | input_new = torch.cat((input, encoded_text)) 95 | output = nn.parallel.data_parallel(self.main,input_new, range(self.ngpu)) 96 | else: 97 | encoded_text = self.encode_text(text_embedding).view(-1,self.nt,1,1) 98 | output = self.main(torch.cat((input, encoded_text), 1)) 99 | return output 100 | 101 | ## Completed - TODO: pass nt and text_embedding size to the G and D and add FC to reduce text_embedding_size to nt 102 | class _netD(nn.Module): 103 | def __init__(self, ngpu, nc, ndf, nte, nt): 104 | super(_netD, self).__init__() 105 | self.ngpu = ngpu 106 | self.nt = nt 107 | self.nte = nte 108 | self.main = nn.Sequential( 109 | # input is (nc) x 64 x 64 110 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | # state size. (ndf) x 32 x 32 113 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 114 | nn.BatchNorm2d(ndf * 2), 115 | nn.LeakyReLU(0.2, inplace=True), 116 | # state size. (ndf*2) x 16 x 16 117 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 118 | nn.BatchNorm2d(ndf * 4), 119 | nn.LeakyReLU(0.2, inplace=True), 120 | # state size. (ndf*4) x 8 x 8 121 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 122 | nn.BatchNorm2d(ndf * 8), 123 | 124 | nn.Conv2d(ndf*8,ndf*2,1,1), 125 | # nn.Dropout2d(inplace=True), 126 | nn.BatchNorm2d(ndf * 2), 127 | nn.LeakyReLU(0.2, inplace=True), 128 | 129 | nn.Conv2d(ndf*2,ndf*2,3,1,1), 130 | # nn.Dropout2d(inplace=True), 131 | nn.BatchNorm2d(ndf * 2), 132 | nn.LeakyReLU(0.2, inplace=True), 133 | 134 | nn.Conv2d(ndf*2,ndf*8,3,1,1), 135 | # nn.Dropout2d(inplace=True), 136 | nn.BatchNorm2d(ndf * 8), 137 | nn.LeakyReLU(0.2, inplace=True)) 138 | 139 | # state size. (ndf*8) x 4 x 4 140 | 141 | ## add another sequential plot after this line to add the embedding and process it to find a single ans 142 | # Completed - TODO: confirm if what we are doing is same as given in paper code 143 | self.encode_text = nn.Sequential( 144 | nn.Linear(nte, nt), 145 | nn.LeakyReLU(0.2, inplace=True) 146 | 147 | ) 148 | 149 | self.concat_image_n_text = nn.Sequential( 150 | nn.Conv2d(ndf * 8 + nt, ndf * 8, 1, 1, 0, bias=False), ## TODO: Might want to change the kernel size and stride 151 | nn.BatchNorm2d(ndf*8), 152 | nn.LeakyReLU(0.2,inplace=True), 153 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 154 | nn.Sigmoid() 155 | ) 156 | 157 | def forward(self, input, text_embedding): 158 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 159 | encoded_img = nn.parallel.data_parallel(self.main, input, 160 | range(self.ngpu)) 161 | encoded_text = nn.parallel.data_parallel(self.encode_text, text_embedding, range(self.ngpu)) 162 | ## add the same things as in the else part 163 | else: 164 | encoded_img = self.main(input) 165 | encoded_text = self.encode_text(text_embedding) 166 | encoded_text = encoded_text.view(-1, self.nt, 1,1) 167 | encoded_text = encoded_text.repeat(1, 1, 4, 4) ## can also directly expand, look into the syntax 168 | output = self.concat_image_n_text(torch.cat((encoded_img, encoded_text),1)) 169 | 170 | return output.view(-1, 1).squeeze(1) 171 | -------------------------------------------------------------------------------- /netD_final.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/netD_final.pth -------------------------------------------------------------------------------- /netG_final.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/netG_final.pth -------------------------------------------------------------------------------- /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/report.pdf -------------------------------------------------------------------------------- /results/a_long_bodied_bird,_with_a_black_face_and_upper_body,_with_a_yellow_underbelly_and_lower_body_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_long_bodied_bird,_with_a_black_face_and_upper_body,_with_a_yellow_underbelly_and_lower_body_0.jpg -------------------------------------------------------------------------------- /results/a_medium_bird_with_a_white_belly,_black_rectrices,_and_a_bright_yellow_throat_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_medium_bird_with_a_white_belly,_black_rectrices,_and_a_bright_yellow_throat_0.jpg -------------------------------------------------------------------------------- /results/a_short_billed,_fully_breasted_bird_with_a_beautifully_contrasting_crown,_this_green_and_gold_b_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_short_billed,_fully_breasted_bird_with_a_beautifully_contrasting_crown,_this_green_and_gold_b_0.jpg -------------------------------------------------------------------------------- /results/a_skinny_white_bird_with_brown_and_black_nape_and_back_sits_perched_on_a_thin_branch_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_skinny_white_bird_with_brown_and_black_nape_and_back_sits_perched_on_a_thin_branch_0.jpg -------------------------------------------------------------------------------- /results/a_small_bird_has_a_small_sharp_bill,_a_spotted_crown,_and_legs_with_a_large_tarsus_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_small_bird_has_a_small_sharp_bill,_a_spotted_crown,_and_legs_with_a_large_tarsus_0.jpg -------------------------------------------------------------------------------- /results/a_small_bird_with_a_blue_breast_and_a_grey_belly_with_a_sharp_small_beak_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_small_bird_with_a_blue_breast_and_a_grey_belly_with_a_sharp_small_beak_0.jpg -------------------------------------------------------------------------------- /results/a_small_bird_with_a_bright_green_belly_and_dark_green_body_covering_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_small_bird_with_a_bright_green_belly_and_dark_green_body_covering_0.jpg -------------------------------------------------------------------------------- /results/a_small_bird_with_a_long_pointed_bill,_feathers_along_its_belly_are_gray,_white,_and_brown_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/a_small_bird_with_a_long_pointed_bill,_feathers_along_its_belly_are_gray,_white,_and_brown_0.jpg -------------------------------------------------------------------------------- /results/fake_samples_epoch_243.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/fake_samples_epoch_243.png -------------------------------------------------------------------------------- /results/fake_samples_epoch_377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/fake_samples_epoch_377.png -------------------------------------------------------------------------------- /results/this_bird_has_a_fluffy_gray_crown_and_gray_and_white_striped_primaries_and_secondaries_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jayybhatt/Text-to-Image-Synthesis/d51b26784657a3e1386462f8c5208a3b7095a6eb/results/this_bird_has_a_fluffy_gray_crown_and_gray_and_white_striped_primaries_and_secondaries_0.jpg --------------------------------------------------------------------------------