├── .gitignore ├── CONFIG ├── LICENSE ├── README.md ├── data.py ├── examples └── .gitignore ├── images ├── architecture.png ├── results_birds.png └── results_flowers.png ├── model.py ├── scripts ├── test_birds.sh ├── test_flowers.sh ├── train_birds.sh ├── train_flowers.sh ├── train_text_embedding_birds.sh └── train_text_embedding_flowers.sh ├── test.py ├── test ├── birds │ ├── Black_Billed_Cuckoo_0055_26223.jpg │ ├── Black_Footed_Albatross_0060_796076.jpg │ ├── Brewer_Blackbird_0112_2340.jpg │ ├── Horned_Grebe_0088_35023.jpg │ ├── Northern_Flicker_0100_28898.jpg │ ├── Red_Legged_Kittiwake_0060_795414.jpg │ └── Tree_Swallow_0107_136223.jpg ├── flowers │ ├── image_03828.jpg │ ├── image_03996.jpg │ ├── image_04899.jpg │ ├── image_05151.jpg │ ├── image_06615.jpg │ ├── image_06734.jpg │ └── image_07200.jpg ├── result_birds │ └── .gitignore ├── result_flowers │ └── .gitignore ├── text_birds.txt └── text_flowers.txt ├── train.py └── train_text_embedding.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store -------------------------------------------------------------------------------- /CONFIG: -------------------------------------------------------------------------------- 1 | BIRDS_IMG_ROOT=/Users/seonghyeon/data/data/CUB_200_2011/CUB_200_2011/images 2 | BIRDS_CAPTION_ROOT=/Users/seonghyeon/data/data/CUB_200_2011/cub_icml 3 | 4 | FLOWERS_IMG_ROOT=/Users/seonghyeon/data/data/oxford102 5 | FLOWERS_CAPTION_ROOT=/Users/seonghyeon/data/data/oxford102 6 | 7 | FASTTEXT_MODEL=/Users/seonghyeon/data/data/fastText/wiki.en/wiki.en.bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Seonghyeon Nam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Image Synthesis via Adversarial Learning 2 | 3 | This is a PyTorch implementation of the paper [Semantic Image Synthesis via Adversarial Learning](https://arxiv.org/abs/1707.06873). 4 | 5 | ![Model architecture](images/architecture.png) 6 | 7 | ## Requirements 8 | - [PyTorch](https://github.com/pytorch/pytorch) 0.2 9 | - [Torchvision](https://github.com/pytorch/vision) 10 | - [Pillow](https://pillow.readthedocs.io/en/4.2.x/) 11 | - [fastText.py](https://github.com/salestock/fastText.py) (Note: if you have a problem when loading a pretrained model, try [my fixed code](https://github.com/woozzu/fastText.py/tree/feature/udpate-fasttext-to-f24a781-fix)) 12 | - [NLTK](http://www.nltk.org) 13 | 14 | ## Pretrained word vectors for fastText 15 | Download a pretrained [English](https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip) word vectors. You can see the list of pretrained vectors on [this page](https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md). 16 | 17 | ## Datasets 18 | - Oxford-102 flowers: [images](http://www.robots.ox.ac.uk/~vgg/data/flowers/102) and [captions](https://drive.google.com/file/d/0B0ywwgffWnLLMl9uOU91MV80cVU/view?usp=sharing) 19 | - Caltech-200 birds: [images](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and [captions](https://drive.google.com/file/d/0B0ywwgffWnLLLUc2WHYzM0Q2eWc/view?usp=sharing) 20 | 21 | The caption data is from [this repository](https://github.com/reedscot/icml2016). After downloading, modify `CONFIG` file so that all paths of the datasets point to the data you downloaded. 22 | 23 | ## Run 24 | - `scripts/train_text_embedding_[birds/flowers].sh` 25 | Train a visual-semantic embedding model using the method of [Kiros et al.](https://arxiv.org/abs/1411.2539). 26 | - `scripts/train_[birds/flowers].sh` 27 | Train a GAN using a pretrained text embedding model. 28 | - `scripts/test_[birds/flowers].sh` 29 | Generate some examples using original images and semantically relevant texts. 30 | 31 | ## Results 32 | ![Flowers](images/results_flowers.png) 33 | 34 | ![Birds](images/results_birds.png) 35 | 36 | ## Acknowledgements 37 | - [Text to image synthesis](https://github.com/reedscot/icml2016) 38 | - [StackGAN](https://github.com/hanzhanggit/StackGAN) 39 | 40 | We would like to thank Hao Dong, who is one of the first authors of the paper [Semantic Image Synthesis via Adversarial Learning](https://arxiv.org/abs/1707.06873), for providing helpful advice for the implementation. -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import nltk 6 | from nltk.tokenize import RegexpTokenizer 7 | 8 | import torch 9 | import torch.utils.data as data 10 | from torch.utils.serialization import load_lua 11 | import torchvision.transforms as transforms 12 | 13 | 14 | def split_sentence_into_words(sentence): 15 | tokenizer = RegexpTokenizer(r'\w+') 16 | return tokenizer.tokenize(sentence.lower()) 17 | 18 | 19 | class ReedICML2016(data.Dataset): 20 | def __init__(self, img_root, caption_root, classes_fllename, 21 | word_embedding, max_word_length, img_transform=None): 22 | super(ReedICML2016, self).__init__() 23 | self.alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} " 24 | 25 | self.max_word_length = max_word_length 26 | self.img_transform = img_transform 27 | 28 | if self.img_transform == None: 29 | self.img_transform = transforms.ToTensor() 30 | 31 | self.data = self._load_dataset(img_root, caption_root, classes_fllename, word_embedding) 32 | 33 | def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding): 34 | output = [] 35 | with open(os.path.join(caption_root, classes_filename)) as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | cls = line.replace('\n', '') 39 | filenames = os.listdir(os.path.join(caption_root, cls)) 40 | for filename in filenames: 41 | datum = load_lua(os.path.join(caption_root, cls, filename)) 42 | raw_desc = datum['char'].numpy() 43 | desc, len_desc = self._get_word_vectors(raw_desc, word_embedding) 44 | output.append({ 45 | 'img': os.path.join(img_root, datum['img']), 46 | 'desc': desc, 47 | 'len_desc': len_desc 48 | }) 49 | return output 50 | 51 | def _get_word_vectors(self, desc, word_embedding): 52 | output = [] 53 | len_desc = [] 54 | for i in range(desc.shape[1]): 55 | words = self._nums2chars(desc[:, i]) 56 | words = split_sentence_into_words(words) 57 | word_vecs = torch.Tensor([word_embedding[w] for w in words]) 58 | # zero padding 59 | if len(words) < self.max_word_length: 60 | word_vecs = torch.cat(( 61 | word_vecs, 62 | torch.zeros(self.max_word_length - len(words), word_vecs.size(1)) 63 | )) 64 | output.append(word_vecs) 65 | len_desc.append(len(words)) 66 | return torch.stack(output), len_desc 67 | 68 | def _nums2chars(self, nums): 69 | chars = '' 70 | for num in nums: 71 | chars += self.alphabet[num - 1] 72 | return chars 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def __getitem__(self, index): 78 | datum = self.data[index] 79 | img = Image.open(datum['img']) 80 | img = self.img_transform(img) 81 | if img.size(0) == 1: 82 | img = img.repeat(3, 1, 1) 83 | desc = datum['desc'] 84 | len_desc = datum['len_desc'] 85 | # randomly select one sentence 86 | selected = np.random.choice(desc.size(0)) 87 | desc = desc[selected, ...] 88 | len_desc = len_desc[selected] 89 | return img, desc, len_desc 90 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/architecture.png -------------------------------------------------------------------------------- /images/results_birds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/results_birds.png -------------------------------------------------------------------------------- /images/results_flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/images/results_flowers.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import torchvision.models as models 7 | 8 | 9 | def init_weights(m): 10 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 11 | if m.weight.requires_grad: 12 | m.weight.data.normal_(std=0.02) 13 | if m.bias is not None and m.bias.requires_grad: 14 | m.bias.data.fill_(0) 15 | elif isinstance(m, nn.BatchNorm2d) and m.affine: 16 | if m.weight.requires_grad: 17 | m.weight.data.normal_(1, 0.02) 18 | if m.bias.requires_grad: 19 | m.bias.data.fill_(0) 20 | 21 | 22 | class VisualSemanticEmbedding(nn.Module): 23 | def __init__(self, embed_ndim): 24 | super(VisualSemanticEmbedding, self).__init__() 25 | self.embed_ndim = embed_ndim 26 | 27 | # image feature 28 | self.img_encoder = models.vgg16(pretrained=True) 29 | for param in self.img_encoder.parameters(): 30 | param.requires_grad = False 31 | self.feat_extractor = nn.Sequential(*(self.img_encoder.classifier[i] for i in range(6))) 32 | self.W = nn.Linear(4096, embed_ndim, False) 33 | 34 | # text feature 35 | self.txt_encoder = nn.GRU(embed_ndim, embed_ndim, 1) 36 | 37 | def forward(self, img, txt): 38 | # image feature 39 | img_feat = self.img_encoder.features(img) 40 | img_feat = img_feat.view(img_feat.size(0), -1) 41 | img_feat = self.feat_extractor(img_feat) 42 | img_feat = self.W(img_feat) 43 | 44 | # text feature 45 | h0 = torch.zeros(1, img.size(0), self.embed_ndim) 46 | h0 = Variable(h0.cuda() if txt.data.is_cuda else h0) 47 | _, txt_feat = self.txt_encoder(txt, h0) 48 | txt_feat = txt_feat.squeeze() 49 | 50 | return img_feat, txt_feat 51 | 52 | 53 | class ResidualBlock(nn.Module): 54 | def __init__(self): 55 | super(ResidualBlock, self).__init__() 56 | 57 | self.encoder = nn.Sequential( 58 | nn.Conv2d(512, 512, 3, padding=1, bias=False), 59 | nn.BatchNorm2d(512), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(512, 512, 3, padding=1, bias=False), 62 | nn.BatchNorm2d(512) 63 | ) 64 | 65 | def forward(self, x): 66 | return F.relu(x + self.encoder(x)) 67 | 68 | 69 | class Generator(nn.Module): 70 | def __init__(self, use_vgg=True): 71 | super(Generator, self).__init__() 72 | 73 | # encoder 74 | if use_vgg: 75 | self.encoder = models.vgg16_bn(pretrained=True) 76 | self.encoder = \ 77 | nn.Sequential(*(self.encoder.features[i] for i in range(23) + range(24, 33))) 78 | self.encoder[24].dilation = (2, 2) 79 | self.encoder[24].padding = (2, 2) 80 | self.encoder[27].dilation = (2, 2) 81 | self.encoder[27].padding = (2, 2) 82 | self.encoder[30].dilation = (2, 2) 83 | self.encoder[30].padding = (2, 2) 84 | for param in self.encoder.parameters(): 85 | param.requires_grad = False 86 | self.encoder.eval() 87 | else: 88 | self.encoder = nn.Sequential( 89 | nn.Conv2d(3, 128, 3, padding=1), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(128, 256, 4, 2, padding=1, bias=False), 92 | nn.BatchNorm2d(256), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(256, 512, 4, 2, padding=1, bias=False), 95 | nn.BatchNorm2d(512), 96 | nn.ReLU(inplace=True) 97 | ) 98 | 99 | # residual blocks 100 | self.residual_blocks = nn.Sequential( 101 | nn.Conv2d(512 + 128, 512, 3, padding=1, bias=False), 102 | nn.BatchNorm2d(512), 103 | ResidualBlock(), 104 | ResidualBlock(), 105 | ResidualBlock(), 106 | ResidualBlock() 107 | ) 108 | 109 | # decoder 110 | self.decoder = nn.Sequential( 111 | nn.Upsample(scale_factor=2, mode='nearest'), 112 | nn.Conv2d(512, 256, 3, padding=1, bias=False), 113 | nn.BatchNorm2d(256), 114 | nn.ReLU(inplace=True), 115 | nn.Upsample(scale_factor=2, mode='nearest'), 116 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 117 | nn.BatchNorm2d(128), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(128, 3, 3, padding=1), 120 | nn.Tanh() 121 | ) 122 | 123 | # conditioning augmentation 124 | self.mu = nn.Sequential( 125 | nn.Linear(300, 128, bias=False), 126 | nn.LeakyReLU(0.2, inplace=True) 127 | ) 128 | self.log_sigma = nn.Sequential( 129 | nn.Linear(300, 128, bias=False), 130 | nn.LeakyReLU(0.2, inplace=True) 131 | ) 132 | 133 | self.apply(init_weights) 134 | 135 | def forward(self, img, txt_feat, z=None): 136 | # encoder 137 | img_feat = self.encoder(img) 138 | z_mean = self.mu(txt_feat) 139 | z_log_stddev = self.log_sigma(txt_feat) 140 | z = torch.randn(txt_feat.size(0), 128) 141 | if next(self.parameters()).is_cuda: 142 | z = z.cuda() 143 | txt_feat = z_mean + z_log_stddev.exp() * Variable(z) 144 | 145 | # residual blocks 146 | txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1) 147 | txt_feat = txt_feat.repeat(1, 1, img_feat.size(2), img_feat.size(3)) 148 | fusion = torch.cat((img_feat, txt_feat), dim=1) 149 | fusion = self.residual_blocks(fusion) 150 | 151 | # decoder 152 | output = self.decoder(fusion) 153 | return output, (z_mean, z_log_stddev) 154 | 155 | 156 | class Discriminator(nn.Module): 157 | def __init__(self): 158 | super(Discriminator, self).__init__() 159 | 160 | self.encoder = nn.Sequential( 161 | nn.Conv2d(3, 64, 4, 2, padding=1), 162 | nn.LeakyReLU(0.2, inplace=True), 163 | nn.Conv2d(64, 128, 4, 2, padding=1, bias=False), 164 | nn.BatchNorm2d(128), 165 | nn.LeakyReLU(0.2, inplace=True), 166 | nn.Conv2d(128, 256, 4, 2, padding=1, bias=False), 167 | nn.BatchNorm2d(256), 168 | nn.LeakyReLU(0.2, inplace=True), 169 | nn.Conv2d(256, 512, 4, 2, padding=1, bias=False), 170 | nn.BatchNorm2d(512) 171 | ) 172 | 173 | self.residual_branch = nn.Sequential( 174 | nn.Conv2d(512, 128, 1, bias=False), 175 | nn.BatchNorm2d(128), 176 | nn.LeakyReLU(0.2, inplace=True), 177 | nn.Conv2d(128, 128, 3, padding=1, bias=False), 178 | nn.BatchNorm2d(128), 179 | nn.LeakyReLU(0.2, inplace=True), 180 | nn.Conv2d(128, 512, 3, padding=1, bias=False), 181 | nn.BatchNorm2d(512) 182 | ) 183 | 184 | self.classifier = nn.Sequential( 185 | nn.Conv2d(512 + 128, 512, 1, bias=False), 186 | nn.BatchNorm2d(512), 187 | nn.LeakyReLU(0.2, inplace=True), 188 | nn.Conv2d(512, 1, 4) 189 | ) 190 | 191 | self.compression = nn.Sequential( 192 | nn.Linear(300, 128), 193 | nn.LeakyReLU(0.2, inplace=True) 194 | ) 195 | 196 | self.apply(init_weights) 197 | 198 | def forward(self, img, txt_feat): 199 | img_feat = self.encoder(img) 200 | img_feat = F.leaky_relu(img_feat + self.residual_branch(img_feat), 0.2) 201 | txt_feat = self.compression(txt_feat) 202 | 203 | txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1) 204 | txt_feat = txt_feat.repeat(1, 1, img_feat.size(2), img_feat.size(3)) 205 | fusion = torch.cat((img_feat, txt_feat), dim=1) 206 | output = self.classifier(fusion) 207 | return output.squeeze() 208 | -------------------------------------------------------------------------------- /scripts/test_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python test.py \ 4 | --img_root ./test/birds \ 5 | --text_file ./test/text_birds.txt \ 6 | --fasttext_model ${FASTTEXT_MODEL} \ 7 | --text_embedding_model ./models/text_embedding_birds.pth \ 8 | --generator_model ./models/birds.pth \ 9 | --output_root ./test/result_birds \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/test_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python test.py \ 4 | --img_root ./test/flowers \ 5 | --text_file ./test/text_flowers.txt \ 6 | --fasttext_model ${FASTTEXT_MODEL} \ 7 | --text_embedding_model ./models/text_embedding_flowers.pth \ 8 | --generator_model ./models/flowers.pth \ 9 | --output_root ./test/result_flowers \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train.py \ 4 | --img_root ${BIRDS_IMG_ROOT} \ 5 | --caption_root ${BIRDS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --text_embedding_model ./models/text_embedding_birds.pth \ 9 | --save_filename ./models/birds.pth \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train.py \ 4 | --img_root ${FLOWERS_IMG_ROOT} \ 5 | --caption_root ${FLOWERS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --text_embedding_model ./models/text_embedding_flowers.pth \ 9 | --save_filename ./models/flowers.pth \ 10 | --use_vgg -------------------------------------------------------------------------------- /scripts/train_text_embedding_birds.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train_text_embedding.py \ 4 | --img_root ${BIRDS_IMG_ROOT} \ 5 | --caption_root ${BIRDS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --save_filename ./models/text_embedding_birds.pth -------------------------------------------------------------------------------- /scripts/train_text_embedding_flowers.sh: -------------------------------------------------------------------------------- 1 | . CONFIG 2 | 3 | python train_text_embedding.py \ 4 | --img_root ${FLOWERS_IMG_ROOT} \ 5 | --caption_root ${FLOWERS_CAPTION_ROOT} \ 6 | --trainclasses_file trainvalclasses.txt \ 7 | --fasttext_model ${FASTTEXT_MODEL} \ 8 | --save_filename ./models/text_embedding_flowers.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import fasttext 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | import torchvision.transforms as transforms 9 | from torchvision.utils import save_image 10 | 11 | from model import VisualSemanticEmbedding 12 | from model import Generator 13 | from data import split_sentence_into_words 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--img_root', type=str, required=True, 18 | help='root directory that contains images') 19 | parser.add_argument('--text_file', type=str, required=True, 20 | help='text file that contains descriptions') 21 | parser.add_argument('--fasttext_model', type=str, required=True, 22 | help='pretrained fastText model (binary file)') 23 | parser.add_argument('--text_embedding_model', type=str, required=True, 24 | help='pretrained text embedding model') 25 | parser.add_argument('--embed_ndim', type=int, default=300, 26 | help='dimension of embedded vector (default: 300)') 27 | parser.add_argument('--generator_model', type=str, required=True, 28 | help='pretrained generator model') 29 | parser.add_argument('--use_vgg', action='store_true', 30 | help='use pretrained VGG network for image encoder') 31 | parser.add_argument('--output_root', type=str, required=True, 32 | help='root directory of output') 33 | parser.add_argument('--no_cuda', action='store_true', 34 | help='do not use cuda') 35 | args = parser.parse_args() 36 | 37 | if not args.no_cuda and not torch.cuda.is_available(): 38 | print('Warning: cuda is not available on this machine.') 39 | args.no_cuda = True 40 | 41 | 42 | if __name__ == '__main__': 43 | print('Loading a pretrained fastText model...') 44 | word_embedding = fasttext.load_model(args.fasttext_model) 45 | 46 | print('Loading a pretrained model...') 47 | 48 | txt_encoder = VisualSemanticEmbedding(args.embed_ndim) 49 | txt_encoder.load_state_dict(torch.load(args.text_embedding_model)) 50 | txt_encoder = txt_encoder.txt_encoder 51 | 52 | G = Generator(use_vgg=args.use_vgg) 53 | G.load_state_dict(torch.load(args.generator_model)) 54 | G.train(False) 55 | 56 | if not args.no_cuda: 57 | txt_encoder.cuda() 58 | G.cuda() 59 | 60 | transform = transforms.Compose([ 61 | transforms.Scale(74), 62 | transforms.CenterCrop(64), 63 | transforms.ToTensor() 64 | ]) 65 | 66 | vgg_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 67 | std=[0.229, 0.224, 0.225]) 68 | 69 | print('Loading test data...') 70 | filenames = os.listdir(args.img_root) 71 | img = [] 72 | for fn in filenames: 73 | im = Image.open(os.path.join(args.img_root, fn)) 74 | im = transform(im) 75 | img.append(im) 76 | img = torch.stack(img) 77 | save_image(img, os.path.join(args.output_root, 'original.jpg')) 78 | img = vgg_normalize(img) if args.use_vgg else img * 2 - 1 79 | img = Variable(img.cuda() if not args.no_cuda else img, volatile=True) 80 | 81 | html = '

Manipulated Images

' 82 | html += '\n'.format('original.jpg') 83 | with open(args.text_file, 'r') as f: 84 | texts = f.readlines() 85 | 86 | for i, txt in enumerate(texts): 87 | txt = txt.replace('\n', '') 88 | desc = split_sentence_into_words(txt) 89 | desc = torch.Tensor([word_embedding[w] for w in desc]) 90 | desc = desc.unsqueeze(1) 91 | desc = desc.repeat(1, img.size(0), 1) 92 | desc = Variable(desc.cuda() if not args.no_cuda else desc, volatile=True) 93 | 94 | _, txt_feat = txt_encoder(desc) 95 | txt_feat = txt_feat.squeeze(0) 96 | output, _ = G(img, txt_feat) 97 | 98 | out_filename = 'output_%d.jpg' % i 99 | save_image((output.data + 1) * 0.5, os.path.join(args.output_root, out_filename)) 100 | html += '\n'.format(txt, out_filename) 101 | 102 | with open(os.path.join(args.output_root, 'index.html'), 'w') as f: 103 | f.write(html) 104 | print('Done. The results were saved in %s.' % args.output_root) 105 | -------------------------------------------------------------------------------- /test/birds/Black_Billed_Cuckoo_0055_26223.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Black_Billed_Cuckoo_0055_26223.jpg -------------------------------------------------------------------------------- /test/birds/Black_Footed_Albatross_0060_796076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Black_Footed_Albatross_0060_796076.jpg -------------------------------------------------------------------------------- /test/birds/Brewer_Blackbird_0112_2340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Brewer_Blackbird_0112_2340.jpg -------------------------------------------------------------------------------- /test/birds/Horned_Grebe_0088_35023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Horned_Grebe_0088_35023.jpg -------------------------------------------------------------------------------- /test/birds/Northern_Flicker_0100_28898.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Northern_Flicker_0100_28898.jpg -------------------------------------------------------------------------------- /test/birds/Red_Legged_Kittiwake_0060_795414.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Red_Legged_Kittiwake_0060_795414.jpg -------------------------------------------------------------------------------- /test/birds/Tree_Swallow_0107_136223.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/birds/Tree_Swallow_0107_136223.jpg -------------------------------------------------------------------------------- /test/flowers/image_03828.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_03828.jpg -------------------------------------------------------------------------------- /test/flowers/image_03996.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_03996.jpg -------------------------------------------------------------------------------- /test/flowers/image_04899.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_04899.jpg -------------------------------------------------------------------------------- /test/flowers/image_05151.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_05151.jpg -------------------------------------------------------------------------------- /test/flowers/image_06615.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_06615.jpg -------------------------------------------------------------------------------- /test/flowers/image_06734.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_06734.jpg -------------------------------------------------------------------------------- /test/flowers/image_07200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackeyZz/NewImage-From-TextAndImage/652dd839f895b3a293de16d6b430069844c33f83/test/flowers/image_07200.jpg -------------------------------------------------------------------------------- /test/result_birds/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /test/result_flowers/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /test/text_birds.txt: -------------------------------------------------------------------------------- 1 | This small bird has a blue crown and white belly. 2 | This small yellow bird has grey wings, and a black bill. 3 | A small brown bird with a brown crown has a white belly. 4 | This black bird has no other colors with a short bill. 5 | An orange bird with green wings and blue head. 6 | A black bird with a red head. 7 | This particular bird with a red head and breast and features grey wings. -------------------------------------------------------------------------------- /test/text_flowers.txt: -------------------------------------------------------------------------------- 1 | The petals are white and the stamens are light yellow. 2 | The red flower has no visible stamens. 3 | The petals of the flower have yellow and red stripes. 4 | The light purple flower has a large number of small petals. 5 | The petals of the flower have mixed colors of bright yellow and light green. 6 | The flower shown has reddish petals with yellow edges. 7 | This flower has petals of pink and white color with yellow stamens. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import fasttext 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.data as data 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | from torchvision.utils import save_image 13 | 14 | from model import VisualSemanticEmbedding 15 | from model import Generator, Discriminator 16 | from data import ReedICML2016 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--img_root', type=str, required=True, 21 | help='root directory that contains images') 22 | parser.add_argument('--caption_root', type=str, required=True, 23 | help='root directory that contains captions') 24 | parser.add_argument('--trainclasses_file', type=str, required=True, 25 | help='text file that contains training classes') 26 | parser.add_argument('--fasttext_model', type=str, required=True, 27 | help='pretrained fastText model (binary file)') 28 | parser.add_argument('--text_embedding_model', type=str, required=True, 29 | help='pretrained text embedding model') 30 | parser.add_argument('--save_filename', type=str, required=True, 31 | help='checkpoint file') 32 | parser.add_argument('--num_threads', type=int, default=4, 33 | help='number of threads for fetching data (default: 4)') 34 | parser.add_argument('--num_epochs', type=int, default=600, 35 | help='number of threads for fetching data (default: 600)') 36 | parser.add_argument('--batch_size', type=int, default=64, 37 | help='batch size (default: 64)') 38 | parser.add_argument('--learning_rate', type=float, default=0.0002, 39 | help='learning rate (dafault: 0.0002)') 40 | parser.add_argument('--lr_decay', type=float, default=0.5, 41 | help='learning rate decay (dafault: 0.5)') 42 | parser.add_argument('--momentum', type=float, default=0.5, 43 | help='beta1 for Adam optimizer (dafault: 0.5)') 44 | parser.add_argument('--embed_ndim', type=int, default=300, 45 | help='dimension of embedded vector (default: 300)') 46 | parser.add_argument('--max_nwords', type=int, default=50, 47 | help='maximum number of words (default: 50)') 48 | parser.add_argument('--use_vgg', action='store_true', 49 | help='use pretrained VGG network for image encoder') 50 | parser.add_argument('--no_cuda', action='store_true', 51 | help='do not use cuda') 52 | args = parser.parse_args() 53 | 54 | if not args.no_cuda and not torch.cuda.is_available(): 55 | print('Warning: cuda is not available on this machine.') 56 | args.no_cuda = True 57 | 58 | 59 | def preprocess(img, desc, len_desc, txt_encoder): 60 | img = Variable(img.cuda() if not args.no_cuda else img) 61 | desc = Variable(desc.cuda() if not args.no_cuda else desc) 62 | 63 | len_desc = len_desc.numpy() 64 | sorted_indices = np.argsort(len_desc)[::-1] 65 | original_indices = np.argsort(sorted_indices) 66 | packed_desc = nn.utils.rnn.pack_padded_sequence( 67 | desc[sorted_indices, ...].transpose(0, 1), 68 | len_desc[sorted_indices] 69 | ) 70 | _, txt_feat = txt_encoder(packed_desc) 71 | txt_feat = txt_feat.squeeze() 72 | txt_feat = txt_feat[original_indices, ...] 73 | 74 | txt_feat_np = txt_feat.data.cpu().numpy() if not args.no_cuda else txt_feat.data.numpy() 75 | txt_feat_mismatch = torch.Tensor(np.roll(txt_feat_np, 1, axis=0)) 76 | txt_feat_mismatch = Variable(txt_feat_mismatch.cuda() if not args.no_cuda else txt_feat_mismatch) 77 | txt_feat_np_split = np.split(txt_feat_np, [txt_feat_np.shape[0] // 2]) 78 | txt_feat_relevant = torch.Tensor(np.concatenate([ 79 | np.roll(txt_feat_np_split[0], -1, axis=0), 80 | txt_feat_np_split[1] 81 | ])) 82 | txt_feat_relevant = Variable(txt_feat_relevant.cuda() if not args.no_cuda else txt_feat_relevant) 83 | return img, txt_feat, txt_feat_mismatch, txt_feat_relevant 84 | 85 | 86 | if __name__ == '__main__': 87 | print('Loading a pretrained fastText model...') 88 | word_embedding = fasttext.load_model(args.fasttext_model) 89 | 90 | print('Loading a dataset...') 91 | train_data = ReedICML2016(args.img_root, 92 | args.caption_root, 93 | args.trainclasses_file, 94 | word_embedding, 95 | args.max_nwords, 96 | transforms.Compose([ 97 | transforms.Scale(74), 98 | transforms.RandomCrop(64), 99 | transforms.RandomHorizontalFlip(), 100 | transforms.ToTensor() 101 | ])) 102 | vgg_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 103 | std=[0.229, 0.224, 0.225]) 104 | 105 | train_loader = data.DataLoader(train_data, 106 | batch_size=args.batch_size, 107 | shuffle=True, 108 | num_workers=args.num_threads) 109 | 110 | word_embedding = None 111 | 112 | # pretrained text embedding model 113 | print('Loading a pretrained text embedding model...') 114 | txt_encoder = VisualSemanticEmbedding(args.embed_ndim) 115 | txt_encoder.load_state_dict(torch.load(args.text_embedding_model)) 116 | txt_encoder = txt_encoder.txt_encoder 117 | for param in txt_encoder.parameters(): 118 | param.requires_grad = False 119 | 120 | G = Generator(use_vgg=args.use_vgg) 121 | D = Discriminator() 122 | 123 | if not args.no_cuda: 124 | txt_encoder.cuda() 125 | G.cuda() 126 | D.cuda() 127 | 128 | g_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, G.parameters()), 129 | lr=args.learning_rate, betas=(args.momentum, 0.999)) 130 | d_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, D.parameters()), 131 | lr=args.learning_rate, betas=(args.momentum, 0.999)) 132 | g_lr_scheduler = lr_scheduler.StepLR(g_optimizer, 100, args.lr_decay) 133 | d_lr_scheduler = lr_scheduler.StepLR(d_optimizer, 100, args.lr_decay) 134 | 135 | for epoch in range(args.num_epochs): 136 | d_lr_scheduler.step() 137 | g_lr_scheduler.step() 138 | 139 | # training loop 140 | avg_D_real_loss = 0 141 | avg_D_real_m_loss = 0 142 | avg_D_fake_loss = 0 143 | avg_G_fake_loss = 0 144 | avg_kld = 0 145 | for i, (img, desc, len_desc) in enumerate(train_loader): 146 | img, txt_feat, txt_feat_mismatch, txt_feat_relevant = \ 147 | preprocess(img, desc, len_desc, txt_encoder) 148 | img_norm = img * 2 - 1 149 | img_G = Variable(vgg_normalize(img.data)) if args.use_vgg else img_norm 150 | 151 | ONES = Variable(torch.ones(img.size(0))) 152 | ZEROS = Variable(torch.zeros(img.size(0))) 153 | if not args.no_cuda: 154 | ONES, ZEROS = ONES.cuda(), ZEROS.cuda() 155 | 156 | # UPDATE DISCRIMINATOR 157 | D.zero_grad() 158 | # real image with matching text 159 | real_logit = D(img_norm, txt_feat) 160 | real_loss = F.binary_cross_entropy_with_logits(real_logit, ONES) 161 | avg_D_real_loss += real_loss.data[0] 162 | real_loss.backward() 163 | # real image with mismatching text 164 | real_m_logit = D(img_norm, txt_feat_mismatch) 165 | real_m_loss = 0.5 * F.binary_cross_entropy_with_logits(real_m_logit, ZEROS) 166 | avg_D_real_m_loss += real_m_loss.data[0] 167 | real_m_loss.backward() 168 | # synthesized image with semantically relevant text 169 | fake, _ = G(img_G, txt_feat_relevant) 170 | fake_logit = D(fake.detach(), txt_feat_relevant) 171 | fake_loss = 0.5 * F.binary_cross_entropy_with_logits(fake_logit, ZEROS) 172 | avg_D_fake_loss += fake_loss.data[0] 173 | fake_loss.backward() 174 | d_optimizer.step() 175 | 176 | # UPDATE GENERATOR 177 | G.zero_grad() 178 | fake, (z_mean, z_log_stddev) = G(img_G, txt_feat_relevant) 179 | kld = torch.mean(-z_log_stddev + 0.5 * (torch.exp(2 * z_log_stddev) + torch.pow(z_mean, 2) - 1)) 180 | avg_kld += kld.data[0] 181 | fake_logit = D(fake, txt_feat_relevant) 182 | fake_loss = F.binary_cross_entropy_with_logits(fake_logit, ONES) 183 | avg_G_fake_loss += fake_loss.data[0] 184 | G_loss = fake_loss + kld 185 | G_loss.backward() 186 | g_optimizer.step() 187 | 188 | if i % 10 == 0: 189 | print('Epoch [%d/%d], Iter [%d/%d], D_real: %.4f, D_mis: %.4f, D_fake: %.4f, G_fake: %.4f, KLD: %.4f' 190 | % (epoch + 1, args.num_epochs, i + 1, len(train_loader), avg_D_real_loss / (i + 1), 191 | avg_D_real_m_loss / (i + 1), avg_D_fake_loss / (i + 1), avg_G_fake_loss / (i + 1), avg_kld / (i + 1))) 192 | 193 | save_image((fake.data + 1) * 0.5, './examples/epoch_%d.png' % (epoch + 1)) 194 | torch.save(G.state_dict(), args.save_filename) 195 | -------------------------------------------------------------------------------- /train_text_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import fasttext 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | from torch.autograd import Variable 8 | import torchvision.transforms as transforms 9 | 10 | from model import VisualSemanticEmbedding 11 | from data import ReedICML2016 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--img_root', type=str, required=True, 16 | help='root directory that contains images') 17 | parser.add_argument('--caption_root', type=str, required=True, 18 | help='root directory that contains captions') 19 | parser.add_argument('--trainclasses_file', type=str, required=True, 20 | help='text file that contains training classes') 21 | parser.add_argument('--fasttext_model', type=str, required=True, 22 | help='pretrained fastText model (binary file)') 23 | parser.add_argument('--save_filename', type=str, required=True, 24 | help='checkpoint file') 25 | parser.add_argument('--num_threads', type=int, default=4, 26 | help='number of threads for fetching data (default: 4)') 27 | parser.add_argument('--num_epochs', type=int, default=300, 28 | help='number of threads for fetching data (default: 300)') 29 | parser.add_argument('--batch_size', type=int, default=64, 30 | help='batch size (default: 64)') 31 | parser.add_argument('--learning_rate', type=float, default=0.0002, 32 | help='learning rate (dafault: 0.0002)') 33 | parser.add_argument('--margin', type=float, default=0.2, 34 | help='margin for pairwise ranking loss (default: 0.2)') 35 | parser.add_argument('--embed_ndim', type=int, default=300, 36 | help='dimension of embedded vector (default: 300)') 37 | parser.add_argument('--max_nwords', type=int, default=50, 38 | help='maximum number of words (default: 50)') 39 | parser.add_argument('--no_cuda', action='store_true', 40 | help='do not use cuda') 41 | args = parser.parse_args() 42 | 43 | if not args.no_cuda and not torch.cuda.is_available(): 44 | print('Warning: cuda is not available on this machine.') 45 | args.no_cuda = True 46 | 47 | 48 | def pairwise_ranking_loss(margin, x, v): 49 | zero = torch.zeros(1) 50 | diag_margin = margin * torch.eye(x.size(0)) 51 | if not args.no_cuda: 52 | zero, diag_margin = zero.cuda(), diag_margin.cuda() 53 | zero, diag_margin = Variable(zero), Variable(diag_margin) 54 | 55 | x = x / torch.norm(x, 2, 1, keepdim=True) 56 | v = v / torch.norm(v, 2, 1, keepdim=True) 57 | prod = torch.matmul(x, v.transpose(0, 1)) 58 | diag = torch.diag(prod) 59 | for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin 60 | for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin 61 | return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0) 62 | 63 | 64 | if __name__ == '__main__': 65 | print('Loading a pretrained fastText model...') 66 | word_embedding = fasttext.load_model(args.fasttext_model) 67 | 68 | print('Loading a dataset...') 69 | train_data = ReedICML2016(args.img_root, 70 | args.caption_root, 71 | args.trainclasses_file, 72 | word_embedding, 73 | args.max_nwords, 74 | transforms.Compose([ 75 | transforms.Scale(256), 76 | transforms.RandomCrop(224), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 80 | std=[0.229, 0.224, 0.225]) 81 | ])) 82 | 83 | word_embedding = None 84 | 85 | train_loader = data.DataLoader(train_data, 86 | batch_size=args.batch_size, 87 | shuffle=True, 88 | num_workers=args.num_threads) 89 | 90 | model = VisualSemanticEmbedding(args.embed_ndim) 91 | if not args.no_cuda: 92 | model.cuda() 93 | 94 | optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), 95 | lr=args.learning_rate) 96 | 97 | for epoch in range(args.num_epochs): 98 | avg_loss = 0 99 | for i, (img, desc, len_desc) in enumerate(train_loader): 100 | img = Variable(img.cuda() if not args.no_cuda else img) 101 | desc = Variable(desc.cuda() if not args.no_cuda else desc) 102 | len_desc, indices = torch.sort(len_desc, 0, True) 103 | indices = indices.numpy() 104 | img = img[indices, ...] 105 | desc = desc[indices, ...].transpose(0, 1) 106 | desc = nn.utils.rnn.pack_padded_sequence(desc, len_desc.numpy()) 107 | 108 | optimizer.zero_grad() 109 | img_feat, txt_feat = model(img, desc) 110 | loss = pairwise_ranking_loss(args.margin, img_feat, txt_feat) 111 | avg_loss += loss.data[0] 112 | loss.backward() 113 | optimizer.step() 114 | 115 | if i % 10 == 0: 116 | print('Epoch [%d/%d], Iter [%d/%d], Loss: %.4f' 117 | % (epoch + 1, args.num_epochs, i + 1, len(train_loader), avg_loss / (i + 1))) 118 | 119 | torch.save(model.state_dict(), args.save_filename) 120 | --------------------------------------------------------------------------------
DescriptionImage
ORIGINAL
{}