├── data
├── images
│ ├── sample
│ │ ├── s1_001.png
│ │ ├── s2_001.png
│ │ └── real_001.jpg
│ ├── Network Description
│ │ └── network description.jpg
│ ├── Stage-1 (102 flowers dataset)
│ │ ├── fake_samples_epoch_027.png
│ │ └── fake_samples_epoch_102.png
│ └── Stage-2 (102 flowers dataset)
│ │ ├── fake_samples_epoch_003.png
│ │ ├── fake_samples_epoch_058.png
│ │ └── fake_samples_epoch_160.png
├── 102-flower dataset
│ └── README.MD
├── Bert Embeddings
│ └── README.MD
└── pre-trained models
│ └── README.MD
├── bert_embeddings.py
├── README.md
├── dataset.py
├── utils.py
├── trainer.py
└── model.py
/data/images/sample/s1_001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/s1_001.png
--------------------------------------------------------------------------------
/data/images/sample/s2_001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/s2_001.png
--------------------------------------------------------------------------------
/data/images/sample/real_001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/sample/real_001.jpg
--------------------------------------------------------------------------------
/data/images/Network Description/network description.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Network Description/network description.jpg
--------------------------------------------------------------------------------
/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_027.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_027.png
--------------------------------------------------------------------------------
/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_102.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-1 (102 flowers dataset)/fake_samples_epoch_102.png
--------------------------------------------------------------------------------
/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_003.png
--------------------------------------------------------------------------------
/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_058.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_058.png
--------------------------------------------------------------------------------
/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_160.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-khanna/stackGAN-text-to-image-synthesis-/HEAD/data/images/Stage-2 (102 flowers dataset)/fake_samples_epoch_160.png
--------------------------------------------------------------------------------
/data/102-flower dataset/README.MD:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | Download the 102-Flowers dataset from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ or you can also download it from https://drive.google.com/drive/folders/1-9I293J77J40IpUCLtoycjC1T_1QWmVR?usp=sharing.
4 |
--------------------------------------------------------------------------------
/data/Bert Embeddings/README.MD:
--------------------------------------------------------------------------------
1 | # Bert - Embedding
2 | Run our Bert_Embeddings file on your captions to get the embeddings, you can also download our pre-processed Bert-Embeddings for 102-Flowers Dataset from https://drive.google.com/file/d/1XiNtxey51c3V03Xe4ELrMbqoQ_gMpvGe/view?usp=sharing
3 |
--------------------------------------------------------------------------------
/data/pre-trained models/README.MD:
--------------------------------------------------------------------------------
1 | # Pre Trained Models
2 |
3 | ## Stage 1 Generator
4 | Download from https://drive.google.com/file/d/1-F0IymmrNWoM33Fb2IbZhf4o41n5FbJN/view?usp=sharing
5 |
6 | ## Stage 1 Discriminator
7 | Download from https://drive.google.com/file/d/1-KfgdzLwfMdVvA1HvEHslroNFFlJrdpo/view?usp=sharing
8 |
9 | ## Stage 2 Generator
10 | Download from https://drive.google.com/file/d/1-YjOU7ALKcg6KZpOxPhPETUzcvWJduIj/view?usp=sharing
11 |
12 | ## Stage 2 Discriminator
13 | Download from https://drive.google.com/file/d/1-ckPuRtTMKsdVMX6KdyLESVDxigXzHl_/view?usp=sharing
14 |
15 | ### The pre-trained models are trained for 100 epochs each. Results can be easily be improved by increasing the number of epochs.
16 |
--------------------------------------------------------------------------------
/bert_embeddings.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """BERT_Embeddings.ipynb
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/13FXGWgWuqa4l3Dx7WM5wXPmZmShD2ffl
8 | """
9 |
10 | !pip install transformers
11 |
12 | import torch
13 | import pandas as pd
14 | import numpy as np
15 |
16 | df = pd.read_csv("caption_id.csv")
17 | sentences = df.Caption.values
18 |
19 | a=np.empty(((int)(sentences.shape[0]/2),768))
20 | from transformers import BertModel
21 | from transformers import BertTokenizer
22 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23 | model = BertModel.from_pretrained('bert-base-uncased')
24 |
25 | model.eval()
26 | for i in range(sen.shape[0]):
27 | input_sentence = torch.tensor(tokenizer.encode(sen[i])).unsqueeze(0)
28 | out = model(input_sentence)
29 | embeddings_of_last_layer = out[0]
30 | cls_embeddings = embeddings_of_last_layer[0].clone().detach().requires_grad_(False)
31 |
32 | a[i]=np.mean(np.array(cls_embeddings),axis=0)
33 | model.zero_grad()
34 |
35 | bc=pd.DataFrame(a)
36 | bc.to_csv('embbedings1.csv')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StackGAN with BERT-Embeddings
2 |
3 | Synthesizing high-quality images from text descriptions is a challenging problem in computer vision and has many practical applications. Samples generated by existing text to-image approaches can roughly reflect the meaning of the given descriptions, but they fail to contain necessary details and vivid object parts. In this project, we improve upon the existing Stacked Generative Adversarial Networks (StackGAN) by introducing BERT Embeddings to generate 256×256 photo-realistic images conditioned on captions.
4 | We divide the problem into two stages. The Stage-I GAN sketches the primitive shape and colours of the object based on the given text description, yielding low-resolution images. The Stage-II GAN takes the primitive results and text descriptions as inputs and generates high-resolution images with photo-realistic details. It can rectify defects in Stage-I results and add compelling details with outstanding refinement process.
5 |
6 |
7 |
8 | ### Dependencies
9 | python 3.0 and above
10 |
11 | Pytorch 1.6.0
12 | CUDA 10.1
13 |
14 | In addition, please `pip install` the following packages:
15 | - `numpy`
16 | - `pandas`
17 | - `torchfile`
18 |
19 | ## Sample case
20 |
21 | ### Caption
22 |
23 | ### Stage-1 Image
24 |
25 |
26 | ### Stage-2 Image
27 |
28 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 | from __future__ import unicode_literals
6 |
7 |
8 | import torch.utils.data as data
9 | from PIL import Image
10 | import PIL
11 | import os
12 | import os.path
13 | import pickle
14 | import random
15 | import numpy as np
16 | import pandas as pd
17 |
18 | #from miscc.config import cfg
19 |
20 |
21 | class TextDataset(data.Dataset):
22 | def __init__(self, data_dir, split='jpg', embedding_type='embeddings1',
23 | imsize=64, transform=None, target_transform=None):
24 |
25 | self.transform = transform
26 | self.target_transform = target_transform
27 | self.imsize = imsize
28 | self.data = []
29 | self.data_dir = data_dir
30 | split_dir = os.path.join(data_dir, split)
31 |
32 | self.filenames = self.load_filenames(split_dir)
33 | self.embeddings = self.load_embedding(split_dir, embedding_type)
34 |
35 | def get_img(self, img_path):
36 | img = Image.open(img_path).convert('RGB')
37 | width, height = img.size
38 | # load_size = int(self.imsize * 76 / 64)
39 | load_size = int(self.imsize)
40 | img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
41 | if self.transform is not None:
42 | img = self.transform(img)
43 | return img
44 |
45 | def load_all_captions(self):
46 | caption_dict = {}
47 | filepath = os.path.join(self.data_dir, 'caption_id.csv')
48 | cap=pd.read_csv(filepath)
49 | for key in self.filenames:
50 | caption_dict[key] = cap['Caption'][cap['image_id']==key]
51 | return caption_dict
52 |
53 | def load_embedding(self, data_dir, embedding_type):
54 | embedding_filename = '/embbedings1.csv'
55 | f=pd.read_csv(data_dir + embedding_filename)
56 | embeddings=np.array(np.array(f.iloc[:,1:]))
57 | return embeddings
58 |
59 | def load_filenames(self, data_dir):
60 | filepath = os.path.join(data_dir, 'filenames.csv')
61 | filenames=np.array(pd.read_csv(filepath)['image_id'])
62 | return filenames
63 |
64 | def __getitem__(self, index):
65 | key = self.filenames[index]
66 | data_dir = '%s/jpg' % self.data_dir
67 | #captions = self.captions[key]
68 | embedding = self.embeddings[index,:]
69 | img_name = '%s/%s.jpg' % (data_dir, key)
70 | img = self.get_img(img_name)
71 | if self.target_transform is not None:
72 | embedding = self.target_transform(embedding)
73 | return img, embedding
74 |
75 | def __len__(self):
76 | return len(self.filenames)
77 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """utils.ipynb
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1nSgFzAcLjQbj94Ow9wJnp4fZ_Y-vlqM1
8 | """
9 |
10 | import os
11 | import errno
12 | import numpy as np
13 |
14 | from copy import deepcopy
15 |
16 | from torch.nn import init
17 | import torch
18 | import torch.nn as nn
19 | import torchvision.utils as vutils
20 |
21 |
22 | #############################
23 | def KL_loss(mu, logvar):
24 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
25 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
26 | KLD = torch.mean(KLD_element).mul_(-0.5)
27 | return KLD
28 |
29 |
30 | def compute_discriminator_loss(netD, real_imgs, fake_imgs,
31 | real_labels, fake_labels,
32 | conditions, gpus):
33 | criterion = nn.BCELoss()
34 | batch_size = real_imgs.size(0)
35 | cond = conditions.detach()
36 | fake = fake_imgs.detach()
37 | real_features = nn.parallel.data_parallel(netD, (real_imgs), gpus)
38 | fake_features = nn.parallel.data_parallel(netD, (fake), gpus)
39 | # real pairs
40 | inputs = (real_features, cond)
41 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
42 | errD_real = criterion(real_logits, real_labels)
43 | # wrong pairs
44 | inputs = (real_features[:(batch_size-1)], cond[1:])
45 | wrong_logits = \
46 | nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
47 | errD_wrong = criterion(wrong_logits, fake_labels[1:])
48 | # fake pairs
49 | inputs = (fake_features, cond)
50 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
51 | errD_fake = criterion(fake_logits, fake_labels)
52 |
53 | if netD.get_uncond_logits is not None:
54 | real_logits = \
55 | nn.parallel.data_parallel(netD.get_uncond_logits,
56 | (real_features), gpus)
57 | fake_logits = \
58 | nn.parallel.data_parallel(netD.get_uncond_logits,
59 | (fake_features), gpus)
60 | uncond_errD_real = criterion(real_logits, real_labels)
61 | uncond_errD_fake = criterion(fake_logits, fake_labels)
62 | #
63 | errD = ((errD_real + uncond_errD_real) / 2. +
64 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.)
65 | errD_real = (errD_real + uncond_errD_real) / 2.
66 | errD_fake = (errD_fake + uncond_errD_fake) / 2.
67 | else:
68 | errD = errD_real + (errD_fake + errD_wrong) * 0.5
69 |
70 | return errD, errD_real, errD_wrong, errD_fake
71 | # return errD, errD_real.data[0], errD_wrong.data[0], errD_fake.data[0]
72 |
73 |
74 |
75 |
76 | def compute_generator_loss(netD, fake_imgs, real_labels, conditions, gpus):
77 | criterion = nn.BCELoss()
78 | cond = conditions.detach()
79 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs), gpus)
80 | # fake pairs
81 | inputs = (fake_features, cond)
82 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
83 | errD_fake = criterion(fake_logits, real_labels)
84 | if netD.get_uncond_logits is not None:
85 | fake_logits = \
86 | nn.parallel.data_parallel(netD.get_uncond_logits,
87 | (fake_features), gpus)
88 | uncond_errD_fake = criterion(fake_logits, real_labels)
89 | errD_fake += uncond_errD_fake
90 | return errD_fake
91 |
92 |
93 | #############################
94 | def weights_init(m):
95 | classname = m.__class__.__name__
96 | if classname.find('Conv') != -1:
97 | m.weight.data.normal_(0.0, 0.02)
98 | elif classname.find('BatchNorm') != -1:
99 | m.weight.data.normal_(1.0, 0.02)
100 | m.bias.data.fill_(0)
101 | elif classname.find('Linear') != -1:
102 | m.weight.data.normal_(0.0, 0.02)
103 | if m.bias is not None:
104 | m.bias.data.fill_(0.0)
105 |
106 |
107 | VIS_COUNT = 64
108 | #############################
109 | def save_img_results(data_img, fake, epoch, image_dir):
110 | num = VIS_COUNT
111 |
112 | fake = fake[0:num]
113 | # data_img is changed to [0,1]
114 | if data_img is not None:
115 | data_img = data_img[0:num]
116 | vutils.save_image(
117 | data_img, '%s/real_samples.png' % image_dir,
118 | normalize=True)
119 | # fake.data is still [-1, 1]
120 | vutils.save_image(
121 | fake.data, '%s/fake_samples_epoch_%03d.png' %
122 | (image_dir, epoch), normalize=True)
123 | else:
124 | vutils.save_image(
125 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' %
126 | (image_dir, epoch), normalize=True)
127 |
128 |
129 | def save_model(netG, netD, epoch, model_dir):
130 | torch.save(
131 | netG.state_dict(),
132 | '%s/netG_epoch_%d.pth' % (model_dir, epoch))
133 | torch.save(
134 | netD.state_dict(),
135 | '%s/netD_epoch_last.pth' % (model_dir))
136 | print('Save G/D models')
137 |
138 |
139 | def mkdir_p(path):
140 | try:
141 | os.makedirs(path)
142 | except OSError as exc: # Python >2.5
143 | if exc.errno == errno.EEXIST and os.path.isdir(path):
144 | pass
145 | else:
146 | raise
147 |
148 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """trainer.ipynb
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1vRvijAMJSVqUZyjifZGktpAlcUBfr0XP
8 | """
9 |
10 | NET_G=''
11 | TRAIN_FLAG = True
12 | TRAIN_MAX_EPOCH=500
13 | TRAIN_SNAPSHOT_INTERVAL=50
14 | TRAIN_BATCH_SIZE=32
15 | GPU_ID='0'
16 | NET_D=''
17 | CUDA=True
18 | gen = '/content/drive/My Drive/Model/netG_epoch_150.pth'
19 | Z_DIM=100
20 |
21 | TRAIN_PRETRAINED_MODEL = ''
22 | TRAIN_PRETRAINED_EPOCH = 600
23 | TRAIN_LR_DECAY_EPOCH = 600
24 | TRAIN_DISCRIMINATOR_LR = 2e-4
25 | TRAIN_GENERATOR_LR = 2e-4
26 | TRAIN_COEFF_KL=2.0
27 |
28 | # Commented out IPython magic to ensure Python compatibility.
29 | # from __future__ import print_function
30 | from six.moves import range
31 | from PIL import Image
32 |
33 | import torch.backends.cudnn as cudnn
34 | import torch
35 | import torch.nn as nn
36 | from torch.autograd import Variable
37 | import torch.optim as optim
38 | import os
39 | import time
40 |
41 | import numpy as np
42 | import torchfile
43 |
44 | from utils import mkdir_p
45 | from utils import weights_init
46 | from utils import save_img_results, save_model
47 | from utils import KL_loss
48 | from utils import compute_discriminator_loss, compute_generator_loss
49 |
50 | # from torch.utils.tensorboard import summary
51 | # from torch.utils.tensorboard import FileWriter
52 |
53 |
54 | class GANTrainer(object):
55 | def __init__(self, output_dir):
56 | if TRAIN_FLAG:
57 | self.model_dir = os.path.join(output_dir, 'Model')
58 | self.image_dir = os.path.join(output_dir, 'Image')
59 | self.log_dir = os.path.join(output_dir, 'Log')
60 | mkdir_p(self.model_dir)
61 | mkdir_p(self.image_dir)
62 | mkdir_p(self.log_dir)
63 | # self.summary_writer = FileWriter(self.log_dir)
64 |
65 | self.max_epoch = TRAIN_MAX_EPOCH
66 | self.snapshot_interval = TRAIN_SNAPSHOT_INTERVAL
67 |
68 | s_gpus = GPU_ID.split(',')
69 | self.gpus = [int(ix) for ix in s_gpus]
70 | self.num_gpus = len(self.gpus)
71 | self.batch_size = TRAIN_BATCH_SIZE * self.num_gpus
72 | torch.cuda.set_device(self.gpus[0])
73 | cudnn.benchmark = True
74 |
75 | # ############# For training stageI GAN #############
76 | def load_network_stageI(self):
77 | from model import STAGE1_G, STAGE1_D
78 | netG = STAGE1_G()
79 | netG.apply(weights_init)
80 | print(netG)
81 | netD = STAGE1_D()
82 | netD.apply(weights_init)
83 | print(netD)
84 |
85 | if NET_G != '':
86 | state_dict = \
87 | torch.load(NET_G,
88 | map_location=lambda storage, loc: storage)
89 | netG.load_state_dict(state_dict)
90 | print('Load from: ', NET_G)
91 | if NET_D != '':
92 | state_dict = \
93 | torch.load(NET_D,
94 | map_location=lambda storage, loc: storage)
95 | netD.load_state_dict(state_dict)
96 | print('Load from: ', NET_D)
97 | if CUDA:
98 | netG.cuda()
99 | netD.cuda()
100 | return netG, netD
101 |
102 | # ############# For training stageII GAN #############
103 | def load_network_stageII(self):
104 | from model import STAGE1_G, STAGE2_G, STAGE2_D
105 |
106 | Stage1_G = STAGE1_G()
107 | netG = STAGE2_G(Stage1_G)
108 | netG.apply(weights_init)
109 | print(netG)
110 | if NET_G != '':
111 | state_dict = \
112 | torch.load(NET_G,
113 | map_location=lambda storage, loc: storage)
114 | netG.load_state_dict(state_dict)
115 | print('Load from: ', NET_G)
116 | elif STAGE1_G != '':
117 | #state_dict = torch.load(STAGE1_G, map_location=lambda storage, loc: storage)
118 | state_dict = torch.load(gen)
119 | netG.STAGE1_G.load_state_dict(state_dict)
120 | print('Load from: ', STAGE1_G)
121 | else:
122 | print("Please give the Stage1_G path")
123 | return
124 |
125 | netD = STAGE2_D()
126 | netD.apply(weights_init)
127 | if NET_D != '':
128 | state_dict = \
129 | torch.load(NET_D,
130 | map_location=lambda storage, loc: storage)
131 | netD.load_state_dict(state_dict)
132 | print('Load from: ', NET_D)
133 | print(netD)
134 |
135 | if CUDA:
136 | netG.cuda()
137 | netD.cuda()
138 | return netG, netD
139 |
140 | def train(self, data_loader, stage=1):
141 | if stage == 1:
142 | netG, netD = self.load_network_stageI()
143 | else:
144 | netG, netD = self.load_network_stageII()
145 |
146 | nz = Z_DIM
147 | batch_size = self.batch_size
148 | noise = Variable(torch.FloatTensor(batch_size, nz))
149 | fixed_noise = \
150 | Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
151 | volatile=True)
152 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
153 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
154 | if CUDA:
155 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
156 | real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
157 |
158 | generator_lr = TRAIN_GENERATOR_LR
159 | discriminator_lr = TRAIN_DISCRIMINATOR_LR
160 | lr_decay_step = TRAIN_LR_DECAY_EPOCH
161 | optimizerD = \
162 | optim.Adam(netD.parameters(),
163 | lr=TRAIN_DISCRIMINATOR_LR, betas=(0.5, 0.999))
164 | netG_para = []
165 | for p in netG.parameters():
166 | if p.requires_grad:
167 | netG_para.append(p)
168 | optimizerG = optim.Adam(netG_para,
169 | lr=TRAIN_GENERATOR_LR,
170 | betas=(0.5, 0.999))
171 | count = 0
172 |
173 | for epoch in range(self.max_epoch):
174 | start_t = time.time()
175 | if epoch % lr_decay_step == 0 and epoch > 0:
176 | generator_lr *= 0.5
177 | for param_group in optimizerG.param_groups:
178 | param_group['lr'] = generator_lr
179 | discriminator_lr *= 0.5
180 | for param_group in optimizerD.param_groups:
181 | param_group['lr'] = discriminator_lr
182 |
183 | for i, data in enumerate(data_loader, 0):
184 | ######################################################
185 | # (1) Prepare training data
186 | ######################################################
187 | real_img_cpu, txt_embedding = data
188 | real_imgs = Variable(real_img_cpu)
189 | txt_embedding = Variable(txt_embedding)
190 | txt_embedding=txt_embedding.type(torch.FloatTensor)
191 | real_imgs=real_imgs.type(torch.FloatTensor)
192 | if CUDA:
193 | real_imgs = real_imgs.cuda()
194 | txt_embedding = txt_embedding.cuda()
195 |
196 | #######################################################
197 | # (2) Generate fake images
198 | ######################################################
199 | noise.data.normal_(0, 1)
200 | inputs = (txt_embedding, noise)
201 | _, fake_imgs, mu, logvar = \
202 | nn.parallel.data_parallel(netG, inputs, self.gpus)
203 |
204 | ############################
205 | # (3) Update D network
206 | ###########################
207 | netD.zero_grad()
208 | errD, errD_real, errD_wrong, errD_fake = \
209 | compute_discriminator_loss(netD, real_imgs, fake_imgs,
210 | real_labels, fake_labels,
211 | mu, self.gpus)
212 | errD.backward()
213 | optimizerD.step()
214 | ############################
215 | # (2) Update G network
216 | ###########################
217 | netG.zero_grad()
218 | errG = compute_generator_loss(netD, fake_imgs,
219 | real_labels, mu, self.gpus)
220 | kl_loss = KL_loss(mu, logvar)
221 | errG_total = errG + kl_loss * TRAIN_COEFF_KL
222 | errG_total.backward()
223 | optimizerG.step()
224 |
225 | count = count + 1
226 | if i % 100 == 0:
227 | print('D_loss', errD)
228 | print('G_loss', errG)
229 | print('KL_loss', kl_loss)
230 | # summary_D = summary.scalar('D_loss', errD.data[0])
231 | # summary_D_r = summary.scalar('D_loss_real', errD_real)
232 | # summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
233 | # summary_D_f = summary.scalar('D_loss_fake', errD_fake)
234 | # summary_G = summary.scalar('G_loss', errG.data[0])
235 | # summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
236 |
237 | # self.summary_writer.add_summary(summary_D, count)
238 | # self.summary_writer.add_summary(summary_D_r, count)
239 | # self.summary_writer.add_summary(summary_D_w, count)
240 | # self.summary_writer.add_summary(summary_D_f, count)
241 | # self.summary_writer.add_summary(summary_G, count)
242 | # self.summary_writer.add_summary(summary_KL, count)
243 |
244 | # save the image result for each epoch
245 | inputs = (txt_embedding, fixed_noise)
246 | lr_fake, fake, _, _ = \
247 | nn.parallel.data_parallel(netG, inputs, self.gpus)
248 | save_img_results(real_img_cpu, fake, epoch, self.image_dir)
249 | if lr_fake is not None:
250 | save_img_results(None, lr_fake, epoch, self.image_dir)
251 | end_t = time.time()
252 | print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
253 | Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
254 | Total Time: %.2fsec
255 | ''' % (epoch, self.max_epoch, i, len(data_loader),
256 | errD, errG, kl_loss,
257 | errD_real, errD_wrong, errD_fake, (end_t - start_t)))
258 | if epoch % self.snapshot_interval == 0:
259 | save_model(netG, netD, epoch, self.model_dir)
260 |
261 | save_model(netG, netD, self.max_epoch, self.model_dir)
262 |
263 | self.summary_writer.close()
264 |
265 | def sample(self, datapath, stage=1):
266 | if stage == 1:
267 | netG, _ = self.load_network_stageI()
268 | else:
269 | netG, _ = self.load_network_stageII()
270 | netG.eval()
271 |
272 | # Load text embeddings generated from the encoder
273 | t_file = torchfile.load(datapath)
274 | captions_list = t_file.raw_txt
275 | embeddings = np.concatenate(t_file.fea_txt, axis=0)
276 | num_embeddings = len(captions_list)
277 | print('Successfully load sentences from: ', datapath)
278 | print('Total number of sentences:', num_embeddings)
279 | print('num_embeddings:', num_embeddings, embeddings.shape)
280 | # path to save generated samples
281 | save_dir = NET_G[:NET_G.find('.pth')]
282 | mkdir_p(save_dir)
283 |
284 | batch_size = np.minimum(num_embeddings, self.batch_size)
285 | nz = Z_DIM
286 | noise = Variable(torch.FloatTensor(batch_size, nz))
287 | if CUDA:
288 | noise = noise.cuda()
289 | count = 0
290 | while count < num_embeddings:
291 | if count > 3000:
292 | break
293 | iend = count + batch_size
294 | if iend > num_embeddings:
295 | iend = num_embeddings
296 | count = num_embeddings - batch_size
297 | embeddings_batch = embeddings[count:iend]
298 | # captions_batch = captions_list[count:iend]
299 | txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
300 | if CUDA:
301 | txt_embedding = txt_embedding.cuda()
302 |
303 | #######################################################
304 | # (2) Generate fake images
305 | ######################################################
306 | noise.data.normal_(0, 1)
307 | inputs = (txt_embedding, noise)
308 | _, fake_imgs, mu, logvar = \
309 | nn.parallel.data_parallel(netG, inputs, self.gpus)
310 | for i in range(batch_size):
311 | save_name = '%s/%d.png' % (save_dir, count + i)
312 | im = fake_imgs[i].data.cpu().numpy()
313 | im = (im + 1.0) * 127.5
314 | im = im.astype(np.uint8)
315 | # print('im', im.shape)
316 | im = np.transpose(im, (1, 2, 0))
317 | # print('im', im.shape)
318 | im = Image.fromarray(im)
319 | im.save(save_name)
320 | count += batch_size
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | """model.ipynb
4 |
5 | Automatically generated by Colaboratory.
6 |
7 | Original file is located at
8 | https://colab.research.google.com/drive/1vqN158R5XSGjVnSSwN51okw90X9pk0VJ
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | from torch.autograd import Variable
15 |
16 | TEXT_DIMENSION = 768
17 | GAN_CONDITION_DIM = 128
18 | CUDA=True
19 | GAN_GF_DIM = 128
20 | GAN_DF_DIM = 64
21 | Z_DIM=100
22 | GAN_R_NUM = 4
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | "3x3 convolution with padding"
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=1, bias=False)
28 |
29 |
30 | # Upsale the spatial size by a factor of 2
31 | def upBlock(in_planes, out_planes):
32 | block = nn.Sequential(
33 | nn.Upsample(scale_factor=2, mode='nearest'),
34 | conv3x3(in_planes, out_planes),
35 | nn.BatchNorm2d(out_planes),
36 | nn.ReLU(True))
37 | return block
38 |
39 |
40 | class ResBlock(nn.Module):
41 | def __init__(self, channel_num):
42 | super(ResBlock, self).__init__()
43 | self.block = nn.Sequential(
44 | conv3x3(channel_num, channel_num),
45 | nn.BatchNorm2d(channel_num),
46 | nn.ReLU(True),
47 | conv3x3(channel_num, channel_num),
48 | nn.BatchNorm2d(channel_num))
49 | self.relu = nn.ReLU(inplace=True)
50 |
51 | def forward(self, x):
52 | residual = x
53 | out = self.block(x)
54 | out += residual
55 | out = self.relu(out)
56 | return out
57 |
58 |
59 | class CA_NET(nn.Module):
60 | # some code is modified from vae examples
61 | # (https://github.com/pytorch/examples/blob/master/vae/main.py)
62 | def __init__(self):
63 | super(CA_NET, self).__init__()
64 | self.t_dim = TEXT_DIMENSION
65 | self.c_dim = GAN_CONDITION_DIM
66 | self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
67 | self.relu = nn.ReLU()
68 |
69 | def encode(self, text_embedding):
70 | x = self.relu(self.fc(text_embedding))
71 | mu = x[:, :self.c_dim]
72 | logvar = x[:, self.c_dim:]
73 | return mu, logvar
74 |
75 | def reparametrize(self, mu, logvar):
76 | std = logvar.mul(0.5).exp_()
77 | if CUDA:
78 | eps = torch.cuda.FloatTensor(std.size()).normal_()
79 | else:
80 | eps = torch.FloatTensor(std.size()).normal_()
81 | eps = Variable(eps)
82 | return eps.mul(std).add_(mu)
83 |
84 | def forward(self, text_embedding):
85 | mu, logvar = self.encode(text_embedding)
86 | c_code = self.reparametrize(mu, logvar)
87 | return c_code, mu, logvar
88 |
89 |
90 | class D_GET_LOGITS(nn.Module):
91 | def __init__(self, ndf, nef, bcondition=True):
92 | super(D_GET_LOGITS, self).__init__()
93 | self.df_dim = ndf
94 | self.ef_dim = nef
95 | self.bcondition = bcondition
96 | if bcondition:
97 | self.outlogits = nn.Sequential(
98 | conv3x3(ndf * 8 + nef, ndf * 8),
99 | nn.BatchNorm2d(ndf * 8),
100 | nn.LeakyReLU(0.2, inplace=True),
101 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
102 | nn.Sigmoid())
103 | else:
104 | self.outlogits = nn.Sequential(
105 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
106 | nn.Sigmoid())
107 |
108 | def forward(self, h_code, c_code=None):
109 | # conditioning output
110 | if self.bcondition and c_code is not None:
111 | c_code = c_code.view(-1, self.ef_dim, 1, 1)
112 | c_code = c_code.repeat(1, 1, 4, 4)
113 | # state size (ngf+egf) x 4 x 4
114 | h_c_code = torch.cat((h_code, c_code), 1)
115 | else:
116 | h_c_code = h_code
117 |
118 | output = self.outlogits(h_c_code)
119 | return output.view(-1)
120 |
121 |
122 | # ############# Networks for stageI GAN #############
123 | class STAGE1_G(nn.Module):
124 | def __init__(self):
125 | super(STAGE1_G, self).__init__()
126 | self.gf_dim = GAN_GF_DIM * 8
127 | self.ef_dim = GAN_CONDITION_DIM
128 | self.z_dim = Z_DIM
129 | self.define_module()
130 |
131 | def define_module(self):
132 | ninput = self.z_dim + self.ef_dim
133 | ngf = self.gf_dim
134 | # TEXT.DIMENSION -> GAN.CONDITION_DIM
135 | self.ca_net = CA_NET()
136 |
137 | # -> ngf x 4 x 4
138 | self.fc = nn.Sequential(
139 | nn.Linear(ninput, ngf * 4 * 4, bias=False),
140 | nn.BatchNorm1d(ngf * 4 * 4),
141 | nn.ReLU(True))
142 |
143 | # ngf x 4 x 4 -> ngf/2 x 8 x 8
144 | self.upsample1 = upBlock(ngf, ngf // 2)
145 | # -> ngf/4 x 16 x 16
146 | self.upsample2 = upBlock(ngf // 2, ngf // 4)
147 | # -> ngf/8 x 32 x 32
148 | self.upsample3 = upBlock(ngf // 4, ngf // 8)
149 | # -> ngf/16 x 64 x 64
150 | self.upsample4 = upBlock(ngf // 8, ngf // 16)
151 | # -> 3 x 64 x 64
152 | self.img = nn.Sequential(
153 | conv3x3(ngf // 16, 3),
154 | nn.Tanh())
155 |
156 | def forward(self, text_embedding, noise):
157 | c_code, mu, logvar = self.ca_net(text_embedding)
158 | z_c_code = torch.cat((noise, c_code), 1)
159 | h_code = self.fc(z_c_code)
160 |
161 | h_code = h_code.view(-1, self.gf_dim, 4, 4)
162 | h_code = self.upsample1(h_code)
163 | h_code = self.upsample2(h_code)
164 | h_code = self.upsample3(h_code)
165 | h_code = self.upsample4(h_code)
166 | # state size 3 x 64 x 64
167 | fake_img = self.img(h_code)
168 | return None, fake_img, mu, logvar
169 |
170 |
171 | class STAGE1_D(nn.Module):
172 | def __init__(self):
173 | super(STAGE1_D, self).__init__()
174 | self.df_dim = GAN_DF_DIM
175 | self.ef_dim = GAN_CONDITION_DIM
176 | self.define_module()
177 |
178 | def define_module(self):
179 | ndf, nef = self.df_dim, self.ef_dim
180 | self.encode_img = nn.Sequential(
181 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
182 | nn.LeakyReLU(0.2, inplace=True),
183 | # state size. (ndf) x 32 x 32
184 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
185 | nn.BatchNorm2d(ndf * 2),
186 | nn.LeakyReLU(0.2, inplace=True),
187 | # state size (ndf*2) x 16 x 16
188 | nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
189 | nn.BatchNorm2d(ndf * 4),
190 | nn.LeakyReLU(0.2, inplace=True),
191 | # state size (ndf*4) x 8 x 8
192 | nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
193 | nn.BatchNorm2d(ndf * 8),
194 | # state size (ndf * 8) x 4 x 4)
195 | nn.LeakyReLU(0.2, inplace=True)
196 | )
197 |
198 | self.get_cond_logits = D_GET_LOGITS(ndf, nef)
199 | self.get_uncond_logits = None
200 |
201 | def forward(self, image):
202 | img_embedding = self.encode_img(image)
203 |
204 | return img_embedding
205 |
206 |
207 | # ############# Networks for stageII GAN #############
208 | class STAGE2_G(nn.Module):
209 | def __init__(self, STAGE1_G):
210 | super(STAGE2_G, self).__init__()
211 | self.gf_dim = GAN_GF_DIM
212 | self.ef_dim = GAN_CONDITION_DIM
213 | self.z_dim = Z_DIM
214 | self.STAGE1_G = STAGE1_G
215 | # fix parameters of stageI GAN
216 | for param in self.STAGE1_G.parameters():
217 | param.requires_grad = False
218 | self.define_module()
219 |
220 | def _make_layer(self, block, channel_num):
221 | layers = []
222 | for i in range(GAN_R_NUM):
223 | layers.append(block(channel_num))
224 | return nn.Sequential(*layers)
225 |
226 | def define_module(self):
227 | ngf = self.gf_dim
228 | # TEXT.DIMENSION -> GAN.CONDITION_DIM
229 | self.ca_net = CA_NET()
230 | # --> 4ngf x 16 x 16
231 | self.encoder = nn.Sequential(
232 | conv3x3(3, ngf),
233 | nn.ReLU(True),
234 | nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False),
235 | nn.BatchNorm2d(ngf * 2),
236 | nn.ReLU(True),
237 | nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False),
238 | nn.BatchNorm2d(ngf * 4),
239 | nn.ReLU(True))
240 | self.hr_joint = nn.Sequential(
241 | conv3x3(self.ef_dim + ngf * 4, ngf * 4),
242 | nn.BatchNorm2d(ngf * 4),
243 | nn.ReLU(True))
244 | self.residual = self._make_layer(ResBlock, ngf * 4)
245 | # --> 2ngf x 32 x 32
246 | self.upsample1 = upBlock(ngf * 4, ngf * 2)
247 | # --> ngf x 64 x 64
248 | self.upsample2 = upBlock(ngf * 2, ngf)
249 | # --> ngf // 2 x 128 x 128
250 | self.upsample3 = upBlock(ngf, ngf // 2)
251 | # --> ngf // 4 x 256 x 256
252 | self.upsample4 = upBlock(ngf // 2, ngf // 4)
253 | # --> 3 x 256 x 256
254 | self.img = nn.Sequential(
255 | conv3x3(ngf // 4, 3),
256 | nn.Tanh())
257 |
258 | def forward(self, text_embedding, noise):
259 | _, stage1_img, _, _ = self.STAGE1_G(text_embedding, noise)
260 | stage1_img = stage1_img.detach()
261 | encoded_img = self.encoder(stage1_img)
262 |
263 | c_code, mu, logvar = self.ca_net(text_embedding)
264 | c_code = c_code.view(-1, self.ef_dim, 1, 1)
265 | c_code = c_code.repeat(1, 1, 16, 16)
266 | i_c_code = torch.cat([encoded_img, c_code], 1)
267 | h_code = self.hr_joint(i_c_code)
268 | h_code = self.residual(h_code)
269 |
270 | h_code = self.upsample1(h_code)
271 | h_code = self.upsample2(h_code)
272 | h_code = self.upsample3(h_code)
273 | h_code = self.upsample4(h_code)
274 |
275 | fake_img = self.img(h_code)
276 | return stage1_img, fake_img, mu, logvar
277 |
278 |
279 | class STAGE2_D(nn.Module):
280 | def __init__(self):
281 | super(STAGE2_D, self).__init__()
282 | self.df_dim = GAN_DF_DIM
283 | self.ef_dim = GAN_CONDITION_DIM
284 | self.define_module()
285 |
286 | def define_module(self):
287 | ndf, nef = self.df_dim, self.ef_dim
288 | self.encode_img = nn.Sequential(
289 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), # 128 * 128 * ndf
290 | nn.LeakyReLU(0.2, inplace=True),
291 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
292 | nn.BatchNorm2d(ndf * 2),
293 | nn.LeakyReLU(0.2, inplace=True), # 64 * 64 * ndf * 2
294 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
295 | nn.BatchNorm2d(ndf * 4),
296 | nn.LeakyReLU(0.2, inplace=True), # 32 * 32 * ndf * 4
297 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
298 | nn.BatchNorm2d(ndf * 8),
299 | nn.LeakyReLU(0.2, inplace=True), # 16 * 16 * ndf * 8
300 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
301 | nn.BatchNorm2d(ndf * 16),
302 | nn.LeakyReLU(0.2, inplace=True), # 8 * 8 * ndf * 16
303 | nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
304 | nn.BatchNorm2d(ndf * 32),
305 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 32
306 | conv3x3(ndf * 32, ndf * 16),
307 | nn.BatchNorm2d(ndf * 16),
308 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 16
309 | conv3x3(ndf * 16, ndf * 8),
310 | nn.BatchNorm2d(ndf * 8),
311 | nn.LeakyReLU(0.2, inplace=True) # 4 * 4 * ndf * 8
312 | )
313 |
314 | self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True)
315 | self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False)
316 |
317 | def forward(self, image):
318 | img_embedding = self.encode_img(image)
319 |
320 | return img_embedding
321 | =======
322 | #!/usr/bin/env python
323 | # coding: utf-8
324 |
325 |
326 | import torch
327 | import torch.nn as nn
328 |
329 |
330 | ############## Configurations
331 |
332 |
333 | dim_text_embedding = 1000
334 | dim_conditioning_var = 128
335 | dim_noise = 100
336 | channels_gen = 128
337 | channels_discr = 64
338 | upscale_factor = 2
339 |
340 |
341 | # upsacles image by factor of 2 and also changes number of channels in upscaled image
342 |
343 | def upscale(in_channels,out_channels):
344 | return nn.Sequential(
345 | nn.Upsample(scale_factor=upscale_factor, mode='nearest'),
346 | nn.Conv2d(in_channels,out_channels,3,1,1,bias = False),
347 | nn.BatchNorm2d(out_channels),
348 | nn.ReLU(True))
349 |
350 |
351 |
352 | # convolutional residual block, keeps number of channels constant
353 |
354 | class ResBlock(nn.Module):
355 | def __init__(self,channels):
356 | super().__init__()
357 | self.channels = channels
358 | self.block = nn.Sequential(
359 | nn.Conv2d(channels,channels,3,1,1,bias = False),
360 | nn.BatchNorm2d(channels),
361 | nn.ReLU(True),
362 | nn.Conv2d(channels,channels,3,1,1,bias = False),
363 | nn.BatchNorm2d(channels)
364 | )
365 | self.ReLU = nn.ReLU(True)
366 |
367 | def forward(self,x):
368 | residue = x
369 | x = self.block(x)
370 | x = x + residue
371 | x = self.ReLU(x)
372 | return x
373 |
374 |
375 |
376 | class Conditional_augmentation(nn.Module):
377 | def __init__(self):
378 | super().__init__()
379 | self.dim_fc_inp = dim_text_embedding
380 | self.dim_fc_out = dim_conditioning_var
381 | self.fc = nn.Linear(self.dim_fc_inp, self.dim_fc_out*2, bias= True)
382 | self.relu = nn.ReLU()
383 |
384 | def get_mu_logvar(self,textEmbedding):
385 | x = self.relu(self.fc(textEmbedding))
386 |
387 | mu = x[:,:dim_conditioning_var]
388 | logvar = x[:,dim_conditioning_var:]
389 | return mu,logvar
390 |
391 |
392 | def get_conditioning_variable(self,mu,logvar):
393 | epsilon = torch.randn(mu.size())
394 | std = torch.exp(0.5*logvar)
395 |
396 | return mu + epsilon*std
397 |
398 | def forward(self,textEmbedding):
399 | mu, logvar = self.get_mu_logvar(textEmbedding)
400 | return self.get_conditioning_variable(mu, logvar)
401 |
402 |
403 | class Discriminator_logit(nn.Module):
404 | def __init__(self,dim_discr,dim_condVar,concat=False):
405 | super().__init__()
406 | self.dim_discr = dim_discr
407 | self.dim_condVar = dim_condVar
408 | self.concat = concat
409 | if concat == True:
410 | self.logits = nn.Sequential(
411 | nn.Conv2d(dim_discr*8 + dim_condVar,dim_discr*8,3,1,1, bias = False),
412 | nn.BatchNorm2d(dim_discr*8),
413 | nn.LeakyReLU(.2, True),
414 | nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4),
415 | nn.Sigmoid()
416 | )
417 |
418 | else :
419 | self.logits = nn.Sequential(
420 | nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4),
421 | nn.Sigmoid()
422 | )
423 |
424 | def forward(self, hidden_vec, cond_aug=None):
425 | if self.concat is True and cond_aug is not None:
426 | cond_aug = cond_aug.view(-1, self.dim_condVar, 1, 1)
427 | cond_aug = cond_aug.repeat(1, 1, 4, 4)
428 | hidden_vec = torch.cat((hidden_vec,cond_aug),1)
429 |
430 | return self.logits(hidden_vec).view(-1)
431 |
432 |
433 | class Stage1_Generator(nn.Module):
434 | def __init__(self):
435 | super().__init__()
436 | self.dim_noise = dim_noise
437 | self.dim_cond_aug = dim_conditioning_var
438 | self.channels_fc = channels_gen * 8
439 | self.cond_aug_net = Conditional_augmentation()
440 |
441 | self.fc = nn.Sequential(
442 | nn.Linear(self.dim_noise + self.dim_cond_aug, self.channels_fc * 4 * 4, bias = False),
443 | nn.BatchNorm1d(self.channels_fc * 4 * 4),
444 | nn.ReLU(True)
445 | )
446 |
447 | self.upsample = nn.Sequential(
448 | upscale(self.channels_fc,self.channels_fc//2),
449 | upscale(self.channels_fc//2,self.channels_fc//4),
450 | upscale(self.channels_fc//4,self.channels_fc//8),
451 | upscale(self.channels_fc//8,self.channels_fc//16)
452 | )
453 |
454 | self.generated_image = nn.Sequential(
455 | nn.Conv2d(self.channels_fc//16,3,3,1,1,bias = False),
456 | nn.Tanh())
457 |
458 |
459 | def forward(self,noise,text_embedding):
460 | cond_aug = self.cond_aug_net(text_embedding)
461 | x = torch.cat((noise,cond_aug),1)
462 |
463 | x = self.fc(x)
464 | x = x.view(-1,self.channels_fc, 4, 4)
465 | x = self.upsample(x)
466 |
467 | image = self.generated_image(x)
468 |
469 | return image
470 |
471 |
472 |
473 | class Stage1_Discriminator(nn.Module):
474 | def __init__(self):
475 | super().__init__()
476 | self.channels_initial = channels_discr
477 |
478 | self.downsample = nn.Sequential(
479 | nn.Conv2d(3, self.channels_initial, kernel_size=4, stride=2, padding=1),
480 | nn.LeakyReLU(0.2,inplace=True),
481 |
482 | nn.Conv2d(self.channels_initial , self.channels_initial*2, kernel_size=4, stride=2, padding=1),
483 | nn.BatchNorm2d(self.channels_initial*2),
484 | nn.LeakyReLU(0.2,inplace=True),
485 |
486 | nn.Conv2d(self.channels_initial*2, self.channels_initial*4, kernel_size=4, stride=2, padding=1),
487 | nn.BatchNorm2d(self.channels_initial*4),
488 | nn.LeakyReLU(0.2,inplace=True),
489 |
490 | nn.Conv2d(self.channels_initial*4, self.channels_initial*8, kernel_size=4, stride=2, padding=1),
491 | nn.BatchNorm2d(self.channels_initial*8),
492 | nn.LeakyReLU(0.2,inplace=True),
493 | )
494 |
495 | self.cond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,True)
496 | self.uncond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,False)
497 |
498 | def forward(self,img):
499 | return self.downsample(img)
500 |
501 |
502 | class Stage2_Generator(nn.Module):
503 | def __init__(self):
504 | super().__init__()
505 | self.downsample_channels = channels_gen
506 | self.dim_embedding = dim_conditioning_var
507 | self.cond_aug_net = Conditional_augmentation()
508 | self.Stage1_G = Stage1_Generator()
509 | self.downsample = nn.Sequential(
510 | nn.Conv2d(3, self.downsample_channels, kernel_size=3, stride=1, padding=1),
511 | nn.ReLU(inplace=True),
512 |
513 | nn.Conv2d(self.downsample_channels, self.downsample_channels*2, kernel_size=4, stride=2, padding=1),
514 | nn.BatchNorm2d(self.downsample_channels*2),
515 | nn.ReLU(inplace=True),
516 |
517 | nn.Conv2d(self.downsample_channels*2, self.downsample_channels*4, kernel_size=4, stride=2, padding=1),
518 | nn.BatchNorm2d(self.downsample_channels*4),
519 | nn.ReLU(inplace=True),
520 | )
521 | self.hidden = nn.Sequential(
522 | nn.Conv2d(self.downsample_channels*4 + self.dim_embedding, self.downsample_channels*4, 3, 1, 1, bias=False),
523 | nn.BatchNorm2d(self.downsample_channels*4),
524 | nn.ReLU(True)
525 | )
526 | self.residual = nn.Sequential(
527 | ResBlock(self.downsample_channels*4),
528 | ResBlock(self.downsample_channels*4),
529 | ResBlock(self.downsample_channels*4),
530 | ResBlock(self.downsample_channels*4)
531 | )
532 | self.upsample = nn.Sequential(
533 | upscale(self.downsample_channels*4,self.downsample_channels*2),
534 | upscale(self.downsample_channels*2,self.downsample_channels),
535 | upscale(self.downsample_channels,self.downsample_channels//2),
536 | upscale(self.downsample_channels//2,self.downsample_channels//4)
537 | )
538 | self.image = nn.Sequential(
539 | nn.Conv2d(self.downsample_channels//4, 3, 3, 1, 1, bias = False),
540 | nn.Tanh()
541 | )
542 |
543 | def forward(self,noise, text_embedding):
544 | image = self.Stage1_G(noise, text_embedding)
545 | image = image.detach()
546 | enc_img = self.downsample(image)
547 |
548 | cond_aug = self.cond_aug_net(text_embedding)
549 | cond_aug = cond_aug.view(-1, self.dim_embedding, 1, 1)
550 | cond_aug = cond_aug.repeat(1, 1, 16, 16)
551 |
552 | x = torch.cat((enc_img, cond_aug),1)
553 | x = self.hidden(x)
554 | x = self.residual(x)
555 | x = self.upsample(x)
556 | enlarged_img = self.image(x)
557 |
558 | return enlarged_img
559 |
560 |
561 | class Stage2_Discriminator(nn.Module):
562 | def __init__(self):
563 | super().__init__()
564 | self.channels_initial = channels_discr
565 | self.downsample = nn.Sequential(
566 | nn.Conv2d(3, self.channels_initial, 4, 2, 1, bias = False),
567 | nn.LeakyReLU(0.2, inplace = True),
568 |
569 | nn.Conv2d(self.channels_initial, self.channels_initial*2, 4, 2, 1, bias = False),
570 | nn.BatchNorm2d(self.channels_initial*2),
571 | nn.LeakyReLU(0.2, inplace = True),
572 |
573 | nn.Conv2d(self.channels_initial*2, self.channels_initial*4, 4, 2, 1, bias = False),
574 | nn.BatchNorm2d(self.channels_initial*4),
575 | nn.LeakyReLU(0.2, inplace = True),
576 |
577 | nn.Conv2d(self.channels_initial*4, self.channels_initial*8, 4, 2, 1, bias = False),
578 | nn.BatchNorm2d(self.channels_initial*8),
579 | nn.LeakyReLU(0.2, inplace = True),
580 |
581 | nn.Conv2d(self.channels_initial*8, self.channels_initial*16, 4, 2, 1, bias = False),
582 | nn.BatchNorm2d(self.channels_initial*16),
583 | nn.LeakyReLU(0.2, inplace = True),
584 |
585 | nn.Conv2d(self.channels_initial*16, self.channels_initial*32, 4, 2, 1, bias = False),
586 | nn.BatchNorm2d(self.channels_initial*32),
587 | nn.LeakyReLU(0.2, inplace = True),
588 |
589 | nn.Conv2d(self.channels_initial*32, self.channels_initial*16, 3, 1, 1, bias = False),
590 | nn.BatchNorm2d(self.channels_initial*16),
591 | nn.LeakyReLU(0.2, inplace = True),
592 |
593 | nn.Conv2d(self.channels_initial*16, self.channels_initial*8, 3, 1, 1, bias = False),
594 | nn.BatchNorm2d(self.channels_initial*8),
595 | nn.LeakyReLU(0.2, inplace = True)
596 | )
597 |
598 | self.cond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,True)
599 | self.uncond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,False)
600 |
601 | def forward(self,image):
602 | return self.downsample(image)
603 |
604 |
605 |
--------------------------------------------------------------------------------