├── data ├── images │ ├── sample │ │ ├── s1_001.png │ │ ├── s2_001.png │ │ └── real_001.jpg │ ├── Network Description │ │ └── network description.jpg │ ├── Stage-1 (102 flowers dataset) │ │ ├── fake_samples_epoch_027.png │ │ └── fake_samples_epoch_102.png │ └── Stage-2 (102 flowers dataset) │ │ ├── fake_samples_epoch_003.png │ │ ├── fake_samples_epoch_058.png │ │ └── fake_samples_epoch_160.png ├── 102-flower dataset │ └── README.MD ├── Bert Embeddings │ └── README.MD └── pre-trained models │ └── README.MD ├── bert_embeddings.py ├── README.md ├── dataset.py ├── utils.py ├── trainer.py └── model.py /data/images/sample/s1_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/s1_001.png -------------------------------------------------------------------------------- /data/images/sample/s2_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/s2_001.png -------------------------------------------------------------------------------- /data/images/sample/real_001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/real_001.jpg -------------------------------------------------------------------------------- /data/images/Network Description/network description.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Network Description/network description.jpg -------------------------------------------------------------------------------- /data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_027.png -------------------------------------------------------------------------------- /data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_102.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_102.png -------------------------------------------------------------------------------- /data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_003.png -------------------------------------------------------------------------------- /data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_058.png -------------------------------------------------------------------------------- /data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_160.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_160.png -------------------------------------------------------------------------------- /data/102-flower dataset/README.MD: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | Download the 102-Flowers dataset from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ or you can also download it from https://drive.google.com/drive/folders/1-9I293J77J40IpUCLtoycjC1T_1QWmVR?usp=sharing. 4 | -------------------------------------------------------------------------------- /data/Bert Embeddings/README.MD: -------------------------------------------------------------------------------- 1 | # Bert - Embedding 2 | Run our Bert_Embeddings file on your captions to get the embeddings, you can also download our pre-processed Bert-Embeddings for 102-Flowers Dataset from https://drive.google.com/file/d/1XiNtxey51c3V03Xe4ELrMbqoQ_gMpvGe/view?usp=sharing 3 | -------------------------------------------------------------------------------- /data/pre-trained models/README.MD: -------------------------------------------------------------------------------- 1 | # Pre Trained Models 2 | 3 | ## Stage 1 Generator 4 | Download from https://drive.google.com/file/d/1-F0IymmrNWoM33Fb2IbZhf4o41n5FbJN/view?usp=sharing 5 | 6 | ## Stage 1 Discriminator 7 | Download from https://drive.google.com/file/d/1-KfgdzLwfMdVvA1HvEHslroNFFlJrdpo/view?usp=sharing 8 | 9 | ## Stage 2 Generator 10 | Download from https://drive.google.com/file/d/1-YjOU7ALKcg6KZpOxPhPETUzcvWJduIj/view?usp=sharing 11 | 12 | ## Stage 2 Discriminator 13 | Download from https://drive.google.com/file/d/1-ckPuRtTMKsdVMX6KdyLESVDxigXzHl_/view?usp=sharing 14 | 15 | ### The pre-trained models are trained for 100 epochs each. Results can be easily be improved by increasing the number of epochs. 16 | -------------------------------------------------------------------------------- /bert_embeddings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """BERT_Embeddings.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/13FXGWgWuqa4l3Dx7WM5wXPmZmShD2ffl 8 | """ 9 | 10 | !pip install transformers 11 | 12 | import torch 13 | import pandas as pd 14 | import numpy as np 15 | 16 | df = pd.read_csv("caption_id.csv") 17 | sentences = df.Caption.values 18 | 19 | a=np.empty(((int)(sentences.shape[0]/2),768)) 20 | from transformers import BertModel 21 | from transformers import BertTokenizer 22 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 23 | model = BertModel.from_pretrained('bert-base-uncased') 24 | 25 | model.eval() 26 | for i in range(sen.shape[0]): 27 | input_sentence = torch.tensor(tokenizer.encode(sen[i])).unsqueeze(0) 28 | out = model(input_sentence) 29 | embeddings_of_last_layer = out[0] 30 | cls_embeddings = embeddings_of_last_layer[0].clone().detach().requires_grad_(False) 31 | 32 | a[i]=np.mean(np.array(cls_embeddings),axis=0) 33 | model.zero_grad() 34 | 35 | bc=pd.DataFrame(a) 36 | bc.to_csv('embbedings1.csv') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StackGAN with BERT-Embeddings 2 | 3 | Synthesizing high-quality images from text descriptions is a challenging problem in computer vision and has many practical applications. Samples generated by existing text to-image approaches can roughly reflect the meaning of the given descriptions, but they fail to contain necessary details and vivid object parts. In this project, we improve upon the existing Stacked Generative Adversarial Networks (StackGAN) by introducing BERT Embeddings to generate 256×256 photo-realistic images conditioned on captions. 4 | We divide the problem into two stages. The Stage-I GAN sketches the primitive shape and colours of the object based on the given text description, yielding low-resolution images. The Stage-II GAN takes the primitive results and text descriptions as inputs and generates high-resolution images with photo-realistic details. It can rectify defects in Stage-I results and add compelling details with outstanding refinement process. 5 | 6 | 7 | 8 | ### Dependencies 9 | python 3.0 and above 10 | 11 | Pytorch 1.6.0 12 | CUDA 10.1 13 | 14 | In addition, please `pip install` the following packages: 15 | - `numpy` 16 | - `pandas` 17 | - `torchfile` 18 | 19 | ## Sample case 20 | 21 | ### Caption 22 | 23 | ### Stage-1 Image 24 | 25 | 26 | ### Stage-2 Image 27 | 28 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | 8 | import torch.utils.data as data 9 | from PIL import Image 10 | import PIL 11 | import os 12 | import os.path 13 | import pickle 14 | import random 15 | import numpy as np 16 | import pandas as pd 17 | 18 | #from miscc.config import cfg 19 | 20 | 21 | class TextDataset(data.Dataset): 22 | def __init__(self, data_dir, split='jpg', embedding_type='embeddings1', 23 | imsize=64, transform=None, target_transform=None): 24 | 25 | self.transform = transform 26 | self.target_transform = target_transform 27 | self.imsize = imsize 28 | self.data = [] 29 | self.data_dir = data_dir 30 | split_dir = os.path.join(data_dir, split) 31 | 32 | self.filenames = self.load_filenames(split_dir) 33 | self.embeddings = self.load_embedding(split_dir, embedding_type) 34 | 35 | def get_img(self, img_path): 36 | img = Image.open(img_path).convert('RGB') 37 | width, height = img.size 38 | # load_size = int(self.imsize * 76 / 64) 39 | load_size = int(self.imsize) 40 | img = img.resize((load_size, load_size), PIL.Image.BILINEAR) 41 | if self.transform is not None: 42 | img = self.transform(img) 43 | return img 44 | 45 | def load_all_captions(self): 46 | caption_dict = {} 47 | filepath = os.path.join(self.data_dir, 'caption_id.csv') 48 | cap=pd.read_csv(filepath) 49 | for key in self.filenames: 50 | caption_dict[key] = cap['Caption'][cap['image_id']==key] 51 | return caption_dict 52 | 53 | def load_embedding(self, data_dir, embedding_type): 54 | embedding_filename = '/embbedings1.csv' 55 | f=pd.read_csv(data_dir + embedding_filename) 56 | embeddings=np.array(np.array(f.iloc[:,1:])) 57 | return embeddings 58 | 59 | def load_filenames(self, data_dir): 60 | filepath = os.path.join(data_dir, 'filenames.csv') 61 | filenames=np.array(pd.read_csv(filepath)['image_id']) 62 | return filenames 63 | 64 | def __getitem__(self, index): 65 | key = self.filenames[index] 66 | data_dir = '%s/jpg' % self.data_dir 67 | #captions = self.captions[key] 68 | embedding = self.embeddings[index,:] 69 | img_name = '%s/%s.jpg' % (data_dir, key) 70 | img = self.get_img(img_name) 71 | if self.target_transform is not None: 72 | embedding = self.target_transform(embedding) 73 | return img, embedding 74 | 75 | def __len__(self): 76 | return len(self.filenames) 77 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """utils.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1nSgFzAcLjQbj94Ow9wJnp4fZ_Y-vlqM1 8 | """ 9 | 10 | import os 11 | import errno 12 | import numpy as np 13 | 14 | from copy import deepcopy 15 | 16 | from torch.nn import init 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.utils as vutils 20 | 21 | 22 | ############################# 23 | def KL_loss(mu, logvar): 24 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 25 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 26 | KLD = torch.mean(KLD_element).mul_(-0.5) 27 | return KLD 28 | 29 | 30 | def compute_discriminator_loss(netD, real_imgs, fake_imgs, 31 | real_labels, fake_labels, 32 | conditions, gpus): 33 | criterion = nn.BCELoss() 34 | batch_size = real_imgs.size(0) 35 | cond = conditions.detach() 36 | fake = fake_imgs.detach() 37 | real_features = nn.parallel.data_parallel(netD, (real_imgs), gpus) 38 | fake_features = nn.parallel.data_parallel(netD, (fake), gpus) 39 | # real pairs 40 | inputs = (real_features, cond) 41 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 42 | errD_real = criterion(real_logits, real_labels) 43 | # wrong pairs 44 | inputs = (real_features[:(batch_size-1)], cond[1:]) 45 | wrong_logits = \ 46 | nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 47 | errD_wrong = criterion(wrong_logits, fake_labels[1:]) 48 | # fake pairs 49 | inputs = (fake_features, cond) 50 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 51 | errD_fake = criterion(fake_logits, fake_labels) 52 | 53 | if netD.get_uncond_logits is not None: 54 | real_logits = \ 55 | nn.parallel.data_parallel(netD.get_uncond_logits, 56 | (real_features), gpus) 57 | fake_logits = \ 58 | nn.parallel.data_parallel(netD.get_uncond_logits, 59 | (fake_features), gpus) 60 | uncond_errD_real = criterion(real_logits, real_labels) 61 | uncond_errD_fake = criterion(fake_logits, fake_labels) 62 | # 63 | errD = ((errD_real + uncond_errD_real) / 2. + 64 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) 65 | errD_real = (errD_real + uncond_errD_real) / 2. 66 | errD_fake = (errD_fake + uncond_errD_fake) / 2. 67 | else: 68 | errD = errD_real + (errD_fake + errD_wrong) * 0.5 69 | 70 | return errD, errD_real, errD_wrong, errD_fake 71 | # return errD, errD_real.data[0], errD_wrong.data[0], errD_fake.data[0] 72 | 73 | 74 | 75 | 76 | def compute_generator_loss(netD, fake_imgs, real_labels, conditions, gpus): 77 | criterion = nn.BCELoss() 78 | cond = conditions.detach() 79 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs), gpus) 80 | # fake pairs 81 | inputs = (fake_features, cond) 82 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 83 | errD_fake = criterion(fake_logits, real_labels) 84 | if netD.get_uncond_logits is not None: 85 | fake_logits = \ 86 | nn.parallel.data_parallel(netD.get_uncond_logits, 87 | (fake_features), gpus) 88 | uncond_errD_fake = criterion(fake_logits, real_labels) 89 | errD_fake += uncond_errD_fake 90 | return errD_fake 91 | 92 | 93 | ############################# 94 | def weights_init(m): 95 | classname = m.__class__.__name__ 96 | if classname.find('Conv') != -1: 97 | m.weight.data.normal_(0.0, 0.02) 98 | elif classname.find('BatchNorm') != -1: 99 | m.weight.data.normal_(1.0, 0.02) 100 | m.bias.data.fill_(0) 101 | elif classname.find('Linear') != -1: 102 | m.weight.data.normal_(0.0, 0.02) 103 | if m.bias is not None: 104 | m.bias.data.fill_(0.0) 105 | 106 | 107 | VIS_COUNT = 64 108 | ############################# 109 | def save_img_results(data_img, fake, epoch, image_dir): 110 | num = VIS_COUNT 111 | 112 | fake = fake[0:num] 113 | # data_img is changed to [0,1] 114 | if data_img is not None: 115 | data_img = data_img[0:num] 116 | vutils.save_image( 117 | data_img, '%s/real_samples.png' % image_dir, 118 | normalize=True) 119 | # fake.data is still [-1, 1] 120 | vutils.save_image( 121 | fake.data, '%s/fake_samples_epoch_%03d.png' % 122 | (image_dir, epoch), normalize=True) 123 | else: 124 | vutils.save_image( 125 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' % 126 | (image_dir, epoch), normalize=True) 127 | 128 | 129 | def save_model(netG, netD, epoch, model_dir): 130 | torch.save( 131 | netG.state_dict(), 132 | '%s/netG_epoch_%d.pth' % (model_dir, epoch)) 133 | torch.save( 134 | netD.state_dict(), 135 | '%s/netD_epoch_last.pth' % (model_dir)) 136 | print('Save G/D models') 137 | 138 | 139 | def mkdir_p(path): 140 | try: 141 | os.makedirs(path) 142 | except OSError as exc: # Python >2.5 143 | if exc.errno == errno.EEXIST and os.path.isdir(path): 144 | pass 145 | else: 146 | raise 147 | 148 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """trainer.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1vRvijAMJSVqUZyjifZGktpAlcUBfr0XP 8 | """ 9 | 10 | NET_G='' 11 | TRAIN_FLAG = True 12 | TRAIN_MAX_EPOCH=500 13 | TRAIN_SNAPSHOT_INTERVAL=50 14 | TRAIN_BATCH_SIZE=32 15 | GPU_ID='0' 16 | NET_D='' 17 | CUDA=True 18 | gen = '/content/drive/My Drive/Model/netG_epoch_150.pth' 19 | Z_DIM=100 20 | 21 | TRAIN_PRETRAINED_MODEL = '' 22 | TRAIN_PRETRAINED_EPOCH = 600 23 | TRAIN_LR_DECAY_EPOCH = 600 24 | TRAIN_DISCRIMINATOR_LR = 2e-4 25 | TRAIN_GENERATOR_LR = 2e-4 26 | TRAIN_COEFF_KL=2.0 27 | 28 | # Commented out IPython magic to ensure Python compatibility. 29 | # from __future__ import print_function 30 | from six.moves import range 31 | from PIL import Image 32 | 33 | import torch.backends.cudnn as cudnn 34 | import torch 35 | import torch.nn as nn 36 | from torch.autograd import Variable 37 | import torch.optim as optim 38 | import os 39 | import time 40 | 41 | import numpy as np 42 | import torchfile 43 | 44 | from utils import mkdir_p 45 | from utils import weights_init 46 | from utils import save_img_results, save_model 47 | from utils import KL_loss 48 | from utils import compute_discriminator_loss, compute_generator_loss 49 | 50 | # from torch.utils.tensorboard import summary 51 | # from torch.utils.tensorboard import FileWriter 52 | 53 | 54 | class GANTrainer(object): 55 | def __init__(self, output_dir): 56 | if TRAIN_FLAG: 57 | self.model_dir = os.path.join(output_dir, 'Model') 58 | self.image_dir = os.path.join(output_dir, 'Image') 59 | self.log_dir = os.path.join(output_dir, 'Log') 60 | mkdir_p(self.model_dir) 61 | mkdir_p(self.image_dir) 62 | mkdir_p(self.log_dir) 63 | # self.summary_writer = FileWriter(self.log_dir) 64 | 65 | self.max_epoch = TRAIN_MAX_EPOCH 66 | self.snapshot_interval = TRAIN_SNAPSHOT_INTERVAL 67 | 68 | s_gpus = GPU_ID.split(',') 69 | self.gpus = [int(ix) for ix in s_gpus] 70 | self.num_gpus = len(self.gpus) 71 | self.batch_size = TRAIN_BATCH_SIZE * self.num_gpus 72 | torch.cuda.set_device(self.gpus[0]) 73 | cudnn.benchmark = True 74 | 75 | # ############# For training stageI GAN ############# 76 | def load_network_stageI(self): 77 | from model import STAGE1_G, STAGE1_D 78 | netG = STAGE1_G() 79 | netG.apply(weights_init) 80 | print(netG) 81 | netD = STAGE1_D() 82 | netD.apply(weights_init) 83 | print(netD) 84 | 85 | if NET_G != '': 86 | state_dict = \ 87 | torch.load(NET_G, 88 | map_location=lambda storage, loc: storage) 89 | netG.load_state_dict(state_dict) 90 | print('Load from: ', NET_G) 91 | if NET_D != '': 92 | state_dict = \ 93 | torch.load(NET_D, 94 | map_location=lambda storage, loc: storage) 95 | netD.load_state_dict(state_dict) 96 | print('Load from: ', NET_D) 97 | if CUDA: 98 | netG.cuda() 99 | netD.cuda() 100 | return netG, netD 101 | 102 | # ############# For training stageII GAN ############# 103 | def load_network_stageII(self): 104 | from model import STAGE1_G, STAGE2_G, STAGE2_D 105 | 106 | Stage1_G = STAGE1_G() 107 | netG = STAGE2_G(Stage1_G) 108 | netG.apply(weights_init) 109 | print(netG) 110 | if NET_G != '': 111 | state_dict = \ 112 | torch.load(NET_G, 113 | map_location=lambda storage, loc: storage) 114 | netG.load_state_dict(state_dict) 115 | print('Load from: ', NET_G) 116 | elif STAGE1_G != '': 117 | #state_dict = torch.load(STAGE1_G, map_location=lambda storage, loc: storage) 118 | state_dict = torch.load(gen) 119 | netG.STAGE1_G.load_state_dict(state_dict) 120 | print('Load from: ', STAGE1_G) 121 | else: 122 | print("Please give the Stage1_G path") 123 | return 124 | 125 | netD = STAGE2_D() 126 | netD.apply(weights_init) 127 | if NET_D != '': 128 | state_dict = \ 129 | torch.load(NET_D, 130 | map_location=lambda storage, loc: storage) 131 | netD.load_state_dict(state_dict) 132 | print('Load from: ', NET_D) 133 | print(netD) 134 | 135 | if CUDA: 136 | netG.cuda() 137 | netD.cuda() 138 | return netG, netD 139 | 140 | def train(self, data_loader, stage=1): 141 | if stage == 1: 142 | netG, netD = self.load_network_stageI() 143 | else: 144 | netG, netD = self.load_network_stageII() 145 | 146 | nz = Z_DIM 147 | batch_size = self.batch_size 148 | noise = Variable(torch.FloatTensor(batch_size, nz)) 149 | fixed_noise = \ 150 | Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), 151 | volatile=True) 152 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 153 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 154 | if CUDA: 155 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 156 | real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() 157 | 158 | generator_lr = TRAIN_GENERATOR_LR 159 | discriminator_lr = TRAIN_DISCRIMINATOR_LR 160 | lr_decay_step = TRAIN_LR_DECAY_EPOCH 161 | optimizerD = \ 162 | optim.Adam(netD.parameters(), 163 | lr=TRAIN_DISCRIMINATOR_LR, betas=(0.5, 0.999)) 164 | netG_para = [] 165 | for p in netG.parameters(): 166 | if p.requires_grad: 167 | netG_para.append(p) 168 | optimizerG = optim.Adam(netG_para, 169 | lr=TRAIN_GENERATOR_LR, 170 | betas=(0.5, 0.999)) 171 | count = 0 172 | 173 | for epoch in range(self.max_epoch): 174 | start_t = time.time() 175 | if epoch % lr_decay_step == 0 and epoch > 0: 176 | generator_lr *= 0.5 177 | for param_group in optimizerG.param_groups: 178 | param_group['lr'] = generator_lr 179 | discriminator_lr *= 0.5 180 | for param_group in optimizerD.param_groups: 181 | param_group['lr'] = discriminator_lr 182 | 183 | for i, data in enumerate(data_loader, 0): 184 | ###################################################### 185 | # (1) Prepare training data 186 | ###################################################### 187 | real_img_cpu, txt_embedding = data 188 | real_imgs = Variable(real_img_cpu) 189 | txt_embedding = Variable(txt_embedding) 190 | txt_embedding=txt_embedding.type(torch.FloatTensor) 191 | real_imgs=real_imgs.type(torch.FloatTensor) 192 | if CUDA: 193 | real_imgs = real_imgs.cuda() 194 | txt_embedding = txt_embedding.cuda() 195 | 196 | ####################################################### 197 | # (2) Generate fake images 198 | ###################################################### 199 | noise.data.normal_(0, 1) 200 | inputs = (txt_embedding, noise) 201 | _, fake_imgs, mu, logvar = \ 202 | nn.parallel.data_parallel(netG, inputs, self.gpus) 203 | 204 | ############################ 205 | # (3) Update D network 206 | ########################### 207 | netD.zero_grad() 208 | errD, errD_real, errD_wrong, errD_fake = \ 209 | compute_discriminator_loss(netD, real_imgs, fake_imgs, 210 | real_labels, fake_labels, 211 | mu, self.gpus) 212 | errD.backward() 213 | optimizerD.step() 214 | ############################ 215 | # (2) Update G network 216 | ########################### 217 | netG.zero_grad() 218 | errG = compute_generator_loss(netD, fake_imgs, 219 | real_labels, mu, self.gpus) 220 | kl_loss = KL_loss(mu, logvar) 221 | errG_total = errG + kl_loss * TRAIN_COEFF_KL 222 | errG_total.backward() 223 | optimizerG.step() 224 | 225 | count = count + 1 226 | if i % 100 == 0: 227 | print('D_loss', errD) 228 | print('G_loss', errG) 229 | print('KL_loss', kl_loss) 230 | # summary_D = summary.scalar('D_loss', errD.data[0]) 231 | # summary_D_r = summary.scalar('D_loss_real', errD_real) 232 | # summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) 233 | # summary_D_f = summary.scalar('D_loss_fake', errD_fake) 234 | # summary_G = summary.scalar('G_loss', errG.data[0]) 235 | # summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) 236 | 237 | # self.summary_writer.add_summary(summary_D, count) 238 | # self.summary_writer.add_summary(summary_D_r, count) 239 | # self.summary_writer.add_summary(summary_D_w, count) 240 | # self.summary_writer.add_summary(summary_D_f, count) 241 | # self.summary_writer.add_summary(summary_G, count) 242 | # self.summary_writer.add_summary(summary_KL, count) 243 | 244 | # save the image result for each epoch 245 | inputs = (txt_embedding, fixed_noise) 246 | lr_fake, fake, _, _ = \ 247 | nn.parallel.data_parallel(netG, inputs, self.gpus) 248 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 249 | if lr_fake is not None: 250 | save_img_results(None, lr_fake, epoch, self.image_dir) 251 | end_t = time.time() 252 | print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f 253 | Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f 254 | Total Time: %.2fsec 255 | ''' % (epoch, self.max_epoch, i, len(data_loader), 256 | errD, errG, kl_loss, 257 | errD_real, errD_wrong, errD_fake, (end_t - start_t))) 258 | if epoch % self.snapshot_interval == 0: 259 | save_model(netG, netD, epoch, self.model_dir) 260 | 261 | save_model(netG, netD, self.max_epoch, self.model_dir) 262 | 263 | self.summary_writer.close() 264 | 265 | def sample(self, datapath, stage=1): 266 | if stage == 1: 267 | netG, _ = self.load_network_stageI() 268 | else: 269 | netG, _ = self.load_network_stageII() 270 | netG.eval() 271 | 272 | # Load text embeddings generated from the encoder 273 | t_file = torchfile.load(datapath) 274 | captions_list = t_file.raw_txt 275 | embeddings = np.concatenate(t_file.fea_txt, axis=0) 276 | num_embeddings = len(captions_list) 277 | print('Successfully load sentences from: ', datapath) 278 | print('Total number of sentences:', num_embeddings) 279 | print('num_embeddings:', num_embeddings, embeddings.shape) 280 | # path to save generated samples 281 | save_dir = NET_G[:NET_G.find('.pth')] 282 | mkdir_p(save_dir) 283 | 284 | batch_size = np.minimum(num_embeddings, self.batch_size) 285 | nz = Z_DIM 286 | noise = Variable(torch.FloatTensor(batch_size, nz)) 287 | if CUDA: 288 | noise = noise.cuda() 289 | count = 0 290 | while count < num_embeddings: 291 | if count > 3000: 292 | break 293 | iend = count + batch_size 294 | if iend > num_embeddings: 295 | iend = num_embeddings 296 | count = num_embeddings - batch_size 297 | embeddings_batch = embeddings[count:iend] 298 | # captions_batch = captions_list[count:iend] 299 | txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) 300 | if CUDA: 301 | txt_embedding = txt_embedding.cuda() 302 | 303 | ####################################################### 304 | # (2) Generate fake images 305 | ###################################################### 306 | noise.data.normal_(0, 1) 307 | inputs = (txt_embedding, noise) 308 | _, fake_imgs, mu, logvar = \ 309 | nn.parallel.data_parallel(netG, inputs, self.gpus) 310 | for i in range(batch_size): 311 | save_name = '%s/%d.png' % (save_dir, count + i) 312 | im = fake_imgs[i].data.cpu().numpy() 313 | im = (im + 1.0) * 127.5 314 | im = im.astype(np.uint8) 315 | # print('im', im.shape) 316 | im = np.transpose(im, (1, 2, 0)) 317 | # print('im', im.shape) 318 | im = Image.fromarray(im) 319 | im.save(save_name) 320 | count += batch_size -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """model.ipynb 4 | 5 | Automatically generated by Colaboratory. 6 | 7 | Original file is located at 8 | https://colab.research.google.com/drive/1vqN158R5XSGjVnSSwN51okw90X9pk0VJ 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | from torch.autograd import Variable 15 | 16 | TEXT_DIMENSION = 768 17 | GAN_CONDITION_DIM = 128 18 | CUDA=True 19 | GAN_GF_DIM = 128 20 | GAN_DF_DIM = 64 21 | Z_DIM=100 22 | GAN_R_NUM = 4 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | # Upsale the spatial size by a factor of 2 31 | def upBlock(in_planes, out_planes): 32 | block = nn.Sequential( 33 | nn.Upsample(scale_factor=2, mode='nearest'), 34 | conv3x3(in_planes, out_planes), 35 | nn.BatchNorm2d(out_planes), 36 | nn.ReLU(True)) 37 | return block 38 | 39 | 40 | class ResBlock(nn.Module): 41 | def __init__(self, channel_num): 42 | super(ResBlock, self).__init__() 43 | self.block = nn.Sequential( 44 | conv3x3(channel_num, channel_num), 45 | nn.BatchNorm2d(channel_num), 46 | nn.ReLU(True), 47 | conv3x3(channel_num, channel_num), 48 | nn.BatchNorm2d(channel_num)) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | def forward(self, x): 52 | residual = x 53 | out = self.block(x) 54 | out += residual 55 | out = self.relu(out) 56 | return out 57 | 58 | 59 | class CA_NET(nn.Module): 60 | # some code is modified from vae examples 61 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 62 | def __init__(self): 63 | super(CA_NET, self).__init__() 64 | self.t_dim = TEXT_DIMENSION 65 | self.c_dim = GAN_CONDITION_DIM 66 | self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True) 67 | self.relu = nn.ReLU() 68 | 69 | def encode(self, text_embedding): 70 | x = self.relu(self.fc(text_embedding)) 71 | mu = x[:, :self.c_dim] 72 | logvar = x[:, self.c_dim:] 73 | return mu, logvar 74 | 75 | def reparametrize(self, mu, logvar): 76 | std = logvar.mul(0.5).exp_() 77 | if CUDA: 78 | eps = torch.cuda.FloatTensor(std.size()).normal_() 79 | else: 80 | eps = torch.FloatTensor(std.size()).normal_() 81 | eps = Variable(eps) 82 | return eps.mul(std).add_(mu) 83 | 84 | def forward(self, text_embedding): 85 | mu, logvar = self.encode(text_embedding) 86 | c_code = self.reparametrize(mu, logvar) 87 | return c_code, mu, logvar 88 | 89 | 90 | class D_GET_LOGITS(nn.Module): 91 | def __init__(self, ndf, nef, bcondition=True): 92 | super(D_GET_LOGITS, self).__init__() 93 | self.df_dim = ndf 94 | self.ef_dim = nef 95 | self.bcondition = bcondition 96 | if bcondition: 97 | self.outlogits = nn.Sequential( 98 | conv3x3(ndf * 8 + nef, ndf * 8), 99 | nn.BatchNorm2d(ndf * 8), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 102 | nn.Sigmoid()) 103 | else: 104 | self.outlogits = nn.Sequential( 105 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 106 | nn.Sigmoid()) 107 | 108 | def forward(self, h_code, c_code=None): 109 | # conditioning output 110 | if self.bcondition and c_code is not None: 111 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 112 | c_code = c_code.repeat(1, 1, 4, 4) 113 | # state size (ngf+egf) x 4 x 4 114 | h_c_code = torch.cat((h_code, c_code), 1) 115 | else: 116 | h_c_code = h_code 117 | 118 | output = self.outlogits(h_c_code) 119 | return output.view(-1) 120 | 121 | 122 | # ############# Networks for stageI GAN ############# 123 | class STAGE1_G(nn.Module): 124 | def __init__(self): 125 | super(STAGE1_G, self).__init__() 126 | self.gf_dim = GAN_GF_DIM * 8 127 | self.ef_dim = GAN_CONDITION_DIM 128 | self.z_dim = Z_DIM 129 | self.define_module() 130 | 131 | def define_module(self): 132 | ninput = self.z_dim + self.ef_dim 133 | ngf = self.gf_dim 134 | # TEXT.DIMENSION -> GAN.CONDITION_DIM 135 | self.ca_net = CA_NET() 136 | 137 | # -> ngf x 4 x 4 138 | self.fc = nn.Sequential( 139 | nn.Linear(ninput, ngf * 4 * 4, bias=False), 140 | nn.BatchNorm1d(ngf * 4 * 4), 141 | nn.ReLU(True)) 142 | 143 | # ngf x 4 x 4 -> ngf/2 x 8 x 8 144 | self.upsample1 = upBlock(ngf, ngf // 2) 145 | # -> ngf/4 x 16 x 16 146 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 147 | # -> ngf/8 x 32 x 32 148 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 149 | # -> ngf/16 x 64 x 64 150 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 151 | # -> 3 x 64 x 64 152 | self.img = nn.Sequential( 153 | conv3x3(ngf // 16, 3), 154 | nn.Tanh()) 155 | 156 | def forward(self, text_embedding, noise): 157 | c_code, mu, logvar = self.ca_net(text_embedding) 158 | z_c_code = torch.cat((noise, c_code), 1) 159 | h_code = self.fc(z_c_code) 160 | 161 | h_code = h_code.view(-1, self.gf_dim, 4, 4) 162 | h_code = self.upsample1(h_code) 163 | h_code = self.upsample2(h_code) 164 | h_code = self.upsample3(h_code) 165 | h_code = self.upsample4(h_code) 166 | # state size 3 x 64 x 64 167 | fake_img = self.img(h_code) 168 | return None, fake_img, mu, logvar 169 | 170 | 171 | class STAGE1_D(nn.Module): 172 | def __init__(self): 173 | super(STAGE1_D, self).__init__() 174 | self.df_dim = GAN_DF_DIM 175 | self.ef_dim = GAN_CONDITION_DIM 176 | self.define_module() 177 | 178 | def define_module(self): 179 | ndf, nef = self.df_dim, self.ef_dim 180 | self.encode_img = nn.Sequential( 181 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 182 | nn.LeakyReLU(0.2, inplace=True), 183 | # state size. (ndf) x 32 x 32 184 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 185 | nn.BatchNorm2d(ndf * 2), 186 | nn.LeakyReLU(0.2, inplace=True), 187 | # state size (ndf*2) x 16 x 16 188 | nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False), 189 | nn.BatchNorm2d(ndf * 4), 190 | nn.LeakyReLU(0.2, inplace=True), 191 | # state size (ndf*4) x 8 x 8 192 | nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False), 193 | nn.BatchNorm2d(ndf * 8), 194 | # state size (ndf * 8) x 4 x 4) 195 | nn.LeakyReLU(0.2, inplace=True) 196 | ) 197 | 198 | self.get_cond_logits = D_GET_LOGITS(ndf, nef) 199 | self.get_uncond_logits = None 200 | 201 | def forward(self, image): 202 | img_embedding = self.encode_img(image) 203 | 204 | return img_embedding 205 | 206 | 207 | # ############# Networks for stageII GAN ############# 208 | class STAGE2_G(nn.Module): 209 | def __init__(self, STAGE1_G): 210 | super(STAGE2_G, self).__init__() 211 | self.gf_dim = GAN_GF_DIM 212 | self.ef_dim = GAN_CONDITION_DIM 213 | self.z_dim = Z_DIM 214 | self.STAGE1_G = STAGE1_G 215 | # fix parameters of stageI GAN 216 | for param in self.STAGE1_G.parameters(): 217 | param.requires_grad = False 218 | self.define_module() 219 | 220 | def _make_layer(self, block, channel_num): 221 | layers = [] 222 | for i in range(GAN_R_NUM): 223 | layers.append(block(channel_num)) 224 | return nn.Sequential(*layers) 225 | 226 | def define_module(self): 227 | ngf = self.gf_dim 228 | # TEXT.DIMENSION -> GAN.CONDITION_DIM 229 | self.ca_net = CA_NET() 230 | # --> 4ngf x 16 x 16 231 | self.encoder = nn.Sequential( 232 | conv3x3(3, ngf), 233 | nn.ReLU(True), 234 | nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False), 235 | nn.BatchNorm2d(ngf * 2), 236 | nn.ReLU(True), 237 | nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False), 238 | nn.BatchNorm2d(ngf * 4), 239 | nn.ReLU(True)) 240 | self.hr_joint = nn.Sequential( 241 | conv3x3(self.ef_dim + ngf * 4, ngf * 4), 242 | nn.BatchNorm2d(ngf * 4), 243 | nn.ReLU(True)) 244 | self.residual = self._make_layer(ResBlock, ngf * 4) 245 | # --> 2ngf x 32 x 32 246 | self.upsample1 = upBlock(ngf * 4, ngf * 2) 247 | # --> ngf x 64 x 64 248 | self.upsample2 = upBlock(ngf * 2, ngf) 249 | # --> ngf // 2 x 128 x 128 250 | self.upsample3 = upBlock(ngf, ngf // 2) 251 | # --> ngf // 4 x 256 x 256 252 | self.upsample4 = upBlock(ngf // 2, ngf // 4) 253 | # --> 3 x 256 x 256 254 | self.img = nn.Sequential( 255 | conv3x3(ngf // 4, 3), 256 | nn.Tanh()) 257 | 258 | def forward(self, text_embedding, noise): 259 | _, stage1_img, _, _ = self.STAGE1_G(text_embedding, noise) 260 | stage1_img = stage1_img.detach() 261 | encoded_img = self.encoder(stage1_img) 262 | 263 | c_code, mu, logvar = self.ca_net(text_embedding) 264 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 265 | c_code = c_code.repeat(1, 1, 16, 16) 266 | i_c_code = torch.cat([encoded_img, c_code], 1) 267 | h_code = self.hr_joint(i_c_code) 268 | h_code = self.residual(h_code) 269 | 270 | h_code = self.upsample1(h_code) 271 | h_code = self.upsample2(h_code) 272 | h_code = self.upsample3(h_code) 273 | h_code = self.upsample4(h_code) 274 | 275 | fake_img = self.img(h_code) 276 | return stage1_img, fake_img, mu, logvar 277 | 278 | 279 | class STAGE2_D(nn.Module): 280 | def __init__(self): 281 | super(STAGE2_D, self).__init__() 282 | self.df_dim = GAN_DF_DIM 283 | self.ef_dim = GAN_CONDITION_DIM 284 | self.define_module() 285 | 286 | def define_module(self): 287 | ndf, nef = self.df_dim, self.ef_dim 288 | self.encode_img = nn.Sequential( 289 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), # 128 * 128 * ndf 290 | nn.LeakyReLU(0.2, inplace=True), 291 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 292 | nn.BatchNorm2d(ndf * 2), 293 | nn.LeakyReLU(0.2, inplace=True), # 64 * 64 * ndf * 2 294 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 295 | nn.BatchNorm2d(ndf * 4), 296 | nn.LeakyReLU(0.2, inplace=True), # 32 * 32 * ndf * 4 297 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 298 | nn.BatchNorm2d(ndf * 8), 299 | nn.LeakyReLU(0.2, inplace=True), # 16 * 16 * ndf * 8 300 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False), 301 | nn.BatchNorm2d(ndf * 16), 302 | nn.LeakyReLU(0.2, inplace=True), # 8 * 8 * ndf * 16 303 | nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False), 304 | nn.BatchNorm2d(ndf * 32), 305 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 32 306 | conv3x3(ndf * 32, ndf * 16), 307 | nn.BatchNorm2d(ndf * 16), 308 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 16 309 | conv3x3(ndf * 16, ndf * 8), 310 | nn.BatchNorm2d(ndf * 8), 311 | nn.LeakyReLU(0.2, inplace=True) # 4 * 4 * ndf * 8 312 | ) 313 | 314 | self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True) 315 | self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False) 316 | 317 | def forward(self, image): 318 | img_embedding = self.encode_img(image) 319 | 320 | return img_embedding 321 | ======= 322 | #!/usr/bin/env python 323 | # coding: utf-8 324 | 325 | 326 | import torch 327 | import torch.nn as nn 328 | 329 | 330 | ############## Configurations 331 | 332 | 333 | dim_text_embedding = 1000 334 | dim_conditioning_var = 128 335 | dim_noise = 100 336 | channels_gen = 128 337 | channels_discr = 64 338 | upscale_factor = 2 339 | 340 | 341 | # upsacles image by factor of 2 and also changes number of channels in upscaled image 342 | 343 | def upscale(in_channels,out_channels): 344 | return nn.Sequential( 345 | nn.Upsample(scale_factor=upscale_factor, mode='nearest'), 346 | nn.Conv2d(in_channels,out_channels,3,1,1,bias = False), 347 | nn.BatchNorm2d(out_channels), 348 | nn.ReLU(True)) 349 | 350 | 351 | 352 | # convolutional residual block, keeps number of channels constant 353 | 354 | class ResBlock(nn.Module): 355 | def __init__(self,channels): 356 | super().__init__() 357 | self.channels = channels 358 | self.block = nn.Sequential( 359 | nn.Conv2d(channels,channels,3,1,1,bias = False), 360 | nn.BatchNorm2d(channels), 361 | nn.ReLU(True), 362 | nn.Conv2d(channels,channels,3,1,1,bias = False), 363 | nn.BatchNorm2d(channels) 364 | ) 365 | self.ReLU = nn.ReLU(True) 366 | 367 | def forward(self,x): 368 | residue = x 369 | x = self.block(x) 370 | x = x + residue 371 | x = self.ReLU(x) 372 | return x 373 | 374 | 375 | 376 | class Conditional_augmentation(nn.Module): 377 | def __init__(self): 378 | super().__init__() 379 | self.dim_fc_inp = dim_text_embedding 380 | self.dim_fc_out = dim_conditioning_var 381 | self.fc = nn.Linear(self.dim_fc_inp, self.dim_fc_out*2, bias= True) 382 | self.relu = nn.ReLU() 383 | 384 | def get_mu_logvar(self,textEmbedding): 385 | x = self.relu(self.fc(textEmbedding)) 386 | 387 | mu = x[:,:dim_conditioning_var] 388 | logvar = x[:,dim_conditioning_var:] 389 | return mu,logvar 390 | 391 | 392 | def get_conditioning_variable(self,mu,logvar): 393 | epsilon = torch.randn(mu.size()) 394 | std = torch.exp(0.5*logvar) 395 | 396 | return mu + epsilon*std 397 | 398 | def forward(self,textEmbedding): 399 | mu, logvar = self.get_mu_logvar(textEmbedding) 400 | return self.get_conditioning_variable(mu, logvar) 401 | 402 | 403 | class Discriminator_logit(nn.Module): 404 | def __init__(self,dim_discr,dim_condVar,concat=False): 405 | super().__init__() 406 | self.dim_discr = dim_discr 407 | self.dim_condVar = dim_condVar 408 | self.concat = concat 409 | if concat == True: 410 | self.logits = nn.Sequential( 411 | nn.Conv2d(dim_discr*8 + dim_condVar,dim_discr*8,3,1,1, bias = False), 412 | nn.BatchNorm2d(dim_discr*8), 413 | nn.LeakyReLU(.2, True), 414 | nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4), 415 | nn.Sigmoid() 416 | ) 417 | 418 | else : 419 | self.logits = nn.Sequential( 420 | nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4), 421 | nn.Sigmoid() 422 | ) 423 | 424 | def forward(self, hidden_vec, cond_aug=None): 425 | if self.concat is True and cond_aug is not None: 426 | cond_aug = cond_aug.view(-1, self.dim_condVar, 1, 1) 427 | cond_aug = cond_aug.repeat(1, 1, 4, 4) 428 | hidden_vec = torch.cat((hidden_vec,cond_aug),1) 429 | 430 | return self.logits(hidden_vec).view(-1) 431 | 432 | 433 | class Stage1_Generator(nn.Module): 434 | def __init__(self): 435 | super().__init__() 436 | self.dim_noise = dim_noise 437 | self.dim_cond_aug = dim_conditioning_var 438 | self.channels_fc = channels_gen * 8 439 | self.cond_aug_net = Conditional_augmentation() 440 | 441 | self.fc = nn.Sequential( 442 | nn.Linear(self.dim_noise + self.dim_cond_aug, self.channels_fc * 4 * 4, bias = False), 443 | nn.BatchNorm1d(self.channels_fc * 4 * 4), 444 | nn.ReLU(True) 445 | ) 446 | 447 | self.upsample = nn.Sequential( 448 | upscale(self.channels_fc,self.channels_fc//2), 449 | upscale(self.channels_fc//2,self.channels_fc//4), 450 | upscale(self.channels_fc//4,self.channels_fc//8), 451 | upscale(self.channels_fc//8,self.channels_fc//16) 452 | ) 453 | 454 | self.generated_image = nn.Sequential( 455 | nn.Conv2d(self.channels_fc//16,3,3,1,1,bias = False), 456 | nn.Tanh()) 457 | 458 | 459 | def forward(self,noise,text_embedding): 460 | cond_aug = self.cond_aug_net(text_embedding) 461 | x = torch.cat((noise,cond_aug),1) 462 | 463 | x = self.fc(x) 464 | x = x.view(-1,self.channels_fc, 4, 4) 465 | x = self.upsample(x) 466 | 467 | image = self.generated_image(x) 468 | 469 | return image 470 | 471 | 472 | 473 | class Stage1_Discriminator(nn.Module): 474 | def __init__(self): 475 | super().__init__() 476 | self.channels_initial = channels_discr 477 | 478 | self.downsample = nn.Sequential( 479 | nn.Conv2d(3, self.channels_initial, kernel_size=4, stride=2, padding=1), 480 | nn.LeakyReLU(0.2,inplace=True), 481 | 482 | nn.Conv2d(self.channels_initial , self.channels_initial*2, kernel_size=4, stride=2, padding=1), 483 | nn.BatchNorm2d(self.channels_initial*2), 484 | nn.LeakyReLU(0.2,inplace=True), 485 | 486 | nn.Conv2d(self.channels_initial*2, self.channels_initial*4, kernel_size=4, stride=2, padding=1), 487 | nn.BatchNorm2d(self.channels_initial*4), 488 | nn.LeakyReLU(0.2,inplace=True), 489 | 490 | nn.Conv2d(self.channels_initial*4, self.channels_initial*8, kernel_size=4, stride=2, padding=1), 491 | nn.BatchNorm2d(self.channels_initial*8), 492 | nn.LeakyReLU(0.2,inplace=True), 493 | ) 494 | 495 | self.cond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,True) 496 | self.uncond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,False) 497 | 498 | def forward(self,img): 499 | return self.downsample(img) 500 | 501 | 502 | class Stage2_Generator(nn.Module): 503 | def __init__(self): 504 | super().__init__() 505 | self.downsample_channels = channels_gen 506 | self.dim_embedding = dim_conditioning_var 507 | self.cond_aug_net = Conditional_augmentation() 508 | self.Stage1_G = Stage1_Generator() 509 | self.downsample = nn.Sequential( 510 | nn.Conv2d(3, self.downsample_channels, kernel_size=3, stride=1, padding=1), 511 | nn.ReLU(inplace=True), 512 | 513 | nn.Conv2d(self.downsample_channels, self.downsample_channels*2, kernel_size=4, stride=2, padding=1), 514 | nn.BatchNorm2d(self.downsample_channels*2), 515 | nn.ReLU(inplace=True), 516 | 517 | nn.Conv2d(self.downsample_channels*2, self.downsample_channels*4, kernel_size=4, stride=2, padding=1), 518 | nn.BatchNorm2d(self.downsample_channels*4), 519 | nn.ReLU(inplace=True), 520 | ) 521 | self.hidden = nn.Sequential( 522 | nn.Conv2d(self.downsample_channels*4 + self.dim_embedding, self.downsample_channels*4, 3, 1, 1, bias=False), 523 | nn.BatchNorm2d(self.downsample_channels*4), 524 | nn.ReLU(True) 525 | ) 526 | self.residual = nn.Sequential( 527 | ResBlock(self.downsample_channels*4), 528 | ResBlock(self.downsample_channels*4), 529 | ResBlock(self.downsample_channels*4), 530 | ResBlock(self.downsample_channels*4) 531 | ) 532 | self.upsample = nn.Sequential( 533 | upscale(self.downsample_channels*4,self.downsample_channels*2), 534 | upscale(self.downsample_channels*2,self.downsample_channels), 535 | upscale(self.downsample_channels,self.downsample_channels//2), 536 | upscale(self.downsample_channels//2,self.downsample_channels//4) 537 | ) 538 | self.image = nn.Sequential( 539 | nn.Conv2d(self.downsample_channels//4, 3, 3, 1, 1, bias = False), 540 | nn.Tanh() 541 | ) 542 | 543 | def forward(self,noise, text_embedding): 544 | image = self.Stage1_G(noise, text_embedding) 545 | image = image.detach() 546 | enc_img = self.downsample(image) 547 | 548 | cond_aug = self.cond_aug_net(text_embedding) 549 | cond_aug = cond_aug.view(-1, self.dim_embedding, 1, 1) 550 | cond_aug = cond_aug.repeat(1, 1, 16, 16) 551 | 552 | x = torch.cat((enc_img, cond_aug),1) 553 | x = self.hidden(x) 554 | x = self.residual(x) 555 | x = self.upsample(x) 556 | enlarged_img = self.image(x) 557 | 558 | return enlarged_img 559 | 560 | 561 | class Stage2_Discriminator(nn.Module): 562 | def __init__(self): 563 | super().__init__() 564 | self.channels_initial = channels_discr 565 | self.downsample = nn.Sequential( 566 | nn.Conv2d(3, self.channels_initial, 4, 2, 1, bias = False), 567 | nn.LeakyReLU(0.2, inplace = True), 568 | 569 | nn.Conv2d(self.channels_initial, self.channels_initial*2, 4, 2, 1, bias = False), 570 | nn.BatchNorm2d(self.channels_initial*2), 571 | nn.LeakyReLU(0.2, inplace = True), 572 | 573 | nn.Conv2d(self.channels_initial*2, self.channels_initial*4, 4, 2, 1, bias = False), 574 | nn.BatchNorm2d(self.channels_initial*4), 575 | nn.LeakyReLU(0.2, inplace = True), 576 | 577 | nn.Conv2d(self.channels_initial*4, self.channels_initial*8, 4, 2, 1, bias = False), 578 | nn.BatchNorm2d(self.channels_initial*8), 579 | nn.LeakyReLU(0.2, inplace = True), 580 | 581 | nn.Conv2d(self.channels_initial*8, self.channels_initial*16, 4, 2, 1, bias = False), 582 | nn.BatchNorm2d(self.channels_initial*16), 583 | nn.LeakyReLU(0.2, inplace = True), 584 | 585 | nn.Conv2d(self.channels_initial*16, self.channels_initial*32, 4, 2, 1, bias = False), 586 | nn.BatchNorm2d(self.channels_initial*32), 587 | nn.LeakyReLU(0.2, inplace = True), 588 | 589 | nn.Conv2d(self.channels_initial*32, self.channels_initial*16, 3, 1, 1, bias = False), 590 | nn.BatchNorm2d(self.channels_initial*16), 591 | nn.LeakyReLU(0.2, inplace = True), 592 | 593 | nn.Conv2d(self.channels_initial*16, self.channels_initial*8, 3, 1, 1, bias = False), 594 | nn.BatchNorm2d(self.channels_initial*8), 595 | nn.LeakyReLU(0.2, inplace = True) 596 | ) 597 | 598 | self.cond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,True) 599 | self.uncond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,False) 600 | 601 | def forward(self,image): 602 | return self.downsample(image) 603 | 604 | 605 | --------------------------------------------------------------------------------