├── .gitignore ├── CONFIG ├── LICENSE ├── README.md ├── data.py ├── examples └── .gitignore ├── images ├── architecture.png ├── results_birds.png └── results_flowers.png ├── model.py ├── scripts ├── test_birds.sh ├── test_flowers.sh ├── train_birds.sh ├── train_flowers.sh ├── train_text_embedding_birds.sh └── train_text_embedding_flowers.sh ├── test.py ├── test ├── birds │ ├── Black_Billed_Cuckoo_0055_26223.jpg │ ├── Black_Footed_Albatross_0060_796076.jpg │ ├── Brewer_Blackbird_0112_2340.jpg │ ├── Horned_Grebe_0088_35023.jpg │ ├── Northern_Flicker_0100_28898.jpg │ ├── Red_Legged_Kittiwake_0060_795414.jpg │ └── Tree_Swallow_0107_136223.jpg ├── flowers │ ├── image_03828.jpg │ ├── image_03996.jpg │ ├── image_04899.jpg │ ├── image_05151.jpg │ ├── image_06615.jpg │ ├── image_06734.jpg │ └── image_07200.jpg ├── result_birds │ └── .gitignore ├── result_flowers │ └── .gitignore ├── text_birds.txt └── text_flowers.txt ├── train.py └── train_text_embedding.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store -------------------------------------------------------------------------------- /CONFIG: -------------------------------------------------------------------------------- 1 | BIRDS_IMG_ROOT=/Users/seonghyeon/data/data/CUB_200_2011/CUB_200_2011/images 2 | BIRDS_CAPTION_ROOT=/Users/seonghyeon/data/data/CUB_200_2011/cub_icml 3 | 4 | FLOWERS_IMG_ROOT=/Users/seonghyeon/data/data/oxford102 5 | FLOWERS_CAPTION_ROOT=/Users/seonghyeon/data/data/oxford102 6 | 7 | FASTTEXT_MODEL=/Users/seonghyeon/data/data/fastText/wiki.en/wiki.en.bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Seonghyeon Nam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Image Synthesis via Adversarial Learning 2 | 3 | This is a PyTorch implementation of the paper [Semantic Image Synthesis via Adversarial Learning](https://arxiv.org/abs/1707.06873). 4 | 5 |  6 | 7 | ## Requirements 8 | - [PyTorch](https://github.com/pytorch/pytorch) 0.2 9 | - [Torchvision](https://github.com/pytorch/vision) 10 | - [Pillow](https://pillow.readthedocs.io/en/4.2.x/) 11 | - [fastText.py](https://github.com/salestock/fastText.py) (Note: if you have a problem when loading a pretrained model, try [my fixed code](https://github.com/woozzu/fastText.py/tree/feature/udpate-fasttext-to-f24a781-fix)) 12 | - [NLTK](http://www.nltk.org) 13 | 14 | ## Pretrained word vectors for fastText 15 | Download a pretrained [English](https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip) word vectors. You can see the list of pretrained vectors on [this page](https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md). 16 | 17 | ## Datasets 18 | - Oxford-102 flowers: [images](http://www.robots.ox.ac.uk/~vgg/data/flowers/102) and [captions](https://drive.google.com/file/d/0B0ywwgffWnLLMl9uOU91MV80cVU/view?usp=sharing) 19 | - Caltech-200 birds: [images](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and [captions](https://drive.google.com/file/d/0B0ywwgffWnLLLUc2WHYzM0Q2eWc/view?usp=sharing) 20 | 21 | The caption data is from [this repository](https://github.com/reedscot/icml2016). After downloading, modify `CONFIG` file so that all paths of the datasets point to the data you downloaded. 22 | 23 | ## Run 24 | - `scripts/train_text_embedding_[birds/flowers].sh` 25 | Train a visual-semantic embedding model using the method of [Kiros et al.](https://arxiv.org/abs/1411.2539). 26 | - `scripts/train_[birds/flowers].sh` 27 | Train a GAN using a pretrained text embedding model. 28 | - `scripts/test_[birds/flowers].sh` 29 | Generate some examples using original images and semantically relevant texts. 30 | 31 | ## Results 32 |  33 | 34 |  35 | 36 | ## Acknowledgements 37 | - [Text to image synthesis](https://github.com/reedscot/icml2016) 38 | - [StackGAN](https://github.com/hanzhanggit/StackGAN) 39 | 40 | We would like to thank Hao Dong, who is one of the first authors of the paper [Semantic Image Synthesis via Adversarial Learning](https://arxiv.org/abs/1707.06873), for providing helpful advice for the implementation. -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import nltk 6 | from nltk.tokenize import RegexpTokenizer 7 | 8 | import torch 9 | import torch.utils.data as data 10 | from torch.utils.serialization import load_lua 11 | import torchvision.transforms as transforms 12 | 13 | 14 | def split_sentence_into_words(sentence): 15 | tokenizer = RegexpTokenizer(r'\w+') 16 | return tokenizer.tokenize(sentence.lower()) 17 | 18 | 19 | class ReedICML2016(data.Dataset): 20 | def __init__(self, img_root, caption_root, classes_fllename, 21 | word_embedding, max_word_length, img_transform=None): 22 | super(ReedICML2016, self).__init__() 23 | self.alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} " 24 | 25 | self.max_word_length = max_word_length 26 | self.img_transform = img_transform 27 | 28 | if self.img_transform == None: 29 | self.img_transform = transforms.ToTensor() 30 | 31 | self.data = self._load_dataset(img_root, caption_root, classes_fllename, word_embedding) 32 | 33 | def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding): 34 | output = [] 35 | with open(os.path.join(caption_root, classes_filename)) as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | cls = line.replace('\n', '') 39 | filenames = os.listdir(os.path.join(caption_root, cls)) 40 | for filename in filenames: 41 | datum = load_lua(os.path.join(caption_root, cls, filename)) 42 | raw_desc = datum['char'].numpy() 43 | desc, len_desc = self._get_word_vectors(raw_desc, word_embedding) 44 | output.append({ 45 | 'img': os.path.join(img_root, datum['img']), 46 | 'desc': desc, 47 | 'len_desc': len_desc 48 | }) 49 | return output 50 | 51 | def _get_word_vectors(self, desc, word_embedding): 52 | output = [] 53 | len_desc = [] 54 | for i in range(desc.shape[1]): 55 | words = self._nums2chars(desc[:, i]) 56 | words = split_sentence_into_words(words) 57 | word_vecs = torch.Tensor([word_embedding[w] for w in words]) 58 | # zero padding 59 | if len(words) < self.max_word_length: 60 | word_vecs = torch.cat(( 61 | word_vecs, 62 | torch.zeros(self.max_word_length - len(words), word_vecs.size(1)) 63 | )) 64 | output.append(word_vecs) 65 | len_desc.append(len(words)) 66 | return torch.stack(output), len_desc 67 | 68 | def _nums2chars(self, nums): 69 | chars = '' 70 | for num in nums: 71 | chars += self.alphabet[num - 1] 72 | return chars 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def __getitem__(self, index): 78 | datum = self.data[index] 79 | img = Image.open(datum['img']) 80 | img = self.img_transform(img) 81 | if img.size(0) == 1: 82 | img = img.repeat(3, 1, 1) 83 | desc = datum['desc'] 84 | len_desc = datum['len_desc'] 85 | # randomly select one sentence 86 | selected = np.random.choice(desc.size(0)) 87 | desc = desc[selected, ...] 88 | len_desc = len_desc[selected] 89 | return img, desc, len_desc 90 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/architecture.png -------------------------------------------------------------------------------- /images/results_birds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/results_birds.png -------------------------------------------------------------------------------- /images/results_flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/results_flowers.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import torchvision.models as models 7 | 8 | 9 | def init_weights(m): 10 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 11 | if m.weight.requires_grad: 12 | m.weight.data.normal_(std=0.02) 13 | if m.bias is not None and m.bias.requires_grad: 14 | m.bias.data.fill_(0) 15 | elif isinstance(m, nn.BatchNorm2d) and m.affine: 16 | if m.weight.requires_grad: 17 | m.weight.data.normal_(1, 0.02) 18 | if m.bias.requires_grad: 19 | m.bias.data.fill_(0) 20 | 21 | 22 | class VisualSemanticEmbedding(nn.Module): 23 | def __init__(self, embed_ndim): 24 | super(VisualSemanticEmbedding, self).__init__() 25 | self.embed_ndim = embed_ndim 26 | 27 | # image feature 28 | self.img_encoder = models.vgg16(pretrained=True) 29 | for param in self.img_encoder.parameters(): 30 | param.requires_grad = False 31 | self.feat_extractor = nn.Sequential(*(self.img_encoder.classifier[i] for i in range(6))) 32 | self.W = nn.Linear(4096, embed_ndim, False) 33 | 34 | # text feature 35 | self.txt_encoder = nn.GRU(embed_ndim, embed_ndim, 1) 36 | 37 | def forward(self, img, txt): 38 | # image feature 39 | img_feat = self.img_encoder.features(img) 40 | img_feat = img_feat.view(img_feat.size(0), -1) 41 | img_feat = self.feat_extractor(img_feat) 42 | img_feat = self.W(img_feat) 43 | 44 | # text feature 45 | h0 = torch.zeros(1, img.size(0), self.embed_ndim) 46 | h0 = Variable(h0.cuda() if txt.data.is_cuda else h0) 47 | _, txt_feat = self.txt_encoder(txt, h0) 48 | txt_feat = txt_feat.squeeze() 49 | 50 | return img_feat, txt_feat 51 | 52 | 53 | class ResidualBlock(nn.Module): 54 | def __init__(self): 55 | super(ResidualBlock, self).__init__() 56 | 57 | self.encoder = nn.Sequential( 58 | nn.Conv2d(512, 512, 3, padding=1, bias=False), 59 | nn.BatchNorm2d(512), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(512, 512, 3, padding=1, bias=False), 62 | nn.BatchNorm2d(512) 63 | ) 64 | 65 | def forward(self, x): 66 | return F.relu(x + self.encoder(x)) 67 | 68 | 69 | class Generator(nn.Module): 70 | def __init__(self, use_vgg=True): 71 | super(Generator, self).__init__() 72 | 73 | # encoder 74 | if use_vgg: 75 | self.encoder = models.vgg16_bn(pretrained=True) 76 | self.encoder = \ 77 | nn.Sequential(*(self.encoder.features[i] for i in range(23) + range(24, 33))) 78 | self.encoder[24].dilation = (2, 2) 79 | self.encoder[24].padding = (2, 2) 80 | self.encoder[27].dilation = (2, 2) 81 | self.encoder[27].padding = (2, 2) 82 | self.encoder[30].dilation = (2, 2) 83 | self.encoder[30].padding = (2, 2) 84 | for param in self.encoder.parameters(): 85 | param.requires_grad = False 86 | self.encoder.eval() 87 | else: 88 | self.encoder = nn.Sequential( 89 | nn.Conv2d(3, 128, 3, padding=1), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(128, 256, 4, 2, padding=1, bias=False), 92 | nn.BatchNorm2d(256), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(256, 512, 4, 2, padding=1, bias=False), 95 | nn.BatchNorm2d(512), 96 | nn.ReLU(inplace=True) 97 | ) 98 | 99 | # residual blocks 100 | self.residual_blocks = nn.Sequential( 101 | nn.Conv2d(512 + 128, 512, 3, padding=1, bias=False), 102 | nn.BatchNorm2d(512), 103 | ResidualBlock(), 104 | ResidualBlock(), 105 | ResidualBlock(), 106 | ResidualBlock() 107 | ) 108 | 109 | # decoder 110 | self.decoder = nn.Sequential( 111 | nn.Upsample(scale_factor=2, mode='nearest'), 112 | nn.Conv2d(512, 256, 3, padding=1, bias=False), 113 | nn.BatchNorm2d(256), 114 | nn.ReLU(inplace=True), 115 | nn.Upsample(scale_factor=2, mode='nearest'), 116 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 117 | nn.BatchNorm2d(128), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(128, 3, 3, padding=1), 120 | nn.Tanh() 121 | ) 122 | 123 | # conditioning augmentation 124 | self.mu = nn.Sequential( 125 | nn.Linear(300, 128, bias=False), 126 | nn.LeakyReLU(0.2, inplace=True) 127 | ) 128 | self.log_sigma = nn.Sequential( 129 | nn.Linear(300, 128, bias=False), 130 | nn.LeakyReLU(0.2, inplace=True) 131 | ) 132 | 133 | self.apply(init_weights) 134 | 135 | def forward(self, img, txt_feat, z=None): 136 | # encoder 137 | img_feat = self.encoder(img) 138 | z_mean = self.mu(txt_feat) 139 | z_log_stddev = self.log_sigma(txt_feat) 140 | z = torch.randn(txt_feat.size(0), 128) 141 | if next(self.parameters()).is_cuda: 142 | z = z.cuda() 143 | txt_feat = z_mean + z_log_stddev.exp() * Variable(z) 144 | 145 | # residual blocks 146 | txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1) 147 | txt_feat = txt_feat.repeat(1, 1, img_feat.size(2), img_feat.size(3)) 148 | fusion = torch.cat((img_feat, txt_feat), dim=1) 149 | fusion = self.residual_blocks(fusion) 150 | 151 | # decoder 152 | output = self.decoder(fusion) 153 | return output, (z_mean, z_log_stddev) 154 | 155 | 156 | class Discriminator(nn.Module): 157 | def __init__(self): 158 | super(Discriminator, self).__init__() 159 | 160 | self.encoder = nn.Sequential( 161 | nn.Conv2d(3, 64, 4, 2, padding=1), 162 | nn.LeakyReLU(0.2, inplace=True), 163 | nn.Conv2d(64, 128, 4, 2, padding=1, bias=False), 164 | nn.BatchNorm2d(128), 165 | nn.LeakyReLU(0.2, inplace=True), 166 | nn.Conv2d(128, 256, 4, 2, padding=1, bias=False), 167 | nn.BatchNorm2d(256), 168 | nn.LeakyReLU(0.2, inplace=True), 169 | nn.Conv2d(256, 512, 4, 2, padding=1, bias=False), 170 | nn.BatchNorm2d(512) 171 | ) 172 | 173 | self.residual_branch = nn.Sequential( 174 | nn.Conv2d(512, 128, 1, bias=False), 175 | nn.BatchNorm2d(128), 176 | nn.LeakyReLU(0.2, inplace=True), 177 | nn.Conv2d(128, 128, 3, padding=1, bias=False), 178 | nn.BatchNorm2d(128), 179 | nn.LeakyReLU(0.2, inplace=True), 180 | nn.Conv2d(128, 512, 3, padding=1, bias=False), 181 | nn.BatchNorm2d(512) 182 | ) 183 | 184 | self.classifier = nn.Sequential( 185 | nn.Conv2d(512 + 128, 512, 1, bias=False), 186 | nn.BatchNorm2d(512), 187 | nn.LeakyReLU(0.2, inplace=True), 188 | nn.Conv2d(512, 1, 4) 189 | ) 190 | 191 | self.compression = nn.Sequential( 192 | nn.Linear(300, 128), 193 | nn.LeakyReLU(0.2, inplace=True) 194 | ) 195 | 196 | self.apply(init_weights) 197 | 198 | def forward(self, img, txt_feat): 199 | img_feat = self.encoder(img) 200 | img_feat = F.leaky_relu(img_feat + self.residual_branch(img_feat), 0.2) 201 | txt_feat = self.compression(txt_feat) 202 | 203 | txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1) 204 | txt_feat = txt_feat.repeat(1, 1, img_feat.size(2), img_feat.size(3)) 205 | fusion = torch.cat((img_feat, txt_feat), dim=1) 206 | output = self.classifier(fusion) 207 | return output.squeeze() 208 | -------------------------------------------------------------------------------- /scripts/test_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python test.py \ 4 | --img_root ./test/birds \ 5 | --text_file ./test/text_birds.txt \ 6 | --fasttext_model ${FASTTEXT_MODEL} \ 7 | --text_embedding_model ./models/text_embedding_birds.pth \ 8 | --generator_model ./models/birds.pth \ 9 | --output_root ./test/result_birds \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/test_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python test.py \ 4 | --img_root ./test/flowers \ 5 | --text_file ./test/text_flowers.txt \ 6 | --fasttext_model ${FASTTEXT_MODEL} \ 7 | --text_embedding_model ./models/text_embedding_flowers.pth \ 8 | --generator_model ./models/flowers.pth \ 9 | --output_root ./test/result_flowers \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train.py \ 4 | --img_root ${BIRDS_IMG_ROOT} \ 5 | --caption_root ${BIRDS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --text_embedding_model ./models/text_embedding_birds.pth \ 9 | --save_filename ./models/birds.pth \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train.py \ 4 | --img_root ${FLOWERS_IMG_ROOT} \ 5 | --caption_root ${FLOWERS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --text_embedding_model ./models/text_embedding_flowers.pth \ 9 | --save_filename ./models/flowers.pth \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_text_embedding_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train_text_embedding.py \ 4 | --img_root ${BIRDS_IMG_ROOT} \ 5 | --caption_root ${BIRDS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --save_filename ./models/text_embedding_birds.pth -------------------------------------------------------------------------------- /scripts/train_text_embedding_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train_text_embedding.py \ 4 | --img_root ${FLOWERS_IMG_ROOT} \ 5 | --caption_root ${FLOWERS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --save_filename ./models/text_embedding_flowers.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import fasttext 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | import torchvision.transforms as transforms 9 | from torchvision.utils import save_image 10 | 11 | from model import VisualSemanticEmbedding 12 | from model import Generator 13 | from data import split_sentence_into_words 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--img_root', type=str, required=True, 18 | help='root directory that contains images') 19 | parser.add_argument('--text_file', type=str, required=True, 20 | help='text file that contains descriptions') 21 | parser.add_argument('--fasttext_model', type=str, required=True, 22 | help='pretrained fastText model (binary file)') 23 | parser.add_argument('--text_embedding_model', type=str, required=True, 24 | help='pretrained text embedding model') 25 | parser.add_argument('--embed_ndim', type=int, default=300, 26 | help='dimension of embedded vector (default: 300)') 27 | parser.add_argument('--generator_model', type=str, required=True, 28 | help='pretrained generator model') 29 | parser.add_argument('--use_vgg', action='store_true', 30 | help='use pretrained VGG network for image encoder') 31 | parser.add_argument('--output_root', type=str, required=True, 32 | help='root directory of output') 33 | parser.add_argument('--no_cuda', action='store_true', 34 | help='do not use cuda') 35 | args = parser.parse_args() 36 | 37 | if not args.no_cuda and not torch.cuda.is_available(): 38 | print('Warning: cuda is not available on this machine.') 39 | args.no_cuda = True 40 | 41 | 42 | if __name__ == '__main__': 43 | print('Loading a pretrained fastText model...') 44 | word_embedding = fasttext.load_model(args.fasttext_model) 45 | 46 | print('Loading a pretrained model...') 47 | 48 | txt_encoder = VisualSemanticEmbedding(args.embed_ndim) 49 | txt_encoder.load_state_dict(torch.load(args.text_embedding_model)) 50 | txt_encoder = txt_encoder.txt_encoder 51 | 52 | G = Generator(use_vgg=args.use_vgg) 53 | G.load_state_dict(torch.load(args.generator_model)) 54 | G.train(False) 55 | 56 | if not args.no_cuda: 57 | txt_encoder.cuda() 58 | G.cuda() 59 | 60 | transform = transforms.Compose([ 61 | transforms.Scale(74), 62 | transforms.CenterCrop(64), 63 | transforms.ToTensor() 64 | ]) 65 | 66 | vgg_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 67 | std=[0.229, 0.224, 0.225]) 68 | 69 | print('Loading test data...') 70 | filenames = os.listdir(args.img_root) 71 | img = [] 72 | for fn in filenames: 73 | im = Image.open(os.path.join(args.img_root, fn)) 74 | im = transform(im) 75 | img.append(im) 76 | img = torch.stack(img) 77 | save_image(img, os.path.join(args.output_root, 'original.jpg')) 78 | img = vgg_normalize(img) if args.use_vgg else img * 2 - 1 79 | img = Variable(img.cuda() if not args.no_cuda else img, volatile=True) 80 | 81 | html = '
Description | Image |
ORIGINAL | |
{} |