├── .gitignore ├── DAMSMencoders └── .gitignore ├── LICENSE ├── README.md ├── code ├── .gitignore ├── GlobalAttention.py ├── cfg │ ├── DAMSM │ │ ├── bird.yml │ │ └── coco.yml │ ├── bird_attn2.yml │ ├── bird_attnDCGAN2.yml │ ├── coco_attn2.yml │ ├── eval_bird.yml │ ├── eval_bird_attnDCGAN2.yml │ └── eval_coco.yml ├── datasets.py ├── main.py ├── miscc │ ├── __init__.py │ ├── config.py │ ├── losses.py │ └── utils.py ├── model.py ├── pretrain_DAMSM.py └── trainer.py ├── data └── .gitignore ├── eval ├── FreeMono.ttf ├── GlobalAttention.py ├── README.md ├── __init__.py ├── data │ ├── bird_AttnGAN2.pth │ ├── captions.pickle │ └── text_encoder200.pth ├── dockerfile.cpu ├── dockerfile.gpu ├── eval.py ├── main.py ├── miscc │ ├── __init__.py │ ├── config.py │ └── utils.py ├── model.py └── requirements.txt ├── example_bird.png ├── example_coco.png ├── framework.png └── models └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | backup 2 | output 3 | code/*.pyc 4 | code/miscc/*.pyc 5 | .DS_Store 6 | .idea/ 7 | -------------------------------------------------------------------------------- /DAMSMencoders/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tao Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AttnGAN 2 | 3 | Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation 4 | with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research). 5 | 6 | 7 | 8 | 9 | ### Dependencies 10 | python 2.7 11 | 12 | Pytorch 13 | 14 | In addition, please add the project folder to PYTHONPATH and `pip install` the following packages: 15 | - `python-dateutil` 16 | - `easydict` 17 | - `pandas` 18 | - `torchfile` 19 | - `nltk` 20 | - `scikit-image` 21 | 22 | 23 | 24 | **Data** 25 | 26 | 1. Download our preprocessed metadata for [birds](https://drive.google.com/open?id=1O_LtUP9sch09QH3s_EBAgLEctBQ5JBSJ) [coco](https://drive.google.com/open?id=1rSnbIGNDGZeHlsUlLdahj0RJ9oo6lgH9) and save them to `data/` 27 | 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/` 28 | 3. Download [coco](http://cocodataset.org/#download) dataset and extract the images to `data/coco/` 29 | 30 | 31 | 32 | **Training** 33 | - Pre-train DAMSM models: 34 | - For bird dataset: `python pretrain_DAMSM.py --cfg cfg/DAMSM/bird.yml --gpu 0` 35 | - For coco dataset: `python pretrain_DAMSM.py --cfg cfg/DAMSM/coco.yml --gpu 1` 36 | 37 | - Train AttnGAN models: 38 | - For bird dataset: `python main.py --cfg cfg/bird_attn2.yml --gpu 2` 39 | - For coco dataset: `python main.py --cfg cfg/coco_attn2.yml --gpu 3` 40 | 41 | - `*.yml` files are example configuration files for training/evaluation our models. 42 | 43 | 44 | 45 | **Pretrained Model** 46 | - [DAMSM for bird](https://drive.google.com/open?id=1GNUKjVeyWYBJ8hEU-yrfYQpDOkxEyP3V). Download and save it to `DAMSMencoders/` 47 | - [DAMSM for coco](https://drive.google.com/open?id=1zIrXCE9F6yfbEJIbNP5-YrEe2pZcPSGJ). Download and save it to `DAMSMencoders/` 48 | - [AttnGAN for bird](https://drive.google.com/open?id=1lqNG75suOuR_8gjoEPYNp8VyT_ufPPig). Download and save it to `models/` 49 | - [AttnGAN for coco](https://drive.google.com/open?id=1i9Xkg9nU74RAvkcqKE-rJYhjvzKAMnCi). Download and save it to `models/` 50 | 51 | - [AttnDCGAN for bird](https://drive.google.com/open?id=19TG0JUoXurxsmZLaJ82Yo6O0UJ6aDBpg). Download and save it to `models/` 52 | - This is an variant of AttnGAN which applies the propsoed attention mechanisms to DCGAN framework. 53 | 54 | **Sampling** 55 | - Run `python main.py --cfg cfg/eval_bird.yml --gpu 1` to generate examples from captions in files listed in "./data/birds/example_filenames.txt". Results are saved to `DAMSMencoders/`. 56 | - Change the `eval_*.yml` files to generate images from other pre-trained models. 57 | - Input your own sentence in "./data/birds/example_captions.txt" if you wannt to generate images from customized sentences. 58 | 59 | **Validation** 60 | - To generate images for all captions in the validation dataset, change B_VALIDATION to True in the eval_*.yml. and then run `python main.py --cfg cfg/eval_bird.yml --gpu 1` 61 | - We compute inception score for models trained on birds using [StackGAN-inception-model](https://github.com/hanzhanggit/StackGAN-inception-model). 62 | - We compute inception score for models trained on coco using [improved-gan/inception_score](https://github.com/openai/improved-gan/tree/master/inception_score). 63 | 64 | 65 | **Examples generated by AttnGAN [[Blog]](https://blogs.microsoft.com/ai/drawing-ai/)** 66 | 67 | bird example | coco example 68 | :-------------------------:|:-------------------------: 69 | ![](https://github.com/taoxugit/AttnGAN/blob/master/example_bird.png) | ![](https://github.com/taoxugit/AttnGAN/blob/master/example_coco.png) 70 | 71 | 72 | ### Creating an API 73 | [Evaluation code](eval) embedded into a callable containerized API is included in the `eval\` folder. 74 | 75 | ### Citing AttnGAN 76 | If you find AttnGAN useful in your research, please consider citing: 77 | 78 | ``` 79 | @article{Tao18attngan, 80 | author = {Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He}, 81 | title = {AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks}, 82 | Year = {2018}, 83 | booktitle = {{CVPR}} 84 | } 85 | ``` 86 | 87 | **Reference** 88 | 89 | - [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916) [[code]](https://github.com/hanzhanggit/StackGAN-v2) 90 | - [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) [[code]](https://github.com/carpedm20/DCGAN-tensorflow) 91 | -------------------------------------------------------------------------------- /code/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | !.gitignore -------------------------------------------------------------------------------- /code/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query metrix. 3 | Based on each query vector q, it computes a parameterized convex combination of the matrix 4 | based. 5 | H_1 H_2 H_3 ... H_n 6 | q q q q 7 | | | | | 8 | \ | | / 9 | ..... 10 | \ | / 11 | a 12 | Constructs a unit mapping. 13 | $$(H_1 + H_n, q) => (a)$$ 14 | Where H is of `batch x n x dim` and q is of `batch x dim`. 15 | 16 | References: 17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules 18 | http://www.aclweb.org/anthology/D15-1166 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def conv1x1(in_planes, out_planes): 26 | "1x1 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | 30 | 31 | def func_attention(query, context, gamma1): 32 | """ 33 | query: batch x ndf x queryL 34 | context: batch x ndf x ih x iw (sourceL=ihxiw) 35 | mask: batch_size x sourceL 36 | """ 37 | batch_size, queryL = query.size(0), query.size(2) 38 | ih, iw = context.size(2), context.size(3) 39 | sourceL = ih * iw 40 | 41 | # --> batch x sourceL x ndf 42 | context = context.view(batch_size, -1, sourceL) 43 | contextT = torch.transpose(context, 1, 2).contiguous() 44 | 45 | # Get attention 46 | # (batch x sourceL x ndf)(batch x ndf x queryL) 47 | # -->batch x sourceL x queryL 48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper 49 | # --> batch*sourceL x queryL 50 | attn = attn.view(batch_size*sourceL, queryL) 51 | attn = nn.Softmax()(attn) # Eq. (8) 52 | 53 | # --> batch x sourceL x queryL 54 | attn = attn.view(batch_size, sourceL, queryL) 55 | # --> batch*queryL x sourceL 56 | attn = torch.transpose(attn, 1, 2).contiguous() 57 | attn = attn.view(batch_size*queryL, sourceL) 58 | # Eq. (9) 59 | attn = attn * gamma1 60 | attn = nn.Softmax()(attn) 61 | attn = attn.view(batch_size, queryL, sourceL) 62 | # --> batch x sourceL x queryL 63 | attnT = torch.transpose(attn, 1, 2).contiguous() 64 | 65 | # (batch x ndf x sourceL)(batch x sourceL x queryL) 66 | # --> batch x ndf x queryL 67 | weightedContext = torch.bmm(context, attnT) 68 | 69 | return weightedContext, attn.view(batch_size, -1, ih, iw) 70 | 71 | 72 | class GlobalAttentionGeneral(nn.Module): 73 | def __init__(self, idf, cdf): 74 | super(GlobalAttentionGeneral, self).__init__() 75 | self.conv_context = conv1x1(cdf, idf) 76 | self.sm = nn.Softmax() 77 | self.mask = None 78 | 79 | def applyMask(self, mask): 80 | self.mask = mask # batch x sourceL 81 | 82 | def forward(self, input, context): 83 | """ 84 | input: batch x idf x ih x iw (queryL=ihxiw) 85 | context: batch x cdf x sourceL 86 | """ 87 | ih, iw = input.size(2), input.size(3) 88 | queryL = ih * iw 89 | batch_size, sourceL = context.size(0), context.size(2) 90 | 91 | # --> batch x queryL x idf 92 | target = input.view(batch_size, -1, queryL) 93 | targetT = torch.transpose(target, 1, 2).contiguous() 94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1 95 | sourceT = context.unsqueeze(3) 96 | # --> batch x idf x sourceL 97 | sourceT = self.conv_context(sourceT).squeeze(3) 98 | 99 | # Get attention 100 | # (batch x queryL x idf)(batch x idf x sourceL) 101 | # -->batch x queryL x sourceL 102 | attn = torch.bmm(targetT, sourceT) 103 | # --> batch*queryL x sourceL 104 | attn = attn.view(batch_size*queryL, sourceL) 105 | if self.mask is not None: 106 | # batch_size x sourceL --> batch_size*queryL x sourceL 107 | mask = self.mask.repeat(queryL, 1) 108 | attn.data.masked_fill_(mask.data, -float('inf')) 109 | attn = self.sm(attn) # Eq. (2) 110 | # --> batch x queryL x sourceL 111 | attn = attn.view(batch_size, queryL, sourceL) 112 | # --> batch x sourceL x queryL 113 | attn = torch.transpose(attn, 1, 2).contiguous() 114 | 115 | # (batch x idf x sourceL)(batch x sourceL x queryL) 116 | # --> batch x idf x queryL 117 | weightedContext = torch.bmm(sourceT, attn) 118 | weightedContext = weightedContext.view(batch_size, -1, ih, iw) 119 | attn = attn.view(batch_size, -1, ih, iw) 120 | 121 | return weightedContext, attn 122 | -------------------------------------------------------------------------------- /code/cfg/DAMSM/bird.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'DAMSM' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: '../data/birds' 5 | GPU_ID: 0 6 | WORKERS: 1 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 1 11 | BASE_SIZE: 299 12 | 13 | 14 | TRAIN: 15 | FLAG: True 16 | NET_E: '' # '../DAMSMencoders/bird/text_encoder200.pth' 17 | BATCH_SIZE: 48 18 | MAX_EPOCH: 600 19 | SNAPSHOT_INTERVAL: 50 20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good; scott: 0.0007 with 0.98decay 21 | RNN_GRAD_CLIP: 0.25 22 | SMOOTH: 23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 24 | GAMMA2: 5.0 25 | GAMMA3: 10.0 # 10good 1&100bad 26 | 27 | 28 | 29 | TEXT: 30 | EMBEDDING_DIM: 256 31 | CAPTIONS_PER_IMAGE: 10 32 | -------------------------------------------------------------------------------- /code/cfg/DAMSM/coco.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'DAMSM' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: '../data/coco' 5 | GPU_ID: 0 6 | WORKERS: 1 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 1 11 | BASE_SIZE: 299 12 | 13 | 14 | TRAIN: 15 | FLAG: True 16 | NET_E: '' # '../DAMSMencoders/coco/text_encoder100.pth' 17 | BATCH_SIZE: 48 18 | MAX_EPOCH: 600 19 | SNAPSHOT_INTERVAL: 5 20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good 21 | RNN_GRAD_CLIP: 0.25 22 | SMOOTH: 23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 24 | GAMMA2: 5.0 25 | GAMMA3: 10.0 # 10good 1&100bad 26 | 27 | 28 | TEXT: 29 | EMBEDDING_DIM: 256 30 | CAPTIONS_PER_IMAGE: 5 31 | WORDS_NUM: 15 32 | -------------------------------------------------------------------------------- /code/cfg/bird_attn2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: '../data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: True 15 | NET_G: '' # '../models/bird_AttnGAN2.pth' 16 | B_NET_D: True 17 | BATCH_SIZE: 20 # 22 18 | MAX_EPOCH: 600 19 | SNAPSHOT_INTERVAL: 50 20 | DISCRIMINATOR_LR: 0.0002 21 | GENERATOR_LR: 0.0002 22 | # 23 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth' 24 | SMOOTH: 25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 26 | GAMMA2: 5.0 27 | GAMMA3: 10.0 # 10good 1&100bad 28 | LAMBDA: 5.0 29 | 30 | 31 | GAN: 32 | DF_DIM: 64 33 | GF_DIM: 32 34 | Z_DIM: 100 35 | R_NUM: 2 36 | 37 | TEXT: 38 | EMBEDDING_DIM: 256 39 | CAPTIONS_PER_IMAGE: 10 40 | -------------------------------------------------------------------------------- /code/cfg/bird_attnDCGAN2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2-dcgan' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: '../data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: True 15 | NET_G: '' # '../models/bird_AttnDCGAN2.pth' 16 | B_NET_D: True 17 | BATCH_SIZE: 30 18 | MAX_EPOCH: 400 19 | SNAPSHOT_INTERVAL: 50 20 | DISCRIMINATOR_LR: 0.0002 21 | GENERATOR_LR: 0.0002 22 | # 23 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth' 24 | SMOOTH: 25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 26 | GAMMA2: 5.0 27 | GAMMA3: 10.0 # 10good 1&100bad 28 | LAMBDA: 1.0 29 | 30 | 31 | GAN: 32 | DF_DIM: 64 33 | GF_DIM: 32 34 | Z_DIM: 100 35 | R_NUM: 0 36 | B_DCGAN: True 37 | 38 | TEXT: 39 | EMBEDDING_DIM: 256 40 | CAPTIONS_PER_IMAGE: 10 41 | -------------------------------------------------------------------------------- /code/cfg/coco_attn2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'glu-gan2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: '../data/coco' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: True 15 | NET_G: '' # '../models/coco_AttnGAN2.pth' 16 | B_NET_D: True 17 | BATCH_SIZE: 14 # 32 18 | MAX_EPOCH: 120 19 | SNAPSHOT_INTERVAL: 5 20 | DISCRIMINATOR_LR: 0.0002 21 | GENERATOR_LR: 0.0002 22 | # 23 | NET_E: '../DAMSMencoders/coco/text_encoder100.pth' 24 | SMOOTH: 25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 26 | GAMMA2: 5.0 27 | GAMMA3: 10.0 # 10good 1&100bad 28 | LAMBDA: 50.0 29 | 30 | 31 | GAN: 32 | DF_DIM: 96 33 | GF_DIM: 48 34 | Z_DIM: 100 35 | R_NUM: 3 36 | 37 | TEXT: 38 | EMBEDDING_DIM: 256 39 | CAPTIONS_PER_IMAGE: 5 40 | WORDS_NUM: 12 41 | -------------------------------------------------------------------------------- /code/cfg/eval_bird.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: '../data/birds' 5 | GPU_ID: 3 6 | WORKERS: 1 7 | 8 | B_VALIDATION: False # True # False 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: False 15 | NET_G: '../models/bird_AttnGAN2.pth' 16 | B_NET_D: False 17 | BATCH_SIZE: 100 18 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth' 19 | 20 | 21 | GAN: 22 | DF_DIM: 64 23 | GF_DIM: 32 24 | Z_DIM: 100 25 | R_NUM: 2 26 | 27 | TEXT: 28 | EMBEDDING_DIM: 256 29 | CAPTIONS_PER_IMAGE: 10 30 | WORDS_NUM: 25 31 | -------------------------------------------------------------------------------- /code/cfg/eval_bird_attnDCGAN2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2-dcgan' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: '../data/birds' 5 | GPU_ID: 3 6 | WORKERS: 1 7 | 8 | B_VALIDATION: False 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: False 15 | NET_G: '../models/bird_AttnDCGAN2.pth' 16 | B_NET_D: False 17 | BATCH_SIZE: 100 18 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth' 19 | 20 | 21 | GAN: 22 | DF_DIM: 64 23 | GF_DIM: 32 24 | Z_DIM: 100 25 | R_NUM: 0 26 | B_DCGAN: True 27 | 28 | TEXT: 29 | EMBEDDING_DIM: 256 30 | CAPTIONS_PER_IMAGE: 10 31 | WORDS_NUM: 25 32 | -------------------------------------------------------------------------------- /code/cfg/eval_coco.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: '../data/coco' 5 | GPU_ID: 3 6 | WORKERS: 1 7 | 8 | B_VALIDATION: False 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: False 15 | NET_G: '../models/coco_AttnGAN2.pth' 16 | B_NET_D: False 17 | BATCH_SIZE: 100 18 | NET_E: '../DAMSMencoders/coco/text_encoder100.pth' 19 | 20 | 21 | GAN: 22 | DF_DIM: 96 23 | GF_DIM: 48 24 | Z_DIM: 100 25 | R_NUM: 3 26 | 27 | TEXT: 28 | EMBEDDING_DIM: 256 29 | CAPTIONS_PER_IMAGE: 5 30 | WORDS_NUM: 20 31 | -------------------------------------------------------------------------------- /code/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | from nltk.tokenize import RegexpTokenizer 8 | from collections import defaultdict 9 | from miscc.config import cfg 10 | 11 | import torch 12 | import torch.utils.data as data 13 | from torch.autograd import Variable 14 | import torchvision.transforms as transforms 15 | 16 | import os 17 | import sys 18 | import numpy as np 19 | import pandas as pd 20 | from PIL import Image 21 | import numpy.random as random 22 | if sys.version_info[0] == 2: 23 | import cPickle as pickle 24 | else: 25 | import pickle 26 | 27 | 28 | def prepare_data(data): 29 | imgs, captions, captions_lens, class_ids, keys = data 30 | 31 | # sort data by the length in a decreasing order 32 | sorted_cap_lens, sorted_cap_indices = \ 33 | torch.sort(captions_lens, 0, True) 34 | 35 | real_imgs = [] 36 | for i in range(len(imgs)): 37 | imgs[i] = imgs[i][sorted_cap_indices] 38 | if cfg.CUDA: 39 | real_imgs.append(Variable(imgs[i]).cuda()) 40 | else: 41 | real_imgs.append(Variable(imgs[i])) 42 | 43 | captions = captions[sorted_cap_indices].squeeze() 44 | class_ids = class_ids[sorted_cap_indices].numpy() 45 | # sent_indices = sent_indices[sorted_cap_indices] 46 | keys = [keys[i] for i in sorted_cap_indices.numpy()] 47 | # print('keys', type(keys), keys[-1]) # list 48 | if cfg.CUDA: 49 | captions = Variable(captions).cuda() 50 | sorted_cap_lens = Variable(sorted_cap_lens).cuda() 51 | else: 52 | captions = Variable(captions) 53 | sorted_cap_lens = Variable(sorted_cap_lens) 54 | 55 | return [real_imgs, captions, sorted_cap_lens, 56 | class_ids, keys] 57 | 58 | 59 | def get_imgs(img_path, imsize, bbox=None, 60 | transform=None, normalize=None): 61 | img = Image.open(img_path).convert('RGB') 62 | width, height = img.size 63 | if bbox is not None: 64 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 65 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 66 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 67 | y1 = np.maximum(0, center_y - r) 68 | y2 = np.minimum(height, center_y + r) 69 | x1 = np.maximum(0, center_x - r) 70 | x2 = np.minimum(width, center_x + r) 71 | img = img.crop([x1, y1, x2, y2]) 72 | 73 | if transform is not None: 74 | img = transform(img) 75 | 76 | ret = [] 77 | if cfg.GAN.B_DCGAN: 78 | ret = [normalize(img)] 79 | else: 80 | for i in range(cfg.TREE.BRANCH_NUM): 81 | # print(imsize[i]) 82 | if i < (cfg.TREE.BRANCH_NUM - 1): 83 | re_img = transforms.Scale(imsize[i])(img) 84 | else: 85 | re_img = img 86 | ret.append(normalize(re_img)) 87 | 88 | return ret 89 | 90 | 91 | class TextDataset(data.Dataset): 92 | def __init__(self, data_dir, split='train', 93 | base_size=64, 94 | transform=None, target_transform=None): 95 | self.transform = transform 96 | self.norm = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 99 | self.target_transform = target_transform 100 | self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE 101 | 102 | self.imsize = [] 103 | for i in range(cfg.TREE.BRANCH_NUM): 104 | self.imsize.append(base_size) 105 | base_size = base_size * 2 106 | 107 | self.data = [] 108 | self.data_dir = data_dir 109 | if data_dir.find('birds') != -1: 110 | self.bbox = self.load_bbox() 111 | else: 112 | self.bbox = None 113 | split_dir = os.path.join(data_dir, split) 114 | 115 | self.filenames, self.captions, self.ixtoword, \ 116 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split) 117 | 118 | self.class_id = self.load_class_id(split_dir, len(self.filenames)) 119 | self.number_example = len(self.filenames) 120 | 121 | def load_bbox(self): 122 | data_dir = self.data_dir 123 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') 124 | df_bounding_boxes = pd.read_csv(bbox_path, 125 | delim_whitespace=True, 126 | header=None).astype(int) 127 | # 128 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') 129 | df_filenames = \ 130 | pd.read_csv(filepath, delim_whitespace=True, header=None) 131 | filenames = df_filenames[1].tolist() 132 | print('Total filenames: ', len(filenames), filenames[0]) 133 | # 134 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 135 | numImgs = len(filenames) 136 | for i in xrange(0, numImgs): 137 | # bbox = [x-left, y-top, width, height] 138 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 139 | 140 | key = filenames[i][:-4] 141 | filename_bbox[key] = bbox 142 | # 143 | return filename_bbox 144 | 145 | def load_captions(self, data_dir, filenames): 146 | all_captions = [] 147 | for i in range(len(filenames)): 148 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 149 | with open(cap_path, "r") as f: 150 | captions = f.read().decode('utf8').split('\n') 151 | cnt = 0 152 | for cap in captions: 153 | if len(cap) == 0: 154 | continue 155 | cap = cap.replace("\ufffd\ufffd", " ") 156 | # picks out sequences of alphanumeric characters as tokens 157 | # and drops everything else 158 | tokenizer = RegexpTokenizer(r'\w+') 159 | tokens = tokenizer.tokenize(cap.lower()) 160 | # print('tokens', tokens) 161 | if len(tokens) == 0: 162 | print('cap', cap) 163 | continue 164 | 165 | tokens_new = [] 166 | for t in tokens: 167 | t = t.encode('ascii', 'ignore').decode('ascii') 168 | if len(t) > 0: 169 | tokens_new.append(t) 170 | all_captions.append(tokens_new) 171 | cnt += 1 172 | if cnt == self.embeddings_num: 173 | break 174 | if cnt < self.embeddings_num: 175 | print('ERROR: the captions for %s less than %d' 176 | % (filenames[i], cnt)) 177 | return all_captions 178 | 179 | def build_dictionary(self, train_captions, test_captions): 180 | word_counts = defaultdict(float) 181 | captions = train_captions + test_captions 182 | for sent in captions: 183 | for word in sent: 184 | word_counts[word] += 1 185 | 186 | vocab = [w for w in word_counts if word_counts[w] >= 0] 187 | 188 | ixtoword = {} 189 | ixtoword[0] = '' 190 | wordtoix = {} 191 | wordtoix[''] = 0 192 | ix = 1 193 | for w in vocab: 194 | wordtoix[w] = ix 195 | ixtoword[ix] = w 196 | ix += 1 197 | 198 | train_captions_new = [] 199 | for t in train_captions: 200 | rev = [] 201 | for w in t: 202 | if w in wordtoix: 203 | rev.append(wordtoix[w]) 204 | # rev.append(0) # do not need '' token 205 | train_captions_new.append(rev) 206 | 207 | test_captions_new = [] 208 | for t in test_captions: 209 | rev = [] 210 | for w in t: 211 | if w in wordtoix: 212 | rev.append(wordtoix[w]) 213 | # rev.append(0) # do not need '' token 214 | test_captions_new.append(rev) 215 | 216 | return [train_captions_new, test_captions_new, 217 | ixtoword, wordtoix, len(ixtoword)] 218 | 219 | def load_text_data(self, data_dir, split): 220 | filepath = os.path.join(data_dir, 'captions.pickle') 221 | train_names = self.load_filenames(data_dir, 'train') 222 | test_names = self.load_filenames(data_dir, 'test') 223 | if not os.path.isfile(filepath): 224 | train_captions = self.load_captions(data_dir, train_names) 225 | test_captions = self.load_captions(data_dir, test_names) 226 | 227 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 228 | self.build_dictionary(train_captions, test_captions) 229 | with open(filepath, 'wb') as f: 230 | pickle.dump([train_captions, test_captions, 231 | ixtoword, wordtoix], f, protocol=2) 232 | print('Save to: ', filepath) 233 | else: 234 | with open(filepath, 'rb') as f: 235 | x = pickle.load(f) 236 | train_captions, test_captions = x[0], x[1] 237 | ixtoword, wordtoix = x[2], x[3] 238 | del x 239 | n_words = len(ixtoword) 240 | print('Load from: ', filepath) 241 | if split == 'train': 242 | # a list of list: each list contains 243 | # the indices of words in a sentence 244 | captions = train_captions 245 | filenames = train_names 246 | else: # split=='test' 247 | captions = test_captions 248 | filenames = test_names 249 | return filenames, captions, ixtoword, wordtoix, n_words 250 | 251 | def load_class_id(self, data_dir, total_num): 252 | if os.path.isfile(data_dir + '/class_info.pickle'): 253 | with open(data_dir + '/class_info.pickle', 'rb') as f: 254 | class_id = pickle.load(f) 255 | else: 256 | class_id = np.arange(total_num) 257 | return class_id 258 | 259 | def load_filenames(self, data_dir, split): 260 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 261 | if os.path.isfile(filepath): 262 | with open(filepath, 'rb') as f: 263 | filenames = pickle.load(f) 264 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 265 | else: 266 | filenames = [] 267 | return filenames 268 | 269 | def get_caption(self, sent_ix): 270 | # a list of indices for a sentence 271 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') 272 | if (sent_caption == 0).sum() > 0: 273 | print('ERROR: do not need END (0) token', sent_caption) 274 | num_words = len(sent_caption) 275 | # pad with 0s (i.e., '') 276 | x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') 277 | x_len = num_words 278 | if num_words <= cfg.TEXT.WORDS_NUM: 279 | x[:num_words, 0] = sent_caption 280 | else: 281 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum 282 | np.random.shuffle(ix) 283 | ix = ix[:cfg.TEXT.WORDS_NUM] 284 | ix = np.sort(ix) 285 | x[:, 0] = sent_caption[ix] 286 | x_len = cfg.TEXT.WORDS_NUM 287 | return x, x_len 288 | 289 | def __getitem__(self, index): 290 | # 291 | key = self.filenames[index] 292 | cls_id = self.class_id[index] 293 | # 294 | if self.bbox is not None: 295 | bbox = self.bbox[key] 296 | data_dir = '%s/CUB_200_2011' % self.data_dir 297 | else: 298 | bbox = None 299 | data_dir = self.data_dir 300 | # 301 | img_name = '%s/images/%s.jpg' % (data_dir, key) 302 | imgs = get_imgs(img_name, self.imsize, 303 | bbox, self.transform, normalize=self.norm) 304 | # random select a sentence 305 | sent_ix = random.randint(0, self.embeddings_num) 306 | new_sent_ix = index * self.embeddings_num + sent_ix 307 | caps, cap_len = self.get_caption(new_sent_ix) 308 | return imgs, caps, cap_len, cls_id, key 309 | 310 | 311 | def __len__(self): 312 | return len(self.filenames) 313 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from miscc.config import cfg, cfg_from_file 4 | from datasets import TextDataset 5 | from trainer import condGANTrainer as trainer 6 | 7 | import os 8 | import sys 9 | import time 10 | import random 11 | import pprint 12 | import datetime 13 | import dateutil.tz 14 | import argparse 15 | import numpy as np 16 | 17 | import torch 18 | import torchvision.transforms as transforms 19 | 20 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 21 | sys.path.append(dir_path) 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train a AttnGAN network') 26 | parser.add_argument('--cfg', dest='cfg_file', 27 | help='optional config file', 28 | default='cfg/bird_attn2.yml', type=str) 29 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1) 30 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 31 | parser.add_argument('--manualSeed', type=int, help='manual seed') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def gen_example(wordtoix, algo): 37 | '''generate images from example sentences''' 38 | from nltk.tokenize import RegexpTokenizer 39 | filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) 40 | data_dic = {} 41 | with open(filepath, "r") as f: 42 | filenames = f.read().decode('utf8').split('\n') 43 | for name in filenames: 44 | if len(name) == 0: 45 | continue 46 | filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) 47 | with open(filepath, "r") as f: 48 | print('Load from:', name) 49 | sentences = f.read().decode('utf8').split('\n') 50 | # a list of indices for a sentence 51 | captions = [] 52 | cap_lens = [] 53 | for sent in sentences: 54 | if len(sent) == 0: 55 | continue 56 | sent = sent.replace("\ufffd\ufffd", " ") 57 | tokenizer = RegexpTokenizer(r'\w+') 58 | tokens = tokenizer.tokenize(sent.lower()) 59 | if len(tokens) == 0: 60 | print('sent', sent) 61 | continue 62 | 63 | rev = [] 64 | for t in tokens: 65 | t = t.encode('ascii', 'ignore').decode('ascii') 66 | if len(t) > 0 and t in wordtoix: 67 | rev.append(wordtoix[t]) 68 | captions.append(rev) 69 | cap_lens.append(len(rev)) 70 | max_len = np.max(cap_lens) 71 | 72 | sorted_indices = np.argsort(cap_lens)[::-1] 73 | cap_lens = np.asarray(cap_lens) 74 | cap_lens = cap_lens[sorted_indices] 75 | cap_array = np.zeros((len(captions), max_len), dtype='int64') 76 | for i in range(len(captions)): 77 | idx = sorted_indices[i] 78 | cap = captions[idx] 79 | c_len = len(cap) 80 | cap_array[i, :c_len] = cap 81 | key = name[(name.rfind('/') + 1):] 82 | data_dic[key] = [cap_array, cap_lens, sorted_indices] 83 | algo.gen_example(data_dic) 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | if args.cfg_file is not None: 89 | cfg_from_file(args.cfg_file) 90 | 91 | if args.gpu_id != -1: 92 | cfg.GPU_ID = args.gpu_id 93 | else: 94 | cfg.CUDA = False 95 | 96 | if args.data_dir != '': 97 | cfg.DATA_DIR = args.data_dir 98 | print('Using config:') 99 | pprint.pprint(cfg) 100 | 101 | if not cfg.TRAIN.FLAG: 102 | args.manualSeed = 100 103 | elif args.manualSeed is None: 104 | args.manualSeed = random.randint(1, 10000) 105 | random.seed(args.manualSeed) 106 | np.random.seed(args.manualSeed) 107 | torch.manual_seed(args.manualSeed) 108 | if cfg.CUDA: 109 | torch.cuda.manual_seed_all(args.manualSeed) 110 | 111 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 112 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 113 | output_dir = '../output/%s_%s_%s' % \ 114 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 115 | 116 | split_dir, bshuffle = 'train', True 117 | if not cfg.TRAIN.FLAG: 118 | # bshuffle = False 119 | split_dir = 'test' 120 | 121 | # Get data loader 122 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) 123 | image_transform = transforms.Compose([ 124 | transforms.Scale(int(imsize * 76 / 64)), 125 | transforms.RandomCrop(imsize), 126 | transforms.RandomHorizontalFlip()]) 127 | dataset = TextDataset(cfg.DATA_DIR, split_dir, 128 | base_size=cfg.TREE.BASE_SIZE, 129 | transform=image_transform) 130 | assert dataset 131 | dataloader = torch.utils.data.DataLoader( 132 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 133 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 134 | 135 | # Define models and go to train/evaluate 136 | algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword) 137 | 138 | start_t = time.time() 139 | if cfg.TRAIN.FLAG: 140 | algo.train() 141 | else: 142 | '''generate images from pre-extracted embeddings''' 143 | if cfg.B_VALIDATION: 144 | algo.sampling(split_dir) # generate images for the whole valid dataset 145 | else: 146 | gen_example(dataset.wordtoix, algo) # generate images for customized captions 147 | end_t = time.time() 148 | print('Total time for training:', end_t - start_t) 149 | -------------------------------------------------------------------------------- /code/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.CONFIG_NAME = '' 15 | __C.DATA_DIR = '' 16 | __C.GPU_ID = 0 17 | __C.CUDA = True 18 | __C.WORKERS = 6 19 | 20 | __C.RNN_TYPE = 'LSTM' # 'GRU' 21 | __C.B_VALIDATION = False 22 | 23 | __C.TREE = edict() 24 | __C.TREE.BRANCH_NUM = 3 25 | __C.TREE.BASE_SIZE = 64 26 | 27 | 28 | # Training options 29 | __C.TRAIN = edict() 30 | __C.TRAIN.BATCH_SIZE = 64 31 | __C.TRAIN.MAX_EPOCH = 600 32 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000 33 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 34 | __C.TRAIN.GENERATOR_LR = 2e-4 35 | __C.TRAIN.ENCODER_LR = 2e-4 36 | __C.TRAIN.RNN_GRAD_CLIP = 0.25 37 | __C.TRAIN.FLAG = True 38 | __C.TRAIN.NET_E = '' 39 | __C.TRAIN.NET_G = '' 40 | __C.TRAIN.B_NET_D = True 41 | 42 | __C.TRAIN.SMOOTH = edict() 43 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0 44 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0 45 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0 46 | __C.TRAIN.SMOOTH.LAMBDA = 1.0 47 | 48 | 49 | # Modal options 50 | __C.GAN = edict() 51 | __C.GAN.DF_DIM = 64 52 | __C.GAN.GF_DIM = 128 53 | __C.GAN.Z_DIM = 100 54 | __C.GAN.CONDITION_DIM = 100 55 | __C.GAN.R_NUM = 2 56 | __C.GAN.B_ATTENTION = True 57 | __C.GAN.B_DCGAN = False 58 | 59 | 60 | __C.TEXT = edict() 61 | __C.TEXT.CAPTIONS_PER_IMAGE = 10 62 | __C.TEXT.EMBEDDING_DIM = 256 63 | __C.TEXT.WORDS_NUM = 18 64 | 65 | 66 | def _merge_a_into_b(a, b): 67 | """Merge config dictionary a into config dictionary b, clobbering the 68 | options in b whenever they are also specified in a. 69 | """ 70 | if type(a) is not edict: 71 | return 72 | 73 | for k, v in a.iteritems(): 74 | # a must specify keys that are in b 75 | if not b.has_key(k): 76 | raise KeyError('{} is not a valid config key'.format(k)) 77 | 78 | # the types must match, too 79 | old_type = type(b[k]) 80 | if old_type is not type(v): 81 | if isinstance(b[k], np.ndarray): 82 | v = np.array(v, dtype=b[k].dtype) 83 | else: 84 | raise ValueError(('Type mismatch ({} vs. {}) ' 85 | 'for config key: {}').format(type(b[k]), 86 | type(v), k)) 87 | 88 | # recursively merge dicts 89 | if type(v) is edict: 90 | try: 91 | _merge_a_into_b(a[k], b[k]) 92 | except: 93 | print('Error under config key: {}'.format(k)) 94 | raise 95 | else: 96 | b[k] = v 97 | 98 | 99 | def cfg_from_file(filename): 100 | """Load a config file and merge it into the default options.""" 101 | import yaml 102 | with open(filename, 'r') as f: 103 | yaml_cfg = edict(yaml.load(f)) 104 | 105 | _merge_a_into_b(yaml_cfg, __C) 106 | -------------------------------------------------------------------------------- /code/miscc/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from miscc.config import cfg 6 | 7 | from GlobalAttention import func_attention 8 | 9 | 10 | # ##################Loss for matching text-image################### 11 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 12 | """Returns cosine similarity between x1 and x2, computed along dim. 13 | """ 14 | w12 = torch.sum(x1 * x2, dim) 15 | w1 = torch.norm(x1, 2, dim) 16 | w2 = torch.norm(x2, 2, dim) 17 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 18 | 19 | 20 | def sent_loss(cnn_code, rnn_code, labels, class_ids, 21 | batch_size, eps=1e-8): 22 | # ### Mask mis-match samples ### 23 | # that come from the same class as the real sample ### 24 | masks = [] 25 | if class_ids is not None: 26 | for i in range(batch_size): 27 | mask = (class_ids == class_ids[i]).astype(np.uint8) 28 | mask[i] = 0 29 | masks.append(mask.reshape((1, -1))) 30 | masks = np.concatenate(masks, 0) 31 | # masks: batch_size x batch_size 32 | masks = torch.ByteTensor(masks) 33 | if cfg.CUDA: 34 | masks = masks.cuda() 35 | 36 | # --> seq_len x batch_size x nef 37 | if cnn_code.dim() == 2: 38 | cnn_code = cnn_code.unsqueeze(0) 39 | rnn_code = rnn_code.unsqueeze(0) 40 | 41 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 42 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) 43 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) 44 | # scores* / norm*: seq_len x batch_size x batch_size 45 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) 46 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) 47 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 48 | 49 | # --> batch_size x batch_size 50 | scores0 = scores0.squeeze() 51 | if class_ids is not None: 52 | scores0.data.masked_fill_(masks, -float('inf')) 53 | scores1 = scores0.transpose(0, 1) 54 | if labels is not None: 55 | loss0 = nn.CrossEntropyLoss()(scores0, labels) 56 | loss1 = nn.CrossEntropyLoss()(scores1, labels) 57 | else: 58 | loss0, loss1 = None, None 59 | return loss0, loss1 60 | 61 | 62 | def words_loss(img_features, words_emb, labels, 63 | cap_lens, class_ids, batch_size): 64 | """ 65 | words_emb(query): batch x nef x seq_len 66 | img_features(context): batch x nef x 17 x 17 67 | """ 68 | masks = [] 69 | att_maps = [] 70 | similarities = [] 71 | cap_lens = cap_lens.data.tolist() 72 | for i in range(batch_size): 73 | if class_ids is not None: 74 | mask = (class_ids == class_ids[i]).astype(np.uint8) 75 | mask[i] = 0 76 | masks.append(mask.reshape((1, -1))) 77 | # Get the i-th text description 78 | words_num = cap_lens[i] 79 | # -> 1 x nef x words_num 80 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() 81 | # -> batch_size x nef x words_num 82 | word = word.repeat(batch_size, 1, 1) 83 | # batch x nef x 17*17 84 | context = img_features 85 | """ 86 | word(query): batch x nef x words_num 87 | context: batch x nef x 17 x 17 88 | weiContext: batch x nef x words_num 89 | attn: batch x words_num x 17 x 17 90 | """ 91 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) 92 | att_maps.append(attn[i].unsqueeze(0).contiguous()) 93 | # --> batch_size x words_num x nef 94 | word = word.transpose(1, 2).contiguous() 95 | weiContext = weiContext.transpose(1, 2).contiguous() 96 | # --> batch_size*words_num x nef 97 | word = word.view(batch_size * words_num, -1) 98 | weiContext = weiContext.view(batch_size * words_num, -1) 99 | # 100 | # -->batch_size*words_num 101 | row_sim = cosine_similarity(word, weiContext) 102 | # --> batch_size x words_num 103 | row_sim = row_sim.view(batch_size, words_num) 104 | 105 | # Eq. (10) 106 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() 107 | row_sim = row_sim.sum(dim=1, keepdim=True) 108 | row_sim = torch.log(row_sim) 109 | 110 | # --> 1 x batch_size 111 | # similarities(i, j): the similarity between the i-th image and the j-th text description 112 | similarities.append(row_sim) 113 | 114 | # batch_size x batch_size 115 | similarities = torch.cat(similarities, 1) 116 | if class_ids is not None: 117 | masks = np.concatenate(masks, 0) 118 | # masks: batch_size x batch_size 119 | masks = torch.ByteTensor(masks) 120 | if cfg.CUDA: 121 | masks = masks.cuda() 122 | 123 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 124 | if class_ids is not None: 125 | similarities.data.masked_fill_(masks, -float('inf')) 126 | similarities1 = similarities.transpose(0, 1) 127 | if labels is not None: 128 | loss0 = nn.CrossEntropyLoss()(similarities, labels) 129 | loss1 = nn.CrossEntropyLoss()(similarities1, labels) 130 | else: 131 | loss0, loss1 = None, None 132 | return loss0, loss1, att_maps 133 | 134 | 135 | # ##################Loss for G and Ds############################## 136 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions, 137 | real_labels, fake_labels): 138 | # Forward 139 | real_features = netD(real_imgs) 140 | fake_features = netD(fake_imgs.detach()) 141 | # loss 142 | # 143 | cond_real_logits = netD.COND_DNET(real_features, conditions) 144 | cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) 145 | cond_fake_logits = netD.COND_DNET(fake_features, conditions) 146 | cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) 147 | # 148 | batch_size = real_features.size(0) 149 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) 150 | cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) 151 | 152 | if netD.UNCOND_DNET is not None: 153 | real_logits = netD.UNCOND_DNET(real_features) 154 | fake_logits = netD.UNCOND_DNET(fake_features) 155 | real_errD = nn.BCELoss()(real_logits, real_labels) 156 | fake_errD = nn.BCELoss()(fake_logits, fake_labels) 157 | errD = ((real_errD + cond_real_errD) / 2. + 158 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) 159 | else: 160 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. 161 | return errD 162 | 163 | 164 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels, 165 | words_embs, sent_emb, match_labels, 166 | cap_lens, class_ids): 167 | numDs = len(netsD) 168 | batch_size = real_labels.size(0) 169 | logs = '' 170 | # Forward 171 | errG_total = 0 172 | for i in range(numDs): 173 | features = netsD[i](fake_imgs[i]) 174 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 175 | cond_errG = nn.BCELoss()(cond_logits, real_labels) 176 | if netsD[i].UNCOND_DNET is not None: 177 | logits = netsD[i].UNCOND_DNET(features) 178 | errG = nn.BCELoss()(logits, real_labels) 179 | g_loss = errG + cond_errG 180 | else: 181 | g_loss = cond_errG 182 | errG_total += g_loss 183 | # err_img = errG_total.data[0] 184 | logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0]) 185 | 186 | # Ranking loss 187 | if i == (numDs - 1): 188 | # words_features: batch_size x nef x 17 x 17 189 | # sent_code: batch_size x nef 190 | region_features, cnn_code = image_encoder(fake_imgs[i]) 191 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 192 | match_labels, cap_lens, 193 | class_ids, batch_size) 194 | w_loss = (w_loss0 + w_loss1) * \ 195 | cfg.TRAIN.SMOOTH.LAMBDA 196 | # err_words = err_words + w_loss.data[0] 197 | 198 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 199 | match_labels, class_ids, batch_size) 200 | s_loss = (s_loss0 + s_loss1) * \ 201 | cfg.TRAIN.SMOOTH.LAMBDA 202 | # err_sent = err_sent + s_loss.data[0] 203 | 204 | errG_total += w_loss + s_loss 205 | logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0]) 206 | return errG_total, logs 207 | 208 | 209 | ################################################################## 210 | def KL_loss(mu, logvar): 211 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 212 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 213 | KLD = torch.mean(KLD_element).mul_(-0.5) 214 | return KLD 215 | -------------------------------------------------------------------------------- /code/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from copy import deepcopy 11 | import skimage.transform 12 | 13 | from miscc.config import cfg 14 | 15 | 16 | # For visualization ################################################ 17 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 18 | 2:[70, 70, 70], 3:[102,102,156], 19 | 4:[190,153,153], 5:[153,153,153], 20 | 6:[250,170, 30], 7:[220, 220, 0], 21 | 8:[107,142, 35], 9:[152,251,152], 22 | 10:[70,130,180], 11:[220,20, 60], 23 | 12:[255, 0, 0], 13:[0, 0, 142], 24 | 14:[119,11, 32], 15:[0, 60,100], 25 | 16:[0, 80, 100], 17:[0, 0, 230], 26 | 18:[0, 0, 70], 19:[0, 0, 0]} 27 | FONT_MAX = 50 28 | 29 | 30 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 31 | num = captions.size(0) 32 | img_txt = Image.fromarray(convas) 33 | # get a font 34 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 35 | fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 36 | # get a drawing context 37 | d = ImageDraw.Draw(img_txt) 38 | sentence_list = [] 39 | for i in range(num): 40 | cap = captions[i].data.cpu().numpy() 41 | sentence = [] 42 | for j in range(len(cap)): 43 | if cap[j] == 0: 44 | break 45 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 46 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 47 | font=fnt, fill=(255, 255, 255, 255)) 48 | sentence.append(word) 49 | sentence_list.append(sentence) 50 | return img_txt, sentence_list 51 | 52 | 53 | def build_super_images(real_imgs, captions, ixtoword, 54 | attn_maps, att_sze, lr_imgs=None, 55 | batch_size=cfg.TRAIN.BATCH_SIZE, 56 | max_word_num=cfg.TEXT.WORDS_NUM): 57 | nvis = 8 58 | real_imgs = real_imgs[:nvis] 59 | if lr_imgs is not None: 60 | lr_imgs = lr_imgs[:nvis] 61 | if att_sze == 17: 62 | vis_size = att_sze * 16 63 | else: 64 | vis_size = real_imgs.size(2) 65 | 66 | text_convas = \ 67 | np.ones([batch_size * FONT_MAX, 68 | (max_word_num + 2) * (vis_size + 2), 3], 69 | dtype=np.uint8) 70 | 71 | for i in range(max_word_num): 72 | istart = (i + 2) * (vis_size + 2) 73 | iend = (i + 3) * (vis_size + 2) 74 | text_convas[:, istart:iend, :] = COLOR_DIC[i] 75 | 76 | 77 | real_imgs = \ 78 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 79 | # [-1, 1] --> [0, 1] 80 | real_imgs.add_(1).div_(2).mul_(255) 81 | real_imgs = real_imgs.data.numpy() 82 | # b x c x h x w --> b x h x w x c 83 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 84 | pad_sze = real_imgs.shape 85 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 86 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) 87 | if lr_imgs is not None: 88 | lr_imgs = \ 89 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) 90 | # [-1, 1] --> [0, 1] 91 | lr_imgs.add_(1).div_(2).mul_(255) 92 | lr_imgs = lr_imgs.data.numpy() 93 | # b x c x h x w --> b x h x w x c 94 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) 95 | 96 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 97 | seq_len = max_word_num 98 | img_set = [] 99 | num = nvis # len(attn_maps) 100 | 101 | text_map, sentences = \ 102 | drawCaption(text_convas, captions, ixtoword, vis_size) 103 | text_map = np.asarray(text_map).astype(np.uint8) 104 | 105 | bUpdate = 1 106 | for i in range(num): 107 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 108 | # --> 1 x 1 x 17 x 17 109 | attn_max = attn.max(dim=1, keepdim=True) 110 | attn = torch.cat([attn_max[0], attn], 1) 111 | # 112 | attn = attn.view(-1, 1, att_sze, att_sze) 113 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 114 | # n x c x h x w --> n x h x w x c 115 | attn = np.transpose(attn, (0, 2, 3, 1)) 116 | num_attn = attn.shape[0] 117 | # 118 | img = real_imgs[i] 119 | if lr_imgs is None: 120 | lrI = img 121 | else: 122 | lrI = lr_imgs[i] 123 | row = [lrI, middle_pad] 124 | row_merge = [img, middle_pad] 125 | row_beforeNorm = [] 126 | minVglobal, maxVglobal = 1, 0 127 | for j in range(num_attn): 128 | one_map = attn[j] 129 | if (vis_size // att_sze) > 1: 130 | one_map = \ 131 | skimage.transform.pyramid_expand(one_map, sigma=20, 132 | upscale=vis_size // att_sze) 133 | row_beforeNorm.append(one_map) 134 | minV = one_map.min() 135 | maxV = one_map.max() 136 | if minVglobal > minV: 137 | minVglobal = minV 138 | if maxVglobal < maxV: 139 | maxVglobal = maxV 140 | for j in range(seq_len + 1): 141 | if j < num_attn: 142 | one_map = row_beforeNorm[j] 143 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) 144 | one_map *= 255 145 | # 146 | PIL_im = Image.fromarray(np.uint8(img)) 147 | PIL_att = Image.fromarray(np.uint8(one_map)) 148 | merged = \ 149 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 150 | mask = Image.new('L', (vis_size, vis_size), (210)) 151 | merged.paste(PIL_im, (0, 0)) 152 | merged.paste(PIL_att, (0, 0), mask) 153 | merged = np.array(merged)[:, :, :3] 154 | else: 155 | one_map = post_pad 156 | merged = post_pad 157 | row.append(one_map) 158 | row.append(middle_pad) 159 | # 160 | row_merge.append(merged) 161 | row_merge.append(middle_pad) 162 | row = np.concatenate(row, 1) 163 | row_merge = np.concatenate(row_merge, 1) 164 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] 165 | if txt.shape[1] != row.shape[1]: 166 | print('txt', txt.shape, 'row', row.shape) 167 | bUpdate = 0 168 | break 169 | row = np.concatenate([txt, row, row_merge], 0) 170 | img_set.append(row) 171 | if bUpdate: 172 | img_set = np.concatenate(img_set, 0) 173 | img_set = img_set.astype(np.uint8) 174 | return img_set, sentences 175 | else: 176 | return None 177 | 178 | 179 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 180 | attn_maps, att_sze, vis_size=256, topK=5): 181 | batch_size = real_imgs.size(0) 182 | max_word_num = np.max(cap_lens) 183 | text_convas = np.ones([batch_size * FONT_MAX, 184 | max_word_num * (vis_size + 2), 3], 185 | dtype=np.uint8) 186 | 187 | real_imgs = \ 188 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 189 | # [-1, 1] --> [0, 1] 190 | real_imgs.add_(1).div_(2).mul_(255) 191 | real_imgs = real_imgs.data.numpy() 192 | # b x c x h x w --> b x h x w x c 193 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 194 | pad_sze = real_imgs.shape 195 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 196 | 197 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 198 | img_set = [] 199 | num = len(attn_maps) 200 | 201 | text_map, sentences = \ 202 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 203 | text_map = np.asarray(text_map).astype(np.uint8) 204 | 205 | bUpdate = 1 206 | for i in range(num): 207 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 208 | # 209 | attn = attn.view(-1, 1, att_sze, att_sze) 210 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 211 | # n x c x h x w --> n x h x w x c 212 | attn = np.transpose(attn, (0, 2, 3, 1)) 213 | num_attn = cap_lens[i] 214 | thresh = 2./float(num_attn) 215 | # 216 | img = real_imgs[i] 217 | row = [] 218 | row_merge = [] 219 | row_txt = [] 220 | row_beforeNorm = [] 221 | conf_score = [] 222 | for j in range(num_attn): 223 | one_map = attn[j] 224 | mask0 = one_map > (2. * thresh) 225 | conf_score.append(np.sum(one_map * mask0)) 226 | mask = one_map > thresh 227 | one_map = one_map * mask 228 | if (vis_size // att_sze) > 1: 229 | one_map = \ 230 | skimage.transform.pyramid_expand(one_map, sigma=20, 231 | upscale=vis_size // att_sze) 232 | minV = one_map.min() 233 | maxV = one_map.max() 234 | one_map = (one_map - minV) / (maxV - minV) 235 | row_beforeNorm.append(one_map) 236 | sorted_indices = np.argsort(conf_score)[::-1] 237 | 238 | for j in range(num_attn): 239 | one_map = row_beforeNorm[j] 240 | one_map *= 255 241 | # 242 | PIL_im = Image.fromarray(np.uint8(img)) 243 | PIL_att = Image.fromarray(np.uint8(one_map)) 244 | merged = \ 245 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 246 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 247 | merged.paste(PIL_im, (0, 0)) 248 | merged.paste(PIL_att, (0, 0), mask) 249 | merged = np.array(merged)[:, :, :3] 250 | 251 | row.append(np.concatenate([one_map, middle_pad], 1)) 252 | # 253 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 254 | # 255 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 256 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 257 | row_txt.append(txt) 258 | # reorder 259 | row_new = [] 260 | row_merge_new = [] 261 | txt_new = [] 262 | for j in range(num_attn): 263 | idx = sorted_indices[j] 264 | row_new.append(row[idx]) 265 | row_merge_new.append(row_merge[idx]) 266 | txt_new.append(row_txt[idx]) 267 | row = np.concatenate(row_new[:topK], 1) 268 | row_merge = np.concatenate(row_merge_new[:topK], 1) 269 | txt = np.concatenate(txt_new[:topK], 1) 270 | if txt.shape[1] != row.shape[1]: 271 | print('Warnings: txt', txt.shape, 'row', row.shape, 272 | 'row_merge_new', row_merge_new.shape) 273 | bUpdate = 0 274 | break 275 | row = np.concatenate([txt, row_merge], 0) 276 | img_set.append(row) 277 | if bUpdate: 278 | img_set = np.concatenate(img_set, 0) 279 | img_set = img_set.astype(np.uint8) 280 | return img_set, sentences 281 | else: 282 | return None 283 | 284 | 285 | #################################################################### 286 | def weights_init(m): 287 | classname = m.__class__.__name__ 288 | if classname.find('Conv') != -1: 289 | nn.init.orthogonal(m.weight.data, 1.0) 290 | elif classname.find('BatchNorm') != -1: 291 | m.weight.data.normal_(1.0, 0.02) 292 | m.bias.data.fill_(0) 293 | elif classname.find('Linear') != -1: 294 | nn.init.orthogonal(m.weight.data, 1.0) 295 | if m.bias is not None: 296 | m.bias.data.fill_(0.0) 297 | 298 | 299 | def load_params(model, new_param): 300 | for p, new_p in zip(model.parameters(), new_param): 301 | p.data.copy_(new_p) 302 | 303 | 304 | def copy_G_params(model): 305 | flatten = deepcopy(list(p.data for p in model.parameters())) 306 | return flatten 307 | 308 | 309 | def mkdir_p(path): 310 | try: 311 | os.makedirs(path) 312 | except OSError as exc: # Python >2.5 313 | if exc.errno == errno.EEXIST and os.path.isdir(path): 314 | pass 315 | else: 316 | raise 317 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | import torch.utils.model_zoo as model_zoo 7 | import torch.nn.functional as F 8 | 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | 11 | from miscc.config import cfg 12 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET 13 | 14 | 15 | class GLU(nn.Module): 16 | def __init__(self): 17 | super(GLU, self).__init__() 18 | 19 | def forward(self, x): 20 | nc = x.size(1) 21 | assert nc % 2 == 0, 'channels dont divide 2!' 22 | nc = int(nc/2) 23 | return x[:, :nc] * F.sigmoid(x[:, nc:]) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, bias=False): 27 | "1x1 convolution with padding" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 29 | padding=0, bias=bias) 30 | 31 | 32 | def conv3x3(in_planes, out_planes): 33 | "3x3 convolution with padding" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 35 | padding=1, bias=False) 36 | 37 | 38 | # Upsale the spatial size by a factor of 2 39 | def upBlock(in_planes, out_planes): 40 | block = nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode='nearest'), 42 | conv3x3(in_planes, out_planes * 2), 43 | nn.BatchNorm2d(out_planes * 2), 44 | GLU()) 45 | return block 46 | 47 | 48 | # Keep the spatial size 49 | def Block3x3_relu(in_planes, out_planes): 50 | block = nn.Sequential( 51 | conv3x3(in_planes, out_planes * 2), 52 | nn.BatchNorm2d(out_planes * 2), 53 | GLU()) 54 | return block 55 | 56 | 57 | class ResBlock(nn.Module): 58 | def __init__(self, channel_num): 59 | super(ResBlock, self).__init__() 60 | self.block = nn.Sequential( 61 | conv3x3(channel_num, channel_num * 2), 62 | nn.BatchNorm2d(channel_num * 2), 63 | GLU(), 64 | conv3x3(channel_num, channel_num), 65 | nn.BatchNorm2d(channel_num)) 66 | 67 | def forward(self, x): 68 | residual = x 69 | out = self.block(x) 70 | out += residual 71 | return out 72 | 73 | 74 | # ############## Text2Image Encoder-Decoder ####### 75 | class RNN_ENCODER(nn.Module): 76 | def __init__(self, ntoken, ninput=300, drop_prob=0.5, 77 | nhidden=128, nlayers=1, bidirectional=True): 78 | super(RNN_ENCODER, self).__init__() 79 | self.n_steps = cfg.TEXT.WORDS_NUM 80 | self.ntoken = ntoken # size of the dictionary 81 | self.ninput = ninput # size of each embedding vector 82 | self.drop_prob = drop_prob # probability of an element to be zeroed 83 | self.nlayers = nlayers # Number of recurrent layers 84 | self.bidirectional = bidirectional 85 | self.rnn_type = cfg.RNN_TYPE 86 | if bidirectional: 87 | self.num_directions = 2 88 | else: 89 | self.num_directions = 1 90 | # number of features in the hidden state 91 | self.nhidden = nhidden // self.num_directions 92 | 93 | self.define_module() 94 | self.init_weights() 95 | 96 | def define_module(self): 97 | self.encoder = nn.Embedding(self.ntoken, self.ninput) 98 | self.drop = nn.Dropout(self.drop_prob) 99 | if self.rnn_type == 'LSTM': 100 | # dropout: If non-zero, introduces a dropout layer on 101 | # the outputs of each RNN layer except the last layer 102 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 103 | self.nlayers, batch_first=True, 104 | dropout=self.drop_prob, 105 | bidirectional=self.bidirectional) 106 | elif self.rnn_type == 'GRU': 107 | self.rnn = nn.GRU(self.ninput, self.nhidden, 108 | self.nlayers, batch_first=True, 109 | dropout=self.drop_prob, 110 | bidirectional=self.bidirectional) 111 | else: 112 | raise NotImplementedError 113 | 114 | def init_weights(self): 115 | initrange = 0.1 116 | self.encoder.weight.data.uniform_(-initrange, initrange) 117 | # Do not need to initialize RNN parameters, which have been initialized 118 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM 119 | # self.decoder.weight.data.uniform_(-initrange, initrange) 120 | # self.decoder.bias.data.fill_(0) 121 | 122 | def init_hidden(self, bsz): 123 | weight = next(self.parameters()).data 124 | if self.rnn_type == 'LSTM': 125 | return (Variable(weight.new(self.nlayers * self.num_directions, 126 | bsz, self.nhidden).zero_()), 127 | Variable(weight.new(self.nlayers * self.num_directions, 128 | bsz, self.nhidden).zero_())) 129 | else: 130 | return Variable(weight.new(self.nlayers * self.num_directions, 131 | bsz, self.nhidden).zero_()) 132 | 133 | def forward(self, captions, cap_lens, hidden, mask=None): 134 | # input: torch.LongTensor of size batch x n_steps 135 | # --> emb: batch x n_steps x ninput 136 | emb = self.drop(self.encoder(captions)) 137 | # 138 | # Returns: a PackedSequence object 139 | cap_lens = cap_lens.data.tolist() 140 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True) 141 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 142 | # tensor containing the initial hidden state for each element in batch. 143 | # #output (batch, seq_len, hidden_size * num_directions) 144 | # #or a PackedSequence object: 145 | # tensor containing output features (h_t) from the last layer of RNN 146 | output, hidden = self.rnn(emb, hidden) 147 | # PackedSequence object 148 | # --> (batch, seq_len, hidden_size * num_directions) 149 | output = pad_packed_sequence(output, batch_first=True)[0] 150 | # output = self.drop(output) 151 | # --> batch x hidden_size*num_directions x seq_len 152 | words_emb = output.transpose(1, 2) 153 | # --> batch x num_directions*hidden_size 154 | if self.rnn_type == 'LSTM': 155 | sent_emb = hidden[0].transpose(0, 1).contiguous() 156 | else: 157 | sent_emb = hidden.transpose(0, 1).contiguous() 158 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 159 | return words_emb, sent_emb 160 | 161 | 162 | class CNN_ENCODER(nn.Module): 163 | def __init__(self, nef): 164 | super(CNN_ENCODER, self).__init__() 165 | if cfg.TRAIN.FLAG: 166 | self.nef = nef 167 | else: 168 | self.nef = 256 # define a uniform ranker 169 | 170 | model = models.inception_v3() 171 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' 172 | model.load_state_dict(model_zoo.load_url(url)) 173 | for param in model.parameters(): 174 | param.requires_grad = False 175 | print('Load pretrained model from ', url) 176 | # print(model) 177 | 178 | self.define_module(model) 179 | self.init_trainable_weights() 180 | 181 | def define_module(self, model): 182 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 183 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 184 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 185 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 186 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 187 | self.Mixed_5b = model.Mixed_5b 188 | self.Mixed_5c = model.Mixed_5c 189 | self.Mixed_5d = model.Mixed_5d 190 | self.Mixed_6a = model.Mixed_6a 191 | self.Mixed_6b = model.Mixed_6b 192 | self.Mixed_6c = model.Mixed_6c 193 | self.Mixed_6d = model.Mixed_6d 194 | self.Mixed_6e = model.Mixed_6e 195 | self.Mixed_7a = model.Mixed_7a 196 | self.Mixed_7b = model.Mixed_7b 197 | self.Mixed_7c = model.Mixed_7c 198 | 199 | self.emb_features = conv1x1(768, self.nef) 200 | self.emb_cnn_code = nn.Linear(2048, self.nef) 201 | 202 | def init_trainable_weights(self): 203 | initrange = 0.1 204 | self.emb_features.weight.data.uniform_(-initrange, initrange) 205 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) 206 | 207 | def forward(self, x): 208 | features = None 209 | # --> fixed-size input: batch x 3 x 299 x 299 210 | x = nn.Upsample(size=(299, 299), mode='bilinear')(x) 211 | # 299 x 299 x 3 212 | x = self.Conv2d_1a_3x3(x) 213 | # 149 x 149 x 32 214 | x = self.Conv2d_2a_3x3(x) 215 | # 147 x 147 x 32 216 | x = self.Conv2d_2b_3x3(x) 217 | # 147 x 147 x 64 218 | x = F.max_pool2d(x, kernel_size=3, stride=2) 219 | # 73 x 73 x 64 220 | x = self.Conv2d_3b_1x1(x) 221 | # 73 x 73 x 80 222 | x = self.Conv2d_4a_3x3(x) 223 | # 71 x 71 x 192 224 | 225 | x = F.max_pool2d(x, kernel_size=3, stride=2) 226 | # 35 x 35 x 192 227 | x = self.Mixed_5b(x) 228 | # 35 x 35 x 256 229 | x = self.Mixed_5c(x) 230 | # 35 x 35 x 288 231 | x = self.Mixed_5d(x) 232 | # 35 x 35 x 288 233 | 234 | x = self.Mixed_6a(x) 235 | # 17 x 17 x 768 236 | x = self.Mixed_6b(x) 237 | # 17 x 17 x 768 238 | x = self.Mixed_6c(x) 239 | # 17 x 17 x 768 240 | x = self.Mixed_6d(x) 241 | # 17 x 17 x 768 242 | x = self.Mixed_6e(x) 243 | # 17 x 17 x 768 244 | 245 | # image region features 246 | features = x 247 | # 17 x 17 x 768 248 | 249 | x = self.Mixed_7a(x) 250 | # 8 x 8 x 1280 251 | x = self.Mixed_7b(x) 252 | # 8 x 8 x 2048 253 | x = self.Mixed_7c(x) 254 | # 8 x 8 x 2048 255 | x = F.avg_pool2d(x, kernel_size=8) 256 | # 1 x 1 x 2048 257 | # x = F.dropout(x, training=self.training) 258 | # 1 x 1 x 2048 259 | x = x.view(x.size(0), -1) 260 | # 2048 261 | 262 | # global image features 263 | cnn_code = self.emb_cnn_code(x) 264 | # 512 265 | if features is not None: 266 | features = self.emb_features(features) 267 | return features, cnn_code 268 | 269 | 270 | # ############## G networks ################### 271 | class CA_NET(nn.Module): 272 | # some code is modified from vae examples 273 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 274 | def __init__(self): 275 | super(CA_NET, self).__init__() 276 | self.t_dim = cfg.TEXT.EMBEDDING_DIM 277 | self.c_dim = cfg.GAN.CONDITION_DIM 278 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) 279 | self.relu = GLU() 280 | 281 | def encode(self, text_embedding): 282 | x = self.relu(self.fc(text_embedding)) 283 | mu = x[:, :self.c_dim] 284 | logvar = x[:, self.c_dim:] 285 | return mu, logvar 286 | 287 | def reparametrize(self, mu, logvar): 288 | std = logvar.mul(0.5).exp_() 289 | if cfg.CUDA: 290 | eps = torch.cuda.FloatTensor(std.size()).normal_() 291 | else: 292 | eps = torch.FloatTensor(std.size()).normal_() 293 | eps = Variable(eps) 294 | return eps.mul(std).add_(mu) 295 | 296 | def forward(self, text_embedding): 297 | mu, logvar = self.encode(text_embedding) 298 | c_code = self.reparametrize(mu, logvar) 299 | return c_code, mu, logvar 300 | 301 | 302 | class INIT_STAGE_G(nn.Module): 303 | def __init__(self, ngf, ncf): 304 | super(INIT_STAGE_G, self).__init__() 305 | self.gf_dim = ngf 306 | self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM 307 | 308 | self.define_module() 309 | 310 | def define_module(self): 311 | nz, ngf = self.in_dim, self.gf_dim 312 | self.fc = nn.Sequential( 313 | nn.Linear(nz, ngf * 4 * 4 * 2, bias=False), 314 | nn.BatchNorm1d(ngf * 4 * 4 * 2), 315 | GLU()) 316 | 317 | self.upsample1 = upBlock(ngf, ngf // 2) 318 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 319 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 320 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 321 | 322 | def forward(self, z_code, c_code): 323 | """ 324 | :param z_code: batch x cfg.GAN.Z_DIM 325 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM 326 | :return: batch x ngf/16 x 64 x 64 327 | """ 328 | c_z_code = torch.cat((c_code, z_code), 1) 329 | # state size ngf x 4 x 4 330 | out_code = self.fc(c_z_code) 331 | out_code = out_code.view(-1, self.gf_dim, 4, 4) 332 | # state size ngf/3 x 8 x 8 333 | out_code = self.upsample1(out_code) 334 | # state size ngf/4 x 16 x 16 335 | out_code = self.upsample2(out_code) 336 | # state size ngf/8 x 32 x 32 337 | out_code32 = self.upsample3(out_code) 338 | # state size ngf/16 x 64 x 64 339 | out_code64 = self.upsample4(out_code32) 340 | 341 | return out_code64 342 | 343 | 344 | class NEXT_STAGE_G(nn.Module): 345 | def __init__(self, ngf, nef, ncf): 346 | super(NEXT_STAGE_G, self).__init__() 347 | self.gf_dim = ngf 348 | self.ef_dim = nef 349 | self.cf_dim = ncf 350 | self.num_residual = cfg.GAN.R_NUM 351 | self.define_module() 352 | 353 | def _make_layer(self, block, channel_num): 354 | layers = [] 355 | for i in range(cfg.GAN.R_NUM): 356 | layers.append(block(channel_num)) 357 | return nn.Sequential(*layers) 358 | 359 | def define_module(self): 360 | ngf = self.gf_dim 361 | self.att = ATT_NET(ngf, self.ef_dim) 362 | self.residual = self._make_layer(ResBlock, ngf * 2) 363 | self.upsample = upBlock(ngf * 2, ngf) 364 | 365 | def forward(self, h_code, c_code, word_embs, mask): 366 | """ 367 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw) 368 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len) 369 | c_code1: batch x idf x queryL 370 | att1: batch x sourceL x queryL 371 | """ 372 | self.att.applyMask(mask) 373 | c_code, att = self.att(h_code, word_embs) 374 | h_c_code = torch.cat((h_code, c_code), 1) 375 | out_code = self.residual(h_c_code) 376 | 377 | # state size ngf/2 x 2in_size x 2in_size 378 | out_code = self.upsample(out_code) 379 | 380 | return out_code, att 381 | 382 | 383 | class GET_IMAGE_G(nn.Module): 384 | def __init__(self, ngf): 385 | super(GET_IMAGE_G, self).__init__() 386 | self.gf_dim = ngf 387 | self.img = nn.Sequential( 388 | conv3x3(ngf, 3), 389 | nn.Tanh() 390 | ) 391 | 392 | def forward(self, h_code): 393 | out_img = self.img(h_code) 394 | return out_img 395 | 396 | 397 | class G_NET(nn.Module): 398 | def __init__(self): 399 | super(G_NET, self).__init__() 400 | ngf = cfg.GAN.GF_DIM 401 | nef = cfg.TEXT.EMBEDDING_DIM 402 | ncf = cfg.GAN.CONDITION_DIM 403 | self.ca_net = CA_NET() 404 | 405 | if cfg.TREE.BRANCH_NUM > 0: 406 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 407 | self.img_net1 = GET_IMAGE_G(ngf) 408 | # gf x 64 x 64 409 | if cfg.TREE.BRANCH_NUM > 1: 410 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 411 | self.img_net2 = GET_IMAGE_G(ngf) 412 | if cfg.TREE.BRANCH_NUM > 2: 413 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 414 | self.img_net3 = GET_IMAGE_G(ngf) 415 | 416 | def forward(self, z_code, sent_emb, word_embs, mask): 417 | """ 418 | :param z_code: batch x cfg.GAN.Z_DIM 419 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 420 | :param word_embs: batch x cdf x seq_len 421 | :param mask: batch x seq_len 422 | :return: 423 | """ 424 | fake_imgs = [] 425 | att_maps = [] 426 | c_code, mu, logvar = self.ca_net(sent_emb) 427 | 428 | if cfg.TREE.BRANCH_NUM > 0: 429 | h_code1 = self.h_net1(z_code, c_code) 430 | fake_img1 = self.img_net1(h_code1) 431 | fake_imgs.append(fake_img1) 432 | if cfg.TREE.BRANCH_NUM > 1: 433 | h_code2, att1 = \ 434 | self.h_net2(h_code1, c_code, word_embs, mask) 435 | fake_img2 = self.img_net2(h_code2) 436 | fake_imgs.append(fake_img2) 437 | if att1 is not None: 438 | att_maps.append(att1) 439 | if cfg.TREE.BRANCH_NUM > 2: 440 | h_code3, att2 = \ 441 | self.h_net3(h_code2, c_code, word_embs, mask) 442 | fake_img3 = self.img_net3(h_code3) 443 | fake_imgs.append(fake_img3) 444 | if att2 is not None: 445 | att_maps.append(att2) 446 | 447 | return fake_imgs, att_maps, mu, logvar 448 | 449 | 450 | 451 | class G_DCGAN(nn.Module): 452 | def __init__(self): 453 | super(G_DCGAN, self).__init__() 454 | ngf = cfg.GAN.GF_DIM 455 | nef = cfg.TEXT.EMBEDDING_DIM 456 | ncf = cfg.GAN.CONDITION_DIM 457 | self.ca_net = CA_NET() 458 | 459 | # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64 460 | if cfg.TREE.BRANCH_NUM > 0: 461 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 462 | # gf x 64 x 64 463 | if cfg.TREE.BRANCH_NUM > 1: 464 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 465 | if cfg.TREE.BRANCH_NUM > 2: 466 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 467 | self.img_net = GET_IMAGE_G(ngf) 468 | 469 | def forward(self, z_code, sent_emb, word_embs, mask): 470 | """ 471 | :param z_code: batch x cfg.GAN.Z_DIM 472 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 473 | :param word_embs: batch x cdf x seq_len 474 | :param mask: batch x seq_len 475 | :return: 476 | """ 477 | att_maps = [] 478 | c_code, mu, logvar = self.ca_net(sent_emb) 479 | if cfg.TREE.BRANCH_NUM > 0: 480 | h_code = self.h_net1(z_code, c_code) 481 | if cfg.TREE.BRANCH_NUM > 1: 482 | h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask) 483 | if att1 is not None: 484 | att_maps.append(att1) 485 | if cfg.TREE.BRANCH_NUM > 2: 486 | h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask) 487 | if att2 is not None: 488 | att_maps.append(att2) 489 | 490 | fake_imgs = self.img_net(h_code) 491 | return [fake_imgs], att_maps, mu, logvar 492 | 493 | 494 | # ############## D networks ########################## 495 | def Block3x3_leakRelu(in_planes, out_planes): 496 | block = nn.Sequential( 497 | conv3x3(in_planes, out_planes), 498 | nn.BatchNorm2d(out_planes), 499 | nn.LeakyReLU(0.2, inplace=True) 500 | ) 501 | return block 502 | 503 | 504 | # Downsale the spatial size by a factor of 2 505 | def downBlock(in_planes, out_planes): 506 | block = nn.Sequential( 507 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 508 | nn.BatchNorm2d(out_planes), 509 | nn.LeakyReLU(0.2, inplace=True) 510 | ) 511 | return block 512 | 513 | 514 | # Downsale the spatial size by a factor of 16 515 | def encode_image_by_16times(ndf): 516 | encode_img = nn.Sequential( 517 | # --> state size. ndf x in_size/2 x in_size/2 518 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 519 | nn.LeakyReLU(0.2, inplace=True), 520 | # --> state size 2ndf x x in_size/4 x in_size/4 521 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 522 | nn.BatchNorm2d(ndf * 2), 523 | nn.LeakyReLU(0.2, inplace=True), 524 | # --> state size 4ndf x in_size/8 x in_size/8 525 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 526 | nn.BatchNorm2d(ndf * 4), 527 | nn.LeakyReLU(0.2, inplace=True), 528 | # --> state size 8ndf x in_size/16 x in_size/16 529 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 530 | nn.BatchNorm2d(ndf * 8), 531 | nn.LeakyReLU(0.2, inplace=True) 532 | ) 533 | return encode_img 534 | 535 | 536 | class D_GET_LOGITS(nn.Module): 537 | def __init__(self, ndf, nef, bcondition=False): 538 | super(D_GET_LOGITS, self).__init__() 539 | self.df_dim = ndf 540 | self.ef_dim = nef 541 | self.bcondition = bcondition 542 | if self.bcondition: 543 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8) 544 | 545 | self.outlogits = nn.Sequential( 546 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 547 | nn.Sigmoid()) 548 | 549 | def forward(self, h_code, c_code=None): 550 | if self.bcondition and c_code is not None: 551 | # conditioning output 552 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 553 | c_code = c_code.repeat(1, 1, 4, 4) 554 | # state size (ngf+egf) x 4 x 4 555 | h_c_code = torch.cat((h_code, c_code), 1) 556 | # state size ngf x in_size x in_size 557 | h_c_code = self.jointConv(h_c_code) 558 | else: 559 | h_c_code = h_code 560 | 561 | output = self.outlogits(h_c_code) 562 | return output.view(-1) 563 | 564 | 565 | # For 64 x 64 images 566 | class D_NET64(nn.Module): 567 | def __init__(self, b_jcu=True): 568 | super(D_NET64, self).__init__() 569 | ndf = cfg.GAN.DF_DIM 570 | nef = cfg.TEXT.EMBEDDING_DIM 571 | self.img_code_s16 = encode_image_by_16times(ndf) 572 | if b_jcu: 573 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 574 | else: 575 | self.UNCOND_DNET = None 576 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 577 | 578 | def forward(self, x_var): 579 | x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df 580 | return x_code4 581 | 582 | 583 | # For 128 x 128 images 584 | class D_NET128(nn.Module): 585 | def __init__(self, b_jcu=True): 586 | super(D_NET128, self).__init__() 587 | ndf = cfg.GAN.DF_DIM 588 | nef = cfg.TEXT.EMBEDDING_DIM 589 | self.img_code_s16 = encode_image_by_16times(ndf) 590 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 591 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) 592 | # 593 | if b_jcu: 594 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 595 | else: 596 | self.UNCOND_DNET = None 597 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 598 | 599 | def forward(self, x_var): 600 | x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df 601 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df 602 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df 603 | return x_code4 604 | 605 | 606 | # For 256 x 256 images 607 | class D_NET256(nn.Module): 608 | def __init__(self, b_jcu=True): 609 | super(D_NET256, self).__init__() 610 | ndf = cfg.GAN.DF_DIM 611 | nef = cfg.TEXT.EMBEDDING_DIM 612 | self.img_code_s16 = encode_image_by_16times(ndf) 613 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 614 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32) 615 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16) 616 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8) 617 | if b_jcu: 618 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 619 | else: 620 | self.UNCOND_DNET = None 621 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 622 | 623 | def forward(self, x_var): 624 | x_code16 = self.img_code_s16(x_var) 625 | x_code8 = self.img_code_s32(x_code16) 626 | x_code4 = self.img_code_s64(x_code8) 627 | x_code4 = self.img_code_s64_1(x_code4) 628 | x_code4 = self.img_code_s64_2(x_code4) 629 | return x_code4 630 | -------------------------------------------------------------------------------- /code/pretrain_DAMSM.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from miscc.utils import mkdir_p 4 | from miscc.utils import build_super_images 5 | from miscc.losses import sent_loss, words_loss 6 | from miscc.config import cfg, cfg_from_file 7 | 8 | from datasets import TextDataset 9 | from datasets import prepare_data 10 | 11 | from model import RNN_ENCODER, CNN_ENCODER 12 | 13 | import os 14 | import sys 15 | import time 16 | import random 17 | import pprint 18 | import datetime 19 | import dateutil.tz 20 | import argparse 21 | import numpy as np 22 | from PIL import Image 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.optim as optim 27 | from torch.autograd import Variable 28 | import torch.backends.cudnn as cudnn 29 | import torchvision.transforms as transforms 30 | 31 | 32 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 33 | sys.path.append(dir_path) 34 | 35 | 36 | UPDATE_INTERVAL = 200 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='Train a DAMSM network') 39 | parser.add_argument('--cfg', dest='cfg_file', 40 | help='optional config file', 41 | default='cfg/DAMSM/bird.yml', type=str) 42 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) 43 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 44 | parser.add_argument('--manualSeed', type=int, help='manual seed') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def train(dataloader, cnn_model, rnn_model, batch_size, 50 | labels, optimizer, epoch, ixtoword, image_dir): 51 | cnn_model.train() 52 | rnn_model.train() 53 | s_total_loss0 = 0 54 | s_total_loss1 = 0 55 | w_total_loss0 = 0 56 | w_total_loss1 = 0 57 | count = (epoch + 1) * len(dataloader) 58 | start_time = time.time() 59 | for step, data in enumerate(dataloader, 0): 60 | # print('step', step) 61 | rnn_model.zero_grad() 62 | cnn_model.zero_grad() 63 | 64 | imgs, captions, cap_lens, \ 65 | class_ids, keys = prepare_data(data) 66 | 67 | 68 | # words_features: batch_size x nef x 17 x 17 69 | # sent_code: batch_size x nef 70 | words_features, sent_code = cnn_model(imgs[-1]) 71 | # --> batch_size x nef x 17*17 72 | nef, att_sze = words_features.size(1), words_features.size(2) 73 | # words_features = words_features.view(batch_size, nef, -1) 74 | 75 | hidden = rnn_model.init_hidden(batch_size) 76 | # words_emb: batch_size x nef x seq_len 77 | # sent_emb: batch_size x nef 78 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 79 | 80 | w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, 81 | cap_lens, class_ids, batch_size) 82 | w_total_loss0 += w_loss0.data 83 | w_total_loss1 += w_loss1.data 84 | loss = w_loss0 + w_loss1 85 | 86 | s_loss0, s_loss1 = \ 87 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 88 | loss += s_loss0 + s_loss1 89 | s_total_loss0 += s_loss0.data 90 | s_total_loss1 += s_loss1.data 91 | # 92 | loss.backward() 93 | # 94 | # `clip_grad_norm` helps prevent 95 | # the exploding gradient problem in RNNs / LSTMs. 96 | torch.nn.utils.clip_grad_norm(rnn_model.parameters(), 97 | cfg.TRAIN.RNN_GRAD_CLIP) 98 | optimizer.step() 99 | 100 | if step % UPDATE_INTERVAL == 0: 101 | count = epoch * len(dataloader) + step 102 | 103 | s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL 104 | s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL 105 | 106 | w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL 107 | w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL 108 | 109 | elapsed = time.time() - start_time 110 | print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 111 | 's_loss {:5.2f} {:5.2f} | ' 112 | 'w_loss {:5.2f} {:5.2f}' 113 | .format(epoch, step, len(dataloader), 114 | elapsed * 1000. / UPDATE_INTERVAL, 115 | s_cur_loss0, s_cur_loss1, 116 | w_cur_loss0, w_cur_loss1)) 117 | s_total_loss0 = 0 118 | s_total_loss1 = 0 119 | w_total_loss0 = 0 120 | w_total_loss1 = 0 121 | start_time = time.time() 122 | # attention Maps 123 | img_set, _ = \ 124 | build_super_images(imgs[-1].cpu(), captions, 125 | ixtoword, attn_maps, att_sze) 126 | if img_set is not None: 127 | im = Image.fromarray(img_set) 128 | fullpath = '%s/attention_maps%d.png' % (image_dir, step) 129 | im.save(fullpath) 130 | return count 131 | 132 | 133 | def evaluate(dataloader, cnn_model, rnn_model, batch_size): 134 | cnn_model.eval() 135 | rnn_model.eval() 136 | s_total_loss = 0 137 | w_total_loss = 0 138 | for step, data in enumerate(dataloader, 0): 139 | real_imgs, captions, cap_lens, \ 140 | class_ids, keys = prepare_data(data) 141 | 142 | words_features, sent_code = cnn_model(real_imgs[-1]) 143 | # nef = words_features.size(1) 144 | # words_features = words_features.view(batch_size, nef, -1) 145 | 146 | hidden = rnn_model.init_hidden(batch_size) 147 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 148 | 149 | w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, 150 | cap_lens, class_ids, batch_size) 151 | w_total_loss += (w_loss0 + w_loss1).data 152 | 153 | s_loss0, s_loss1 = \ 154 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 155 | s_total_loss += (s_loss0 + s_loss1).data 156 | 157 | if step == 50: 158 | break 159 | 160 | s_cur_loss = s_total_loss[0] / step 161 | w_cur_loss = w_total_loss[0] / step 162 | 163 | return s_cur_loss, w_cur_loss 164 | 165 | 166 | def build_models(): 167 | # build model ############################################################ 168 | text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 169 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) 170 | labels = Variable(torch.LongTensor(range(batch_size))) 171 | start_epoch = 0 172 | if cfg.TRAIN.NET_E != '': 173 | state_dict = torch.load(cfg.TRAIN.NET_E) 174 | text_encoder.load_state_dict(state_dict) 175 | print('Load ', cfg.TRAIN.NET_E) 176 | # 177 | name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 178 | state_dict = torch.load(name) 179 | image_encoder.load_state_dict(state_dict) 180 | print('Load ', name) 181 | 182 | istart = cfg.TRAIN.NET_E.rfind('_') + 8 183 | iend = cfg.TRAIN.NET_E.rfind('.') 184 | start_epoch = cfg.TRAIN.NET_E[istart:iend] 185 | start_epoch = int(start_epoch) + 1 186 | print('start_epoch', start_epoch) 187 | if cfg.CUDA: 188 | text_encoder = text_encoder.cuda() 189 | image_encoder = image_encoder.cuda() 190 | labels = labels.cuda() 191 | 192 | return text_encoder, image_encoder, labels, start_epoch 193 | 194 | 195 | if __name__ == "__main__": 196 | args = parse_args() 197 | if args.cfg_file is not None: 198 | cfg_from_file(args.cfg_file) 199 | 200 | if args.gpu_id == -1: 201 | cfg.CUDA = False 202 | else: 203 | cfg.GPU_ID = args.gpu_id 204 | 205 | if args.data_dir != '': 206 | cfg.DATA_DIR = args.data_dir 207 | print('Using config:') 208 | pprint.pprint(cfg) 209 | 210 | if not cfg.TRAIN.FLAG: 211 | args.manualSeed = 100 212 | elif args.manualSeed is None: 213 | args.manualSeed = random.randint(1, 10000) 214 | random.seed(args.manualSeed) 215 | np.random.seed(args.manualSeed) 216 | torch.manual_seed(args.manualSeed) 217 | if cfg.CUDA: 218 | torch.cuda.manual_seed_all(args.manualSeed) 219 | 220 | ########################################################################## 221 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 222 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 223 | output_dir = '../output/%s_%s_%s' % \ 224 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 225 | 226 | model_dir = os.path.join(output_dir, 'Model') 227 | image_dir = os.path.join(output_dir, 'Image') 228 | mkdir_p(model_dir) 229 | mkdir_p(image_dir) 230 | 231 | torch.cuda.set_device(cfg.GPU_ID) 232 | cudnn.benchmark = True 233 | 234 | # Get data loader ################################################## 235 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) 236 | batch_size = cfg.TRAIN.BATCH_SIZE 237 | image_transform = transforms.Compose([ 238 | transforms.Scale(int(imsize * 76 / 64)), 239 | transforms.RandomCrop(imsize), 240 | transforms.RandomHorizontalFlip()]) 241 | dataset = TextDataset(cfg.DATA_DIR, 'train', 242 | base_size=cfg.TREE.BASE_SIZE, 243 | transform=image_transform) 244 | 245 | print(dataset.n_words, dataset.embeddings_num) 246 | assert dataset 247 | dataloader = torch.utils.data.DataLoader( 248 | dataset, batch_size=batch_size, drop_last=True, 249 | shuffle=True, num_workers=int(cfg.WORKERS)) 250 | 251 | # # validation data # 252 | dataset_val = TextDataset(cfg.DATA_DIR, 'test', 253 | base_size=cfg.TREE.BASE_SIZE, 254 | transform=image_transform) 255 | dataloader_val = torch.utils.data.DataLoader( 256 | dataset_val, batch_size=batch_size, drop_last=True, 257 | shuffle=True, num_workers=int(cfg.WORKERS)) 258 | 259 | # Train ############################################################## 260 | text_encoder, image_encoder, labels, start_epoch = build_models() 261 | para = list(text_encoder.parameters()) 262 | for v in image_encoder.parameters(): 263 | if v.requires_grad: 264 | para.append(v) 265 | # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) 266 | # At any point you can hit Ctrl + C to break out of training early. 267 | try: 268 | lr = cfg.TRAIN.ENCODER_LR 269 | for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): 270 | optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) 271 | epoch_start_time = time.time() 272 | count = train(dataloader, image_encoder, text_encoder, 273 | batch_size, labels, optimizer, epoch, 274 | dataset.ixtoword, image_dir) 275 | print('-' * 89) 276 | if len(dataloader_val) > 0: 277 | s_loss, w_loss = evaluate(dataloader_val, image_encoder, 278 | text_encoder, batch_size) 279 | print('| end epoch {:3d} | valid loss ' 280 | '{:5.2f} {:5.2f} | lr {:.5f}|' 281 | .format(epoch, s_loss, w_loss, lr)) 282 | print('-' * 89) 283 | if lr > cfg.TRAIN.ENCODER_LR/10.: 284 | lr *= 0.98 285 | 286 | if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or 287 | epoch == cfg.TRAIN.MAX_EPOCH): 288 | torch.save(image_encoder.state_dict(), 289 | '%s/image_encoder%d.pth' % (model_dir, epoch)) 290 | torch.save(text_encoder.state_dict(), 291 | '%s/text_encoder%d.pth' % (model_dir, epoch)) 292 | print('Save G/Ds models.') 293 | except KeyboardInterrupt: 294 | print('-' * 89) 295 | print('Exiting from training early') 296 | -------------------------------------------------------------------------------- /code/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | import torch.backends.cudnn as cudnn 9 | 10 | from PIL import Image 11 | 12 | from miscc.config import cfg 13 | from miscc.utils import mkdir_p 14 | from miscc.utils import build_super_images, build_super_images2 15 | from miscc.utils import weights_init, load_params, copy_G_params 16 | from model import G_DCGAN, G_NET 17 | from datasets import prepare_data 18 | from model import RNN_ENCODER, CNN_ENCODER 19 | 20 | from miscc.losses import words_loss 21 | from miscc.losses import discriminator_loss, generator_loss, KL_loss 22 | import os 23 | import time 24 | import numpy as np 25 | import sys 26 | 27 | # ################# Text to image task############################ # 28 | class condGANTrainer(object): 29 | def __init__(self, output_dir, data_loader, n_words, ixtoword): 30 | if cfg.TRAIN.FLAG: 31 | self.model_dir = os.path.join(output_dir, 'Model') 32 | self.image_dir = os.path.join(output_dir, 'Image') 33 | mkdir_p(self.model_dir) 34 | mkdir_p(self.image_dir) 35 | 36 | torch.cuda.set_device(cfg.GPU_ID) 37 | cudnn.benchmark = True 38 | 39 | self.batch_size = cfg.TRAIN.BATCH_SIZE 40 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 41 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 42 | 43 | self.n_words = n_words 44 | self.ixtoword = ixtoword 45 | self.data_loader = data_loader 46 | self.num_batches = len(self.data_loader) 47 | 48 | def build_models(self): 49 | # ###################encoders######################################## # 50 | if cfg.TRAIN.NET_E == '': 51 | print('Error: no pretrained text-image encoders') 52 | return 53 | 54 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) 55 | img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 56 | state_dict = \ 57 | torch.load(img_encoder_path, map_location=lambda storage, loc: storage) 58 | image_encoder.load_state_dict(state_dict) 59 | for p in image_encoder.parameters(): 60 | p.requires_grad = False 61 | print('Load image encoder from:', img_encoder_path) 62 | image_encoder.eval() 63 | 64 | text_encoder = \ 65 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 66 | state_dict = \ 67 | torch.load(cfg.TRAIN.NET_E, 68 | map_location=lambda storage, loc: storage) 69 | text_encoder.load_state_dict(state_dict) 70 | for p in text_encoder.parameters(): 71 | p.requires_grad = False 72 | print('Load text encoder from:', cfg.TRAIN.NET_E) 73 | text_encoder.eval() 74 | 75 | # #######################generator and discriminators############## # 76 | netsD = [] 77 | if cfg.GAN.B_DCGAN: 78 | if cfg.TREE.BRANCH_NUM ==1: 79 | from model import D_NET64 as D_NET 80 | elif cfg.TREE.BRANCH_NUM == 2: 81 | from model import D_NET128 as D_NET 82 | else: # cfg.TREE.BRANCH_NUM == 3: 83 | from model import D_NET256 as D_NET 84 | # TODO: elif cfg.TREE.BRANCH_NUM > 3: 85 | netG = G_DCGAN() 86 | netsD = [D_NET(b_jcu=False)] 87 | else: 88 | from model import D_NET64, D_NET128, D_NET256 89 | netG = G_NET() 90 | if cfg.TREE.BRANCH_NUM > 0: 91 | netsD.append(D_NET64()) 92 | if cfg.TREE.BRANCH_NUM > 1: 93 | netsD.append(D_NET128()) 94 | if cfg.TREE.BRANCH_NUM > 2: 95 | netsD.append(D_NET256()) 96 | # TODO: if cfg.TREE.BRANCH_NUM > 3: 97 | netG.apply(weights_init) 98 | # print(netG) 99 | for i in range(len(netsD)): 100 | netsD[i].apply(weights_init) 101 | # print(netsD[i]) 102 | print('# of netsD', len(netsD)) 103 | # 104 | epoch = 0 105 | if cfg.TRAIN.NET_G != '': 106 | state_dict = \ 107 | torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) 108 | netG.load_state_dict(state_dict) 109 | print('Load G from: ', cfg.TRAIN.NET_G) 110 | istart = cfg.TRAIN.NET_G.rfind('_') + 1 111 | iend = cfg.TRAIN.NET_G.rfind('.') 112 | epoch = cfg.TRAIN.NET_G[istart:iend] 113 | epoch = int(epoch) + 1 114 | if cfg.TRAIN.B_NET_D: 115 | Gname = cfg.TRAIN.NET_G 116 | for i in range(len(netsD)): 117 | s_tmp = Gname[:Gname.rfind('/')] 118 | Dname = '%s/netD%d.pth' % (s_tmp, i) 119 | print('Load D from: ', Dname) 120 | state_dict = \ 121 | torch.load(Dname, map_location=lambda storage, loc: storage) 122 | netsD[i].load_state_dict(state_dict) 123 | # ########################################################### # 124 | if cfg.CUDA: 125 | text_encoder = text_encoder.cuda() 126 | image_encoder = image_encoder.cuda() 127 | netG.cuda() 128 | for i in range(len(netsD)): 129 | netsD[i].cuda() 130 | return [text_encoder, image_encoder, netG, netsD, epoch] 131 | 132 | def define_optimizers(self, netG, netsD): 133 | optimizersD = [] 134 | num_Ds = len(netsD) 135 | for i in range(num_Ds): 136 | opt = optim.Adam(netsD[i].parameters(), 137 | lr=cfg.TRAIN.DISCRIMINATOR_LR, 138 | betas=(0.5, 0.999)) 139 | optimizersD.append(opt) 140 | 141 | optimizerG = optim.Adam(netG.parameters(), 142 | lr=cfg.TRAIN.GENERATOR_LR, 143 | betas=(0.5, 0.999)) 144 | 145 | return optimizerG, optimizersD 146 | 147 | def prepare_labels(self): 148 | batch_size = self.batch_size 149 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 150 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 151 | match_labels = Variable(torch.LongTensor(range(batch_size))) 152 | if cfg.CUDA: 153 | real_labels = real_labels.cuda() 154 | fake_labels = fake_labels.cuda() 155 | match_labels = match_labels.cuda() 156 | 157 | return real_labels, fake_labels, match_labels 158 | 159 | def save_model(self, netG, avg_param_G, netsD, epoch): 160 | backup_para = copy_G_params(netG) 161 | load_params(netG, avg_param_G) 162 | torch.save(netG.state_dict(), 163 | '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)) 164 | load_params(netG, backup_para) 165 | # 166 | for i in range(len(netsD)): 167 | netD = netsD[i] 168 | torch.save(netD.state_dict(), 169 | '%s/netD%d.pth' % (self.model_dir, i)) 170 | print('Save G/Ds models.') 171 | 172 | def set_requires_grad_value(self, models_list, brequires): 173 | for i in range(len(models_list)): 174 | for p in models_list[i].parameters(): 175 | p.requires_grad = brequires 176 | 177 | def save_img_results(self, netG, noise, sent_emb, words_embs, mask, 178 | image_encoder, captions, cap_lens, 179 | gen_iterations, name='current'): 180 | # Save images 181 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 182 | for i in range(len(attention_maps)): 183 | if len(fake_imgs) > 1: 184 | img = fake_imgs[i + 1].detach().cpu() 185 | lr_img = fake_imgs[i].detach().cpu() 186 | else: 187 | img = fake_imgs[0].detach().cpu() 188 | lr_img = None 189 | attn_maps = attention_maps[i] 190 | att_sze = attn_maps.size(2) 191 | img_set, _ = \ 192 | build_super_images(img, captions, self.ixtoword, 193 | attn_maps, att_sze, lr_imgs=lr_img) 194 | if img_set is not None: 195 | im = Image.fromarray(img_set) 196 | fullpath = '%s/G_%s_%d_%d.png'\ 197 | % (self.image_dir, name, gen_iterations, i) 198 | im.save(fullpath) 199 | 200 | # for i in range(len(netsD)): 201 | i = -1 202 | img = fake_imgs[i].detach() 203 | region_features, _ = image_encoder(img) 204 | att_sze = region_features.size(2) 205 | _, _, att_maps = words_loss(region_features.detach(), 206 | words_embs.detach(), 207 | None, cap_lens, 208 | None, self.batch_size) 209 | img_set, _ = \ 210 | build_super_images(fake_imgs[i].detach().cpu(), 211 | captions, self.ixtoword, att_maps, att_sze) 212 | if img_set is not None: 213 | im = Image.fromarray(img_set) 214 | fullpath = '%s/D_%s_%d.png'\ 215 | % (self.image_dir, name, gen_iterations) 216 | im.save(fullpath) 217 | 218 | def train(self): 219 | text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() 220 | avg_param_G = copy_G_params(netG) 221 | optimizerG, optimizersD = self.define_optimizers(netG, netsD) 222 | real_labels, fake_labels, match_labels = self.prepare_labels() 223 | 224 | batch_size = self.batch_size 225 | nz = cfg.GAN.Z_DIM 226 | noise = Variable(torch.FloatTensor(batch_size, nz)) 227 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) 228 | if cfg.CUDA: 229 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 230 | 231 | gen_iterations = 0 232 | # gen_iterations = start_epoch * self.num_batches 233 | for epoch in range(start_epoch, self.max_epoch): 234 | start_t = time.time() 235 | 236 | data_iter = iter(self.data_loader) 237 | step = 0 238 | while step < self.num_batches: 239 | # reset requires_grad to be trainable for all Ds 240 | # self.set_requires_grad_value(netsD, True) 241 | 242 | ###################################################### 243 | # (1) Prepare training data and Compute text embeddings 244 | ###################################################### 245 | data = data_iter.next() 246 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 247 | 248 | hidden = text_encoder.init_hidden(batch_size) 249 | # words_embs: batch_size x nef x seq_len 250 | # sent_emb: batch_size x nef 251 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 252 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 253 | mask = (captions == 0) 254 | num_words = words_embs.size(2) 255 | if mask.size(1) > num_words: 256 | mask = mask[:, :num_words] 257 | 258 | ####################################################### 259 | # (2) Generate fake images 260 | ###################################################### 261 | noise.data.normal_(0, 1) 262 | fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) 263 | 264 | ####################################################### 265 | # (3) Update D network 266 | ###################################################### 267 | errD_total = 0 268 | D_logs = '' 269 | for i in range(len(netsD)): 270 | netsD[i].zero_grad() 271 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], 272 | sent_emb, real_labels, fake_labels) 273 | # backward and update parameters 274 | errD.backward() 275 | optimizersD[i].step() 276 | errD_total += errD 277 | D_logs += 'errD%d: %.2f ' % (i, errD.data[0]) 278 | 279 | ####################################################### 280 | # (4) Update G network: maximize log(D(G(z))) 281 | ###################################################### 282 | # compute total loss for training G 283 | step += 1 284 | gen_iterations += 1 285 | 286 | # do not need to compute gradient for Ds 287 | # self.set_requires_grad_value(netsD, False) 288 | netG.zero_grad() 289 | errG_total, G_logs = \ 290 | generator_loss(netsD, image_encoder, fake_imgs, real_labels, 291 | words_embs, sent_emb, match_labels, cap_lens, class_ids) 292 | kl_loss = KL_loss(mu, logvar) 293 | errG_total += kl_loss 294 | G_logs += 'kl_loss: %.2f ' % kl_loss.data[0] 295 | # backward and update parameters 296 | errG_total.backward() 297 | optimizerG.step() 298 | for p, avg_p in zip(netG.parameters(), avg_param_G): 299 | avg_p.mul_(0.999).add_(0.001, p.data) 300 | 301 | if gen_iterations % 100 == 0: 302 | print(D_logs + '\n' + G_logs) 303 | # save images 304 | if gen_iterations % 1000 == 0: 305 | backup_para = copy_G_params(netG) 306 | load_params(netG, avg_param_G) 307 | self.save_img_results(netG, fixed_noise, sent_emb, 308 | words_embs, mask, image_encoder, 309 | captions, cap_lens, epoch, name='average') 310 | load_params(netG, backup_para) 311 | # 312 | # self.save_img_results(netG, fixed_noise, sent_emb, 313 | # words_embs, mask, image_encoder, 314 | # captions, cap_lens, 315 | # epoch, name='current') 316 | end_t = time.time() 317 | 318 | print('''[%d/%d][%d] 319 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' 320 | % (epoch, self.max_epoch, self.num_batches, 321 | errD_total.data[0], errG_total.data[0], 322 | end_t - start_t)) 323 | 324 | if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: 325 | self.save_model(netG, avg_param_G, netsD, epoch) 326 | 327 | self.save_model(netG, avg_param_G, netsD, self.max_epoch) 328 | 329 | def save_singleimages(self, images, filenames, save_dir, 330 | split_dir, sentenceID=0): 331 | for i in range(images.size(0)): 332 | s_tmp = '%s/single_samples/%s/%s' %\ 333 | (save_dir, split_dir, filenames[i]) 334 | folder = s_tmp[:s_tmp.rfind('/')] 335 | if not os.path.isdir(folder): 336 | print('Make a new folder: ', folder) 337 | mkdir_p(folder) 338 | 339 | fullpath = '%s_%d.jpg' % (s_tmp, sentenceID) 340 | # range from [-1, 1] to [0, 1] 341 | # img = (images[i] + 1.0) / 2 342 | img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() 343 | # range from [0, 1] to [0, 255] 344 | ndarr = img.permute(1, 2, 0).data.cpu().numpy() 345 | im = Image.fromarray(ndarr) 346 | im.save(fullpath) 347 | 348 | def sampling(self, split_dir): 349 | if cfg.TRAIN.NET_G == '': 350 | print('Error: the path for morels is not found!') 351 | else: 352 | if split_dir == 'test': 353 | split_dir = 'valid' 354 | # Build and load the generator 355 | if cfg.GAN.B_DCGAN: 356 | netG = G_DCGAN() 357 | else: 358 | netG = G_NET() 359 | netG.apply(weights_init) 360 | netG.cuda() 361 | netG.eval() 362 | # 363 | text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 364 | state_dict = \ 365 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 366 | text_encoder.load_state_dict(state_dict) 367 | print('Load text encoder from:', cfg.TRAIN.NET_E) 368 | text_encoder = text_encoder.cuda() 369 | text_encoder.eval() 370 | 371 | batch_size = self.batch_size 372 | nz = cfg.GAN.Z_DIM 373 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 374 | noise = noise.cuda() 375 | 376 | model_dir = cfg.TRAIN.NET_G 377 | state_dict = \ 378 | torch.load(model_dir, map_location=lambda storage, loc: storage) 379 | # state_dict = torch.load(cfg.TRAIN.NET_G) 380 | netG.load_state_dict(state_dict) 381 | print('Load G from: ', model_dir) 382 | 383 | # the path to save generated images 384 | s_tmp = model_dir[:model_dir.rfind('.pth')] 385 | save_dir = '%s/%s' % (s_tmp, split_dir) 386 | mkdir_p(save_dir) 387 | 388 | cnt = 0 389 | 390 | for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): 391 | for step, data in enumerate(self.data_loader, 0): 392 | cnt += batch_size 393 | if step % 100 == 0: 394 | print('step: ', step) 395 | # if step > 50: 396 | # break 397 | 398 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 399 | 400 | hidden = text_encoder.init_hidden(batch_size) 401 | # words_embs: batch_size x nef x seq_len 402 | # sent_emb: batch_size x nef 403 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 404 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 405 | mask = (captions == 0) 406 | num_words = words_embs.size(2) 407 | if mask.size(1) > num_words: 408 | mask = mask[:, :num_words] 409 | 410 | ####################################################### 411 | # (2) Generate fake images 412 | ###################################################### 413 | noise.data.normal_(0, 1) 414 | fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) 415 | for j in range(batch_size): 416 | s_tmp = '%s/single/%s' % (save_dir, keys[j]) 417 | folder = s_tmp[:s_tmp.rfind('/')] 418 | if not os.path.isdir(folder): 419 | print('Make a new folder: ', folder) 420 | mkdir_p(folder) 421 | k = -1 422 | # for k in range(len(fake_imgs)): 423 | im = fake_imgs[k][j].data.cpu().numpy() 424 | # [-1, 1] --> [0, 255] 425 | im = (im + 1.0) * 127.5 426 | im = im.astype(np.uint8) 427 | im = np.transpose(im, (1, 2, 0)) 428 | im = Image.fromarray(im) 429 | fullpath = '%s_s%d.png' % (s_tmp, k) 430 | im.save(fullpath) 431 | 432 | def gen_example(self, data_dic): 433 | if cfg.TRAIN.NET_G == '': 434 | print('Error: the path for morels is not found!') 435 | else: 436 | # Build and load the generator 437 | text_encoder = \ 438 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 439 | state_dict = \ 440 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 441 | text_encoder.load_state_dict(state_dict) 442 | print('Load text encoder from:', cfg.TRAIN.NET_E) 443 | text_encoder = text_encoder.cuda() 444 | text_encoder.eval() 445 | 446 | # the path to save generated images 447 | if cfg.GAN.B_DCGAN: 448 | netG = G_DCGAN() 449 | else: 450 | netG = G_NET() 451 | s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] 452 | model_dir = cfg.TRAIN.NET_G 453 | state_dict = \ 454 | torch.load(model_dir, map_location=lambda storage, loc: storage) 455 | netG.load_state_dict(state_dict) 456 | print('Load G from: ', model_dir) 457 | netG.cuda() 458 | netG.eval() 459 | for key in data_dic: 460 | save_dir = '%s/%s' % (s_tmp, key) 461 | mkdir_p(save_dir) 462 | captions, cap_lens, sorted_indices = data_dic[key] 463 | 464 | batch_size = captions.shape[0] 465 | nz = cfg.GAN.Z_DIM 466 | captions = Variable(torch.from_numpy(captions), volatile=True) 467 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) 468 | 469 | captions = captions.cuda() 470 | cap_lens = cap_lens.cuda() 471 | for i in range(1): # 16 472 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 473 | noise = noise.cuda() 474 | ####################################################### 475 | # (1) Extract text embeddings 476 | ###################################################### 477 | hidden = text_encoder.init_hidden(batch_size) 478 | # words_embs: batch_size x nef x seq_len 479 | # sent_emb: batch_size x nef 480 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 481 | mask = (captions == 0) 482 | ####################################################### 483 | # (2) Generate fake images 484 | ###################################################### 485 | noise.data.normal_(0, 1) 486 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 487 | # G attention 488 | cap_lens_np = cap_lens.cpu().data.numpy() 489 | for j in range(batch_size): 490 | save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) 491 | for k in range(len(fake_imgs)): 492 | im = fake_imgs[k][j].data.cpu().numpy() 493 | im = (im + 1.0) * 127.5 494 | im = im.astype(np.uint8) 495 | # print('im', im.shape) 496 | im = np.transpose(im, (1, 2, 0)) 497 | # print('im', im.shape) 498 | im = Image.fromarray(im) 499 | fullpath = '%s_g%d.png' % (save_name, k) 500 | im.save(fullpath) 501 | 502 | for k in range(len(attention_maps)): 503 | if len(fake_imgs) > 1: 504 | im = fake_imgs[k + 1].detach().cpu() 505 | else: 506 | im = fake_imgs[0].detach().cpu() 507 | attn_maps = attention_maps[k] 508 | att_sze = attn_maps.size(2) 509 | img_set, sentences = \ 510 | build_super_images2(im[j].unsqueeze(0), 511 | captions[j].unsqueeze(0), 512 | [cap_lens_np[j]], self.ixtoword, 513 | [attn_maps[j]], att_sze) 514 | if img_set is not None: 515 | im = Image.fromarray(img_set) 516 | fullpath = '%s_a%d.png' % (save_name, k) 517 | im.save(fullpath) 518 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /eval/FreeMono.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/FreeMono.ttf -------------------------------------------------------------------------------- /eval/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query metrix. 3 | Based on each query vector q, it computes a parameterized convex combination of the matrix 4 | based. 5 | H_1 H_2 H_3 ... H_n 6 | q q q q 7 | | | | | 8 | \ | | / 9 | ..... 10 | \ | / 11 | a 12 | Constructs a unit mapping. 13 | $$(H_1 + H_n, q) => (a)$$ 14 | Where H is of `batch x n x dim` and q is of `batch x dim`. 15 | 16 | References: 17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules 18 | http://www.aclweb.org/anthology/D15-1166 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def conv1x1(in_planes, out_planes): 26 | "1x1 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | 30 | 31 | def func_attention(query, context, gamma1): 32 | """ 33 | query: batch x ndf x queryL 34 | context: batch x ndf x ih x iw (sourceL=ihxiw) 35 | mask: batch_size x sourceL 36 | """ 37 | batch_size, queryL = query.size(0), query.size(2) 38 | ih, iw = context.size(2), context.size(3) 39 | sourceL = ih * iw 40 | 41 | # --> batch x sourceL x ndf 42 | context = context.view(batch_size, -1, sourceL) 43 | contextT = torch.transpose(context, 1, 2).contiguous() 44 | 45 | # Get attention 46 | # (batch x sourceL x ndf)(batch x ndf x queryL) 47 | # -->batch x sourceL x queryL 48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper 49 | # --> batch*sourceL x queryL 50 | attn = attn.view(batch_size*sourceL, queryL) 51 | attn = nn.Softmax()(attn) # Eq. (8) 52 | 53 | # --> batch x sourceL x queryL 54 | attn = attn.view(batch_size, sourceL, queryL) 55 | # --> batch*queryL x sourceL 56 | attn = torch.transpose(attn, 1, 2).contiguous() 57 | attn = attn.view(batch_size*queryL, sourceL) 58 | # Eq. (9) 59 | attn = attn * gamma1 60 | attn = nn.Softmax()(attn) 61 | attn = attn.view(batch_size, queryL, sourceL) 62 | # --> batch x sourceL x queryL 63 | attnT = torch.transpose(attn, 1, 2).contiguous() 64 | 65 | # (batch x ndf x sourceL)(batch x sourceL x queryL) 66 | # --> batch x ndf x queryL 67 | weightedContext = torch.bmm(context, attnT) 68 | 69 | return weightedContext, attn.view(batch_size, -1, ih, iw) 70 | 71 | 72 | class GlobalAttentionGeneral(nn.Module): 73 | def __init__(self, idf, cdf): 74 | super(GlobalAttentionGeneral, self).__init__() 75 | self.conv_context = conv1x1(cdf, idf) 76 | self.sm = nn.Softmax() 77 | self.mask = None 78 | 79 | def applyMask(self, mask): 80 | self.mask = mask # batch x sourceL 81 | 82 | def forward(self, input, context): 83 | """ 84 | input: batch x idf x ih x iw (queryL=ihxiw) 85 | context: batch x cdf x sourceL 86 | """ 87 | ih, iw = input.size(2), input.size(3) 88 | queryL = ih * iw 89 | batch_size, sourceL = context.size(0), context.size(2) 90 | 91 | # --> batch x queryL x idf 92 | target = input.view(batch_size, -1, queryL) 93 | targetT = torch.transpose(target, 1, 2).contiguous() 94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1 95 | sourceT = context.unsqueeze(3) 96 | # --> batch x idf x sourceL 97 | sourceT = self.conv_context(sourceT).squeeze(3) 98 | 99 | # Get attention 100 | # (batch x queryL x idf)(batch x idf x sourceL) 101 | # -->batch x queryL x sourceL 102 | attn = torch.bmm(targetT, sourceT) 103 | # --> batch*queryL x sourceL 104 | attn = attn.view(batch_size*queryL, sourceL) 105 | if self.mask is not None: 106 | # batch_size x sourceL --> batch_size*queryL x sourceL 107 | mask = self.mask.repeat(queryL, 1) 108 | attn.data.masked_fill_(mask.data, -float('inf')) 109 | attn = self.sm(attn) # Eq. (2) 110 | # --> batch x queryL x sourceL 111 | attn = attn.view(batch_size, queryL, sourceL) 112 | # --> batch x sourceL x queryL 113 | attn = torch.transpose(attn, 1, 2).contiguous() 114 | 115 | # (batch x idf x sourceL)(batch x sourceL x queryL) 116 | # --> batch x idf x queryL 117 | weightedContext = torch.bmm(sourceT, attn) 118 | weightedContext = weightedContext.view(batch_size, -1, ih, iw) 119 | attn = attn.view(batch_size, -1, ih, iw) 120 | 121 | return weightedContext, attn 122 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # AttnGAN Eval API 2 | Model evaluation code is extracted here in order to create a separate inference mode for the project. The evaluation code is then embedded into a flask app that accepts API requests. 3 | There are two docker files: 4 | 1. [dockerfile.cpu](dockerfile.cpu) - creates a CPU bonud container 5 | 2. [dockerfile.gpu](dockerfile.gpu) - creates a GPU bound container 6 | 7 | # Requirements 8 | The app uses Azure Blob Storage as an image repository as well as Application Insights for logging telemetry. 9 | 10 | # Running the Application 11 | There is a three step process running the application and generating bird images. 12 | 1. Create the container (optionally choose the cpu or gpu dockerfile: 13 | ``` 14 | docker build -t "attngan" -f dockerfile.cpu . 15 | ``` 16 | 2. Run the container (replace the key's with the appropriate blob storage location as well as App Insights Key): 17 | ``` 18 | docker run -it -e BLOB_KEY=KEY -e TELEMETRY=TELEMETRY_KEY -p 5678:8080 attngan 19 | ``` 20 | 3. Call the API: 21 | ``` 22 | curl -H "Content-Type: application/json" -X POST -d '{"caption":"the bird has a yellow crown and a black eyering that is round"}' http://locahost:5678/api/v1.0/bird 23 | ``` 24 | 25 | # Images 26 | You should have your very own image generator. 27 | 28 | 29 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /eval/data/bird_AttnGAN2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/bird_AttnGAN2.pth -------------------------------------------------------------------------------- /eval/data/captions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/captions.pickle -------------------------------------------------------------------------------- /eval/data/text_encoder200.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/text_encoder200.pth -------------------------------------------------------------------------------- /eval/dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM python:2 2 | 3 | RUN mkdir -p /usr/src/app 4 | WORKDIR /usr/src/app 5 | 6 | COPY requirements.txt /usr/src/app/ 7 | RUN pip install --upgrade pip 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | RUN pip install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl 10 | RUN pip install torchvision 11 | 12 | COPY . /usr/src/app 13 | 14 | ENV GPU False 15 | ENV EXPORT_MODEL True 16 | 17 | EXPOSE 8080 18 | 19 | CMD ["python", "main.py"] 20 | 21 | 22 | -------------------------------------------------------------------------------- /eval/dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda 2 | 3 | RUN apt-get update \ 4 | && apt-get upgrade -y \ 5 | && apt-get install -y \ 6 | python-pip \ 7 | python2.7 \ 8 | && apt-get autoremove \ 9 | && apt-get clean 10 | 11 | RUN mkdir -p /usr/src/app 12 | WORKDIR /usr/src/app 13 | 14 | COPY requirements.txt /usr/src/app/ 15 | RUN pip install --upgrade pip 16 | RUN pip install --no-cache-dir -r requirements.txt 17 | RUN pip install http://download.pytorch.org/whl/cu90/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl 18 | RUN pip install torchvision 19 | 20 | COPY . /usr/src/app 21 | 22 | ENV NVIDIA_VISIBLE_DEVICES all 23 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility 24 | ENV GPU True 25 | ENV EXPORT_MODEL False 26 | 27 | EXPOSE 8080 28 | 29 | CMD ["python", "main.py"] -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import torch 6 | import io 7 | import time 8 | import numpy as np 9 | from PIL import Image 10 | import torch.onnx 11 | from datetime import datetime 12 | from torch.autograd import Variable 13 | from miscc.config import cfg 14 | from miscc.utils import build_super_images2 15 | from model import RNN_ENCODER, G_NET 16 | from azure.storage.blob import BlockBlobService 17 | 18 | if sys.version_info[0] == 2: 19 | import cPickle as pickle 20 | else: 21 | import pickle 22 | 23 | from werkzeug.contrib.cache import SimpleCache 24 | cache = SimpleCache() 25 | 26 | def vectorize_caption(wordtoix, caption, copies=2): 27 | # create caption vector 28 | tokens = caption.split(' ') 29 | cap_v = [] 30 | for t in tokens: 31 | t = t.strip().encode('ascii', 'ignore').decode('ascii') 32 | if len(t) > 0 and t in wordtoix: 33 | cap_v.append(wordtoix[t]) 34 | 35 | # expected state for single generation 36 | captions = np.zeros((copies, len(cap_v))) 37 | for i in range(copies): 38 | captions[i,:] = np.array(cap_v) 39 | cap_lens = np.zeros(copies) + len(cap_v) 40 | 41 | #print(captions.astype(int), cap_lens.astype(int)) 42 | #captions, cap_lens = np.array([cap_v, cap_v]), np.array([len(cap_v), len(cap_v)]) 43 | #print(captions, cap_lens) 44 | #return captions, cap_lens 45 | 46 | return captions.astype(int), cap_lens.astype(int) 47 | 48 | def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=2): 49 | # load word vector 50 | captions, cap_lens = vectorize_caption(wordtoix, caption, copies) 51 | n_words = len(wordtoix) 52 | 53 | # only one to generate 54 | batch_size = captions.shape[0] 55 | 56 | nz = cfg.GAN.Z_DIM 57 | captions = Variable(torch.from_numpy(captions), volatile=True) 58 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) 59 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 60 | 61 | if cfg.CUDA: 62 | captions = captions.cuda() 63 | cap_lens = cap_lens.cuda() 64 | noise = noise.cuda() 65 | 66 | 67 | 68 | ####################################################### 69 | # (1) Extract text embeddings 70 | ####################################################### 71 | hidden = text_encoder.init_hidden(batch_size) 72 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 73 | mask = (captions == 0) 74 | 75 | 76 | ####################################################### 77 | # (2) Generate fake images 78 | ####################################################### 79 | noise.data.normal_(0, 1) 80 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 81 | 82 | # ONNX EXPORT 83 | #export = os.environ["EXPORT_MODEL"].lower() == 'true' 84 | if False: 85 | print("saving text_encoder.onnx") 86 | text_encoder_out = torch.onnx._export(text_encoder, (captions, cap_lens, hidden), "text_encoder.onnx", export_params=True) 87 | print("uploading text_encoder.onnx") 88 | blob_service.create_blob_from_path('models', "text_encoder.onnx", os.path.abspath("text_encoder.onnx")) 89 | print("done") 90 | 91 | print("saving netg.onnx") 92 | netg_out = torch.onnx._export(netG, (noise, sent_emb, words_embs, mask), "netg.onnx", export_params=True) 93 | print("uploading netg.onnx") 94 | blob_service.create_blob_from_path('models', "netg.onnx", os.path.abspath("netg.onnx")) 95 | print("done") 96 | return 97 | 98 | # G attention 99 | cap_lens_np = cap_lens.cpu().data.numpy() 100 | 101 | # storing to blob storage 102 | container_name = "images" 103 | full_path = "https://attgan.blob.core.windows.net/images/%s" 104 | prefix = datetime.now().strftime('%Y/%B/%d/%H_%M_%S_%f') 105 | urls = [] 106 | # only look at first one 107 | #j = 0 108 | for j in range(batch_size): 109 | for k in range(len(fake_imgs)): 110 | im = fake_imgs[k][j].data.cpu().numpy() 111 | im = (im + 1.0) * 127.5 112 | im = im.astype(np.uint8) 113 | im = np.transpose(im, (1, 2, 0)) 114 | im = Image.fromarray(im) 115 | 116 | # save image to stream 117 | stream = io.BytesIO() 118 | im.save(stream, format="png") 119 | stream.seek(0) 120 | if copies > 2: 121 | blob_name = '%s/%d/%s_g%d.png' % (prefix, j, "bird", k) 122 | else: 123 | blob_name = '%s/%s_g%d.png' % (prefix, "bird", k) 124 | blob_service.create_blob_from_stream(container_name, blob_name, stream) 125 | urls.append(full_path % blob_name) 126 | 127 | if copies == 2: 128 | for k in range(len(attention_maps)): 129 | #if False: 130 | if len(fake_imgs) > 1: 131 | im = fake_imgs[k + 1].detach().cpu() 132 | else: 133 | im = fake_imgs[0].detach().cpu() 134 | 135 | attn_maps = attention_maps[k] 136 | att_sze = attn_maps.size(2) 137 | 138 | img_set, sentences = \ 139 | build_super_images2(im[j].unsqueeze(0), 140 | captions[j].unsqueeze(0), 141 | [cap_lens_np[j]], ixtoword, 142 | [attn_maps[j]], att_sze) 143 | 144 | if img_set is not None: 145 | im = Image.fromarray(img_set) 146 | stream = io.BytesIO() 147 | im.save(stream, format="png") 148 | stream.seek(0) 149 | 150 | blob_name = '%s/%s_a%d.png' % (prefix, "attmaps", k) 151 | blob_service.create_blob_from_stream(container_name, blob_name, stream) 152 | urls.append(full_path % blob_name) 153 | if copies == 2: 154 | break 155 | 156 | #print(len(urls), urls) 157 | return urls 158 | 159 | def word_index(): 160 | ixtoword = cache.get('ixtoword') 161 | wordtoix = cache.get('wordtoix') 162 | if ixtoword is None or wordtoix is None: 163 | #print("ix and word not cached") 164 | # load word to index dictionary 165 | x = pickle.load(open('data/captions.pickle', 'rb')) 166 | ixtoword = x[2] 167 | wordtoix = x[3] 168 | del x 169 | cache.set('ixtoword', ixtoword, timeout=60 * 60 * 24) 170 | cache.set('wordtoix', wordtoix, timeout=60 * 60 * 24) 171 | 172 | return wordtoix, ixtoword 173 | 174 | def models(word_len): 175 | #print(word_len) 176 | text_encoder = cache.get('text_encoder') 177 | if text_encoder is None: 178 | #print("text_encoder not cached") 179 | text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM) 180 | state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 181 | text_encoder.load_state_dict(state_dict) 182 | if cfg.CUDA: 183 | text_encoder.cuda() 184 | text_encoder.eval() 185 | cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) 186 | 187 | netG = cache.get('netG') 188 | if netG is None: 189 | #print("netG not cached") 190 | netG = G_NET() 191 | state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) 192 | netG.load_state_dict(state_dict) 193 | if cfg.CUDA: 194 | netG.cuda() 195 | netG.eval() 196 | cache.set('netG', netG, timeout=60 * 60 * 24) 197 | 198 | return text_encoder, netG 199 | 200 | def eval(caption): 201 | # load word dictionaries 202 | wordtoix, ixtoword = word_index() 203 | # lead models 204 | text_encoder, netG = models(len(wordtoix)) 205 | # load blob service 206 | blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"]) 207 | 208 | t0 = time.time() 209 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service) 210 | t1 = time.time() 211 | 212 | response = { 213 | 'small': urls[0], 214 | 'medium': urls[1], 215 | 'large': urls[2], 216 | 'map1': urls[3], 217 | 'map2': urls[4], 218 | 'caption': caption, 219 | 'elapsed': t1 - t0 220 | } 221 | 222 | return response 223 | 224 | if __name__ == "__main__": 225 | caption = "the bird has a yellow crown and a black eyering that is round" 226 | 227 | # load configuration 228 | #cfg_from_file('eval_bird.yml') 229 | # load word dictionaries 230 | wordtoix, ixtoword = word_index() 231 | # lead models 232 | text_encoder, netG = models(len(wordtoix)) 233 | # load blob service 234 | blob_service = BlockBlobService(account_name='attgan', account_key='[REDACTED]') 235 | 236 | t0 = time.time() 237 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service) 238 | t1 = time.time() 239 | print(t1-t0) 240 | print(urls) -------------------------------------------------------------------------------- /eval/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | from eval import * 5 | from flask import Flask, jsonify, request, abort 6 | from applicationinsights import TelemetryClient 7 | from applicationinsights.requests import WSGIApplication 8 | from applicationinsights.exceptions import enable 9 | from miscc.config import cfg 10 | #from werkzeug.contrib.profiler import ProfilerMiddleware 11 | 12 | enable(os.environ["TELEMETRY"]) 13 | app = Flask(__name__) 14 | app.wsgi_app = WSGIApplication(os.environ["TELEMETRY"], app.wsgi_app) 15 | 16 | @app.route('/api/v1.0/bird', methods=['POST']) 17 | def create_bird(): 18 | if not request.json or not 'caption' in request.json: 19 | abort(400) 20 | 21 | caption = request.json['caption'] 22 | 23 | t0 = time.time() 24 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service) 25 | t1 = time.time() 26 | 27 | response = { 28 | 'small': urls[0], 29 | 'medium': urls[1], 30 | 'large': urls[2], 31 | 'map1': urls[3], 32 | 'map2': urls[4], 33 | 'caption': caption, 34 | 'elapsed': t1 - t0 35 | } 36 | return jsonify({'bird': response}), 201 37 | 38 | @app.route('/api/v1.0/birds', methods=['POST']) 39 | def create_birds(): 40 | if not request.json or not 'caption' in request.json: 41 | abort(400) 42 | 43 | caption = request.json['caption'] 44 | 45 | t0 = time.time() 46 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=6) 47 | t1 = time.time() 48 | 49 | response = { 50 | 'bird1' : { 'small': urls[0], 'medium': urls[1], 'large': urls[2] }, 51 | 'bird2' : { 'small': urls[3], 'medium': urls[4], 'large': urls[5] }, 52 | 'bird3' : { 'small': urls[6], 'medium': urls[7], 'large': urls[8] }, 53 | 'bird4' : { 'small': urls[9], 'medium': urls[10], 'large': urls[11] }, 54 | 'bird5' : { 'small': urls[12], 'medium': urls[13], 'large': urls[14] }, 55 | 'bird6' : { 'small': urls[15], 'medium': urls[16], 'large': urls[17] }, 56 | 'caption': caption, 57 | 'elapsed': t1 - t0 58 | } 59 | return jsonify({'bird': response}), 201 60 | 61 | @app.route('/', methods=['GET']) 62 | def get_bird(): 63 | return 'Version 1' 64 | 65 | if __name__ == '__main__': 66 | t0 = time.time() 67 | tc = TelemetryClient(os.environ["TELEMETRY"]) 68 | 69 | # gpu based 70 | cfg.CUDA = os.environ["GPU"].lower() == 'true' 71 | tc.track_event('container initializing', {"CUDA": str(cfg.CUDA)}) 72 | 73 | # load word dictionaries 74 | wordtoix, ixtoword = word_index() 75 | # lead models 76 | text_encoder, netG = models(len(wordtoix)) 77 | # load blob service 78 | blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"]) 79 | 80 | seed = 100 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | if cfg.CUDA: 85 | torch.cuda.manual_seed_all(seed) 86 | 87 | #app.config['PROFILE'] = True 88 | #app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30]) 89 | #app.run(host='0.0.0.0', port=8080, debug = True) 90 | 91 | t1 = time.time() 92 | tc.track_event('container start', {"starttime": str(t1-t0)}) 93 | app.run(host='0.0.0.0', port=8080) 94 | -------------------------------------------------------------------------------- /eval/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /eval/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.CONFIG_NAME = 'attn2' 15 | __C.DATA_DIR = '' 16 | __C.GPU_ID = 0 17 | __C.CUDA = False 18 | __C.WORKERS = 1 19 | 20 | __C.RNN_TYPE = 'LSTM' # 'GRU' 21 | __C.B_VALIDATION = False 22 | 23 | __C.TREE = edict() 24 | __C.TREE.BRANCH_NUM = 3 25 | __C.TREE.BASE_SIZE = 64 26 | 27 | 28 | # Training options 29 | __C.TRAIN = edict() 30 | __C.TRAIN.BATCH_SIZE = 64 31 | __C.TRAIN.MAX_EPOCH = 600 32 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000 33 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 34 | __C.TRAIN.GENERATOR_LR = 2e-4 35 | __C.TRAIN.ENCODER_LR = 2e-4 36 | __C.TRAIN.RNN_GRAD_CLIP = 0.25 37 | __C.TRAIN.FLAG = False 38 | __C.TRAIN.NET_E = 'data/text_encoder200.pth' 39 | __C.TRAIN.NET_G = 'data/bird_AttnGAN2.pth' 40 | __C.TRAIN.B_NET_D = False 41 | 42 | __C.TRAIN.SMOOTH = edict() 43 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0 44 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0 45 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0 46 | __C.TRAIN.SMOOTH.LAMBDA = 1.0 47 | 48 | 49 | # Modal options 50 | __C.GAN = edict() 51 | __C.GAN.DF_DIM = 64 52 | __C.GAN.GF_DIM = 32 53 | __C.GAN.Z_DIM = 100 54 | __C.GAN.CONDITION_DIM = 100 55 | __C.GAN.R_NUM = 2 56 | __C.GAN.B_ATTENTION = True 57 | __C.GAN.B_DCGAN = False 58 | 59 | 60 | __C.TEXT = edict() 61 | __C.TEXT.CAPTIONS_PER_IMAGE = 10 62 | __C.TEXT.EMBEDDING_DIM = 256 63 | __C.TEXT.WORDS_NUM = 25 64 | 65 | -------------------------------------------------------------------------------- /eval/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from copy import deepcopy 11 | import skimage.transform 12 | 13 | 14 | 15 | # For visualization ################################################ 16 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 17 | 2:[70, 70, 70], 3:[102,102,156], 18 | 4:[190,153,153], 5:[153,153,153], 19 | 6:[250,170, 30], 7:[220, 220, 0], 20 | 8:[107,142, 35], 9:[152,251,152], 21 | 10:[70,130,180], 11:[220,20, 60], 22 | 12:[255, 0, 0], 13:[0, 0, 142], 23 | 14:[119,11, 32], 15:[0, 60,100], 24 | 16:[0, 80, 100], 17:[0, 0, 230], 25 | 18:[0, 0, 70], 19:[0, 0, 0]} 26 | FONT_MAX = 50 27 | 28 | 29 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 30 | num = captions.size(0) 31 | img_txt = Image.fromarray(convas) 32 | # get a font 33 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 34 | #fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 35 | fnt = ImageFont.truetype('FreeMono.ttf', 50) 36 | # get a drawing context 37 | d = ImageDraw.Draw(img_txt) 38 | sentence_list = [] 39 | for i in range(num): 40 | cap = captions[i].data.cpu().numpy() 41 | sentence = [] 42 | for j in range(len(cap)): 43 | if cap[j] == 0: 44 | break 45 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 46 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 47 | font=fnt, fill=(255, 255, 255, 255)) 48 | sentence.append(word) 49 | sentence_list.append(sentence) 50 | return img_txt, sentence_list 51 | 52 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 53 | attn_maps, att_sze, vis_size=256, topK=5): 54 | batch_size = real_imgs.size(0) 55 | max_word_num = np.max(cap_lens) 56 | text_convas = np.ones([batch_size * FONT_MAX, 57 | max_word_num * (vis_size + 2), 3], 58 | dtype=np.uint8) 59 | 60 | real_imgs = \ 61 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 62 | # [-1, 1] --> [0, 1] 63 | real_imgs.add_(1).div_(2).mul_(255) 64 | real_imgs = real_imgs.data.numpy() 65 | # b x c x h x w --> b x h x w x c 66 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 67 | pad_sze = real_imgs.shape 68 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 69 | 70 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 71 | img_set = [] 72 | num = len(attn_maps) 73 | 74 | text_map, sentences = \ 75 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 76 | text_map = np.asarray(text_map).astype(np.uint8) 77 | 78 | bUpdate = 1 79 | for i in range(num): 80 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 81 | # 82 | attn = attn.view(-1, 1, att_sze, att_sze) 83 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 84 | # n x c x h x w --> n x h x w x c 85 | attn = np.transpose(attn, (0, 2, 3, 1)) 86 | num_attn = cap_lens[i] 87 | thresh = 2./float(num_attn) 88 | # 89 | img = real_imgs[i] 90 | row = [] 91 | row_merge = [] 92 | row_txt = [] 93 | row_beforeNorm = [] 94 | conf_score = [] 95 | for j in range(num_attn): 96 | one_map = attn[j] 97 | mask0 = one_map > (2. * thresh) 98 | conf_score.append(np.sum(one_map * mask0)) 99 | mask = one_map > thresh 100 | one_map = one_map * mask 101 | if (vis_size // att_sze) > 1: 102 | one_map = \ 103 | skimage.transform.pyramid_expand(one_map, sigma=20, 104 | upscale=vis_size // att_sze) 105 | minV = one_map.min() 106 | maxV = one_map.max() 107 | one_map = (one_map - minV) / (maxV - minV) 108 | row_beforeNorm.append(one_map) 109 | sorted_indices = np.argsort(conf_score)[::-1] 110 | 111 | for j in range(num_attn): 112 | one_map = row_beforeNorm[j] 113 | one_map *= 255 114 | # 115 | PIL_im = Image.fromarray(np.uint8(img)) 116 | PIL_att = Image.fromarray(np.uint8(one_map)) 117 | merged = \ 118 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 119 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 120 | merged.paste(PIL_im, (0, 0)) 121 | merged.paste(PIL_att, (0, 0), mask) 122 | merged = np.array(merged)[:, :, :3] 123 | 124 | row.append(np.concatenate([one_map, middle_pad], 1)) 125 | # 126 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 127 | # 128 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 129 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 130 | row_txt.append(txt) 131 | # reorder 132 | row_new = [] 133 | row_merge_new = [] 134 | txt_new = [] 135 | for j in range(num_attn): 136 | idx = sorted_indices[j] 137 | row_new.append(row[idx]) 138 | row_merge_new.append(row_merge[idx]) 139 | txt_new.append(row_txt[idx]) 140 | row = np.concatenate(row_new[:topK], 1) 141 | row_merge = np.concatenate(row_merge_new[:topK], 1) 142 | txt = np.concatenate(txt_new[:topK], 1) 143 | if txt.shape[1] != row.shape[1]: 144 | print('Warnings: txt', txt.shape, 'row', row.shape, 145 | 'row_merge_new', row_merge_new.shape) 146 | bUpdate = 0 147 | break 148 | row = np.concatenate([txt, row_merge], 0) 149 | img_set.append(row) 150 | if bUpdate: 151 | img_set = np.concatenate(img_set, 0) 152 | img_set = img_set.astype(np.uint8) 153 | return img_set, sentences 154 | else: 155 | return None 156 | 157 | 158 | #################################################################### 159 | def weights_init(m): 160 | classname = m.__class__.__name__ 161 | if classname.find('Conv') != -1: 162 | nn.init.orthogonal(m.weight.data, 1.0) 163 | elif classname.find('BatchNorm') != -1: 164 | m.weight.data.normal_(1.0, 0.02) 165 | m.bias.data.fill_(0) 166 | elif classname.find('Linear') != -1: 167 | nn.init.orthogonal(m.weight.data, 1.0) 168 | if m.bias is not None: 169 | m.bias.data.fill_(0.0) 170 | 171 | 172 | def load_params(model, new_param): 173 | for p, new_p in zip(model.parameters(), new_param): 174 | p.data.copy_(new_p) 175 | 176 | 177 | def copy_G_params(model): 178 | flatten = deepcopy(list(p.data for p in model.parameters())) 179 | return flatten 180 | 181 | 182 | def mkdir_p(path): 183 | try: 184 | os.makedirs(path) 185 | except OSError as exc: # Python >2.5 186 | if exc.errno == errno.EEXIST and os.path.isdir(path): 187 | pass 188 | else: 189 | raise 190 | -------------------------------------------------------------------------------- /eval/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | import torch.utils.model_zoo as model_zoo 7 | import torch.nn.functional as F 8 | 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | 11 | from miscc.config import cfg 12 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET 13 | 14 | 15 | # ############## Text2Image Encoder-Decoder ####### 16 | class RNN_ENCODER(nn.Module): 17 | def __init__(self, ntoken, ninput=300, drop_prob=0.5, 18 | nhidden=128, nlayers=1, bidirectional=True): 19 | super(RNN_ENCODER, self).__init__() 20 | 21 | self.n_steps = cfg.TEXT.WORDS_NUM 22 | self.rnn_type = cfg.RNN_TYPE 23 | 24 | self.ntoken = ntoken # size of the dictionary 25 | self.ninput = ninput # size of each embedding vector 26 | self.drop_prob = drop_prob # probability of an element to be zeroed 27 | self.nlayers = nlayers # Number of recurrent layers 28 | self.bidirectional = bidirectional 29 | 30 | if bidirectional: 31 | self.num_directions = 2 32 | else: 33 | self.num_directions = 1 34 | # number of features in the hidden state 35 | self.nhidden = nhidden // self.num_directions 36 | 37 | self.define_module() 38 | self.init_weights() 39 | 40 | def define_module(self): 41 | self.encoder = nn.Embedding(self.ntoken, self.ninput) 42 | self.drop = nn.Dropout(self.drop_prob) 43 | if self.rnn_type == 'LSTM': 44 | # dropout: If non-zero, introduces a dropout layer on 45 | # the outputs of each RNN layer except the last layer 46 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 47 | self.nlayers, batch_first=True, 48 | dropout=self.drop_prob, 49 | bidirectional=self.bidirectional) 50 | elif self.rnn_type == 'GRU': 51 | self.rnn = nn.GRU(self.ninput, self.nhidden, 52 | self.nlayers, batch_first=True, 53 | dropout=self.drop_prob, 54 | bidirectional=self.bidirectional) 55 | else: 56 | raise NotImplementedError 57 | 58 | def init_weights(self): 59 | initrange = 0.1 60 | self.encoder.weight.data.uniform_(-initrange, initrange) 61 | # Do not need to initialize RNN parameters, which have been initialized 62 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM 63 | # self.decoder.weight.data.uniform_(-initrange, initrange) 64 | # self.decoder.bias.data.fill_(0) 65 | 66 | def init_hidden(self, bsz): 67 | weight = next(self.parameters()).data 68 | if self.rnn_type == 'LSTM': 69 | return (Variable(weight.new(self.nlayers * self.num_directions, 70 | bsz, self.nhidden).zero_()), 71 | Variable(weight.new(self.nlayers * self.num_directions, 72 | bsz, self.nhidden).zero_())) 73 | else: 74 | return Variable(weight.new(self.nlayers * self.num_directions, 75 | bsz, self.nhidden).zero_()) 76 | 77 | def forward(self, captions, cap_lens, hidden, mask=None): 78 | # input: torch.LongTensor of size batch x n_steps 79 | # --> emb: batch x n_steps x ninput 80 | emb = self.drop(self.encoder(captions)) 81 | # 82 | # Returns: a PackedSequence object 83 | cap_lens = cap_lens.data.tolist() 84 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True) 85 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 86 | # tensor containing the initial hidden state for each element in batch. 87 | # #output (batch, seq_len, hidden_size * num_directions) 88 | # #or a PackedSequence object: 89 | # tensor containing output features (h_t) from the last layer of RNN 90 | output, hidden = self.rnn(emb, hidden) 91 | # PackedSequence object 92 | # --> (batch, seq_len, hidden_size * num_directions) 93 | output = pad_packed_sequence(output, batch_first=True)[0] 94 | # output = self.drop(output) 95 | # --> batch x hidden_size*num_directions x seq_len 96 | words_emb = output.transpose(1, 2) 97 | # --> batch x num_directions*hidden_size 98 | if self.rnn_type == 'LSTM': 99 | sent_emb = hidden[0].transpose(0, 1).contiguous() 100 | else: 101 | sent_emb = hidden.transpose(0, 1).contiguous() 102 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 103 | return words_emb, sent_emb 104 | 105 | 106 | 107 | 108 | class G_NET(nn.Module): 109 | def __init__(self): 110 | super(G_NET, self).__init__() 111 | 112 | ngf = cfg.GAN.GF_DIM 113 | nef = cfg.TEXT.EMBEDDING_DIM 114 | ncf = cfg.GAN.CONDITION_DIM 115 | 116 | 117 | self.ca_net = CA_NET() 118 | 119 | if cfg.TREE.BRANCH_NUM > 0: 120 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 121 | self.img_net1 = GET_IMAGE_G(ngf) 122 | # gf x 64 x 64 123 | if cfg.TREE.BRANCH_NUM > 1: 124 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 125 | self.img_net2 = GET_IMAGE_G(ngf) 126 | if cfg.TREE.BRANCH_NUM > 2: 127 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 128 | self.img_net3 = GET_IMAGE_G(ngf) 129 | 130 | def forward(self, z_code, sent_emb, word_embs, mask): 131 | """ 132 | :param z_code: batch x cfg.GAN.Z_DIM 133 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 134 | :param word_embs: batch x cdf x seq_len 135 | :param mask: batch x seq_len 136 | :return: 137 | """ 138 | fake_imgs = [] 139 | att_maps = [] 140 | c_code, mu, logvar = self.ca_net(sent_emb) 141 | 142 | if cfg.TREE.BRANCH_NUM > 0: 143 | h_code1 = self.h_net1(z_code, c_code) 144 | fake_img1 = self.img_net1(h_code1) 145 | fake_imgs.append(fake_img1) 146 | if cfg.TREE.BRANCH_NUM > 1: 147 | h_code2, att1 = \ 148 | self.h_net2(h_code1, c_code, word_embs, mask) 149 | fake_img2 = self.img_net2(h_code2) 150 | fake_imgs.append(fake_img2) 151 | if att1 is not None: 152 | att_maps.append(att1) 153 | if cfg.TREE.BRANCH_NUM > 2: 154 | h_code3, att2 = \ 155 | self.h_net3(h_code2, c_code, word_embs, mask) 156 | fake_img3 = self.img_net3(h_code3) 157 | fake_imgs.append(fake_img3) 158 | if att2 is not None: 159 | att_maps.append(att2) 160 | 161 | return fake_imgs, att_maps, mu, logvar 162 | 163 | 164 | # ############## G networks ################### 165 | class CA_NET(nn.Module): 166 | # some code is modified from vae examples 167 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 168 | def __init__(self): 169 | super(CA_NET, self).__init__() 170 | 171 | self.t_dim = cfg.TEXT.EMBEDDING_DIM 172 | self.c_dim = cfg.GAN.CONDITION_DIM 173 | 174 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) 175 | self.relu = GLU() 176 | 177 | def encode(self, text_embedding): 178 | x = self.relu(self.fc(text_embedding)) 179 | mu = x[:, :self.c_dim] 180 | logvar = x[:, self.c_dim:] 181 | return mu, logvar 182 | 183 | def reparametrize(self, mu, logvar): 184 | std = logvar.mul(0.5).exp_() 185 | if cfg.CUDA: 186 | eps = torch.cuda.FloatTensor(std.size()).normal_() 187 | else: 188 | eps = torch.FloatTensor(std.size()).normal_() 189 | eps = Variable(eps) 190 | return eps.mul(std).add_(mu) 191 | 192 | def forward(self, text_embedding): 193 | mu, logvar = self.encode(text_embedding) 194 | c_code = self.reparametrize(mu, logvar) 195 | return c_code, mu, logvar 196 | 197 | 198 | 199 | class GLU(nn.Module): 200 | def __init__(self): 201 | super(GLU, self).__init__() 202 | 203 | def forward(self, x): 204 | nc = x.size(1) 205 | assert nc % 2 == 0, 'channels dont divide 2!' 206 | nc = int(nc/2) 207 | return x[:, :nc] * F.sigmoid(x[:, nc:]) 208 | 209 | 210 | def conv1x1(in_planes, out_planes, bias=False): 211 | "1x1 convolution with padding" 212 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 213 | padding=0, bias=bias) 214 | 215 | 216 | def conv3x3(in_planes, out_planes): 217 | "3x3 convolution with padding" 218 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 219 | padding=1, bias=False) 220 | 221 | 222 | # Upsale the spatial size by a factor of 2 223 | def upBlock(in_planes, out_planes): 224 | block = nn.Sequential( 225 | nn.Upsample(scale_factor=2, mode='nearest'), 226 | conv3x3(in_planes, out_planes * 2), 227 | nn.BatchNorm2d(out_planes * 2), 228 | GLU()) 229 | return block 230 | 231 | 232 | # Keep the spatial size 233 | def Block3x3_relu(in_planes, out_planes): 234 | block = nn.Sequential( 235 | conv3x3(in_planes, out_planes * 2), 236 | nn.BatchNorm2d(out_planes * 2), 237 | GLU()) 238 | return block 239 | 240 | 241 | class ResBlock(nn.Module): 242 | def __init__(self, channel_num): 243 | super(ResBlock, self).__init__() 244 | self.block = nn.Sequential( 245 | conv3x3(channel_num, channel_num * 2), 246 | nn.BatchNorm2d(channel_num * 2), 247 | GLU(), 248 | conv3x3(channel_num, channel_num), 249 | nn.BatchNorm2d(channel_num)) 250 | 251 | def forward(self, x): 252 | residual = x 253 | out = self.block(x) 254 | out += residual 255 | return out 256 | 257 | 258 | 259 | class CNN_ENCODER(nn.Module): 260 | def __init__(self, nef): 261 | super(CNN_ENCODER, self).__init__() 262 | if cfg.TRAIN.FLAG: 263 | self.nef = nef 264 | else: 265 | self.nef = 256 # define a uniform ranker 266 | 267 | model = models.inception_v3() 268 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' 269 | model.load_state_dict(model_zoo.load_url(url)) 270 | for param in model.parameters(): 271 | param.requires_grad = False 272 | print('Load pretrained model from ', url) 273 | # print(model) 274 | 275 | self.define_module(model) 276 | self.init_trainable_weights() 277 | 278 | def define_module(self, model): 279 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 280 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 281 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 282 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 283 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 284 | self.Mixed_5b = model.Mixed_5b 285 | self.Mixed_5c = model.Mixed_5c 286 | self.Mixed_5d = model.Mixed_5d 287 | self.Mixed_6a = model.Mixed_6a 288 | self.Mixed_6b = model.Mixed_6b 289 | self.Mixed_6c = model.Mixed_6c 290 | self.Mixed_6d = model.Mixed_6d 291 | self.Mixed_6e = model.Mixed_6e 292 | self.Mixed_7a = model.Mixed_7a 293 | self.Mixed_7b = model.Mixed_7b 294 | self.Mixed_7c = model.Mixed_7c 295 | 296 | self.emb_features = conv1x1(768, self.nef) 297 | self.emb_cnn_code = nn.Linear(2048, self.nef) 298 | 299 | def init_trainable_weights(self): 300 | initrange = 0.1 301 | self.emb_features.weight.data.uniform_(-initrange, initrange) 302 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) 303 | 304 | def forward(self, x): 305 | features = None 306 | # --> fixed-size input: batch x 3 x 299 x 299 307 | x = nn.Upsample(size=(299, 299), mode='bilinear')(x) 308 | # 299 x 299 x 3 309 | x = self.Conv2d_1a_3x3(x) 310 | # 149 x 149 x 32 311 | x = self.Conv2d_2a_3x3(x) 312 | # 147 x 147 x 32 313 | x = self.Conv2d_2b_3x3(x) 314 | # 147 x 147 x 64 315 | x = F.max_pool2d(x, kernel_size=3, stride=2) 316 | # 73 x 73 x 64 317 | x = self.Conv2d_3b_1x1(x) 318 | # 73 x 73 x 80 319 | x = self.Conv2d_4a_3x3(x) 320 | # 71 x 71 x 192 321 | 322 | x = F.max_pool2d(x, kernel_size=3, stride=2) 323 | # 35 x 35 x 192 324 | x = self.Mixed_5b(x) 325 | # 35 x 35 x 256 326 | x = self.Mixed_5c(x) 327 | # 35 x 35 x 288 328 | x = self.Mixed_5d(x) 329 | # 35 x 35 x 288 330 | 331 | x = self.Mixed_6a(x) 332 | # 17 x 17 x 768 333 | x = self.Mixed_6b(x) 334 | # 17 x 17 x 768 335 | x = self.Mixed_6c(x) 336 | # 17 x 17 x 768 337 | x = self.Mixed_6d(x) 338 | # 17 x 17 x 768 339 | x = self.Mixed_6e(x) 340 | # 17 x 17 x 768 341 | 342 | # image region features 343 | features = x 344 | # 17 x 17 x 768 345 | 346 | x = self.Mixed_7a(x) 347 | # 8 x 8 x 1280 348 | x = self.Mixed_7b(x) 349 | # 8 x 8 x 2048 350 | x = self.Mixed_7c(x) 351 | # 8 x 8 x 2048 352 | x = F.avg_pool2d(x, kernel_size=8) 353 | # 1 x 1 x 2048 354 | # x = F.dropout(x, training=self.training) 355 | # 1 x 1 x 2048 356 | x = x.view(x.size(0), -1) 357 | # 2048 358 | 359 | # global image features 360 | cnn_code = self.emb_cnn_code(x) 361 | # 512 362 | if features is not None: 363 | features = self.emb_features(features) 364 | return features, cnn_code 365 | 366 | 367 | class INIT_STAGE_G(nn.Module): 368 | def __init__(self, ngf, ncf): 369 | super(INIT_STAGE_G, self).__init__() 370 | self.gf_dim = ngf 371 | self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM 372 | 373 | self.define_module() 374 | 375 | def define_module(self): 376 | nz, ngf = self.in_dim, self.gf_dim 377 | self.fc = nn.Sequential( 378 | nn.Linear(nz, ngf * 4 * 4 * 2, bias=False), 379 | # removing for single instance caption 380 | nn.BatchNorm1d(ngf * 4 * 4 * 2), 381 | GLU()) 382 | 383 | self.upsample1 = upBlock(ngf, ngf // 2) 384 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 385 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 386 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 387 | 388 | def forward(self, z_code, c_code): 389 | """ 390 | :param z_code: batch x cfg.GAN.Z_DIM 391 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM 392 | :return: batch x ngf/16 x 64 x 64 393 | """ 394 | c_z_code = torch.cat((c_code, z_code), 1) 395 | # state size ngf x 4 x 4 396 | out_code = self.fc(c_z_code) 397 | out_code = out_code.view(-1, self.gf_dim, 4, 4) 398 | # state size ngf/3 x 8 x 8 399 | out_code = self.upsample1(out_code) 400 | # state size ngf/4 x 16 x 16 401 | out_code = self.upsample2(out_code) 402 | # state size ngf/8 x 32 x 32 403 | out_code32 = self.upsample3(out_code) 404 | # state size ngf/16 x 64 x 64 405 | out_code64 = self.upsample4(out_code32) 406 | 407 | return out_code64 408 | 409 | 410 | class NEXT_STAGE_G(nn.Module): 411 | def __init__(self, ngf, nef, ncf): 412 | super(NEXT_STAGE_G, self).__init__() 413 | self.gf_dim = ngf 414 | self.ef_dim = nef 415 | self.cf_dim = ncf 416 | self.num_residual = cfg.GAN.R_NUM 417 | self.define_module() 418 | 419 | def _make_layer(self, block, channel_num): 420 | layers = [] 421 | for i in range(cfg.GAN.R_NUM): 422 | layers.append(block(channel_num)) 423 | return nn.Sequential(*layers) 424 | 425 | def define_module(self): 426 | ngf = self.gf_dim 427 | self.att = ATT_NET(ngf, self.ef_dim) 428 | self.residual = self._make_layer(ResBlock, ngf * 2) 429 | self.upsample = upBlock(ngf * 2, ngf) 430 | 431 | def forward(self, h_code, c_code, word_embs, mask): 432 | """ 433 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw) 434 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len) 435 | c_code1: batch x idf x queryL 436 | att1: batch x sourceL x queryL 437 | """ 438 | self.att.applyMask(mask) 439 | c_code, att = self.att(h_code, word_embs) 440 | h_c_code = torch.cat((h_code, c_code), 1) 441 | out_code = self.residual(h_c_code) 442 | 443 | # state size ngf/2 x 2in_size x 2in_size 444 | out_code = self.upsample(out_code) 445 | 446 | return out_code, att 447 | 448 | 449 | class GET_IMAGE_G(nn.Module): 450 | def __init__(self, ngf): 451 | super(GET_IMAGE_G, self).__init__() 452 | self.gf_dim = ngf 453 | self.img = nn.Sequential( 454 | conv3x3(ngf, 3), 455 | nn.Tanh() 456 | ) 457 | 458 | def forward(self, h_code): 459 | out_img = self.img(h_code) 460 | return out_img 461 | 462 | 463 | class G_DCGAN(nn.Module): 464 | def __init__(self): 465 | super(G_DCGAN, self).__init__() 466 | ngf = cfg.GAN.GF_DIM 467 | nef = cfg.TEXT.EMBEDDING_DIM 468 | ncf = cfg.GAN.CONDITION_DIM 469 | self.ca_net = CA_NET() 470 | 471 | # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64 472 | if cfg.TREE.BRANCH_NUM > 0: 473 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 474 | # gf x 64 x 64 475 | if cfg.TREE.BRANCH_NUM > 1: 476 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 477 | if cfg.TREE.BRANCH_NUM > 2: 478 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 479 | self.img_net = GET_IMAGE_G(ngf) 480 | 481 | def forward(self, z_code, sent_emb, word_embs, mask): 482 | """ 483 | :param z_code: batch x cfg.GAN.Z_DIM 484 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 485 | :param word_embs: batch x cdf x seq_len 486 | :param mask: batch x seq_len 487 | :return: 488 | """ 489 | att_maps = [] 490 | c_code, mu, logvar = self.ca_net(sent_emb) 491 | if cfg.TREE.BRANCH_NUM > 0: 492 | h_code = self.h_net1(z_code, c_code) 493 | if cfg.TREE.BRANCH_NUM > 1: 494 | h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask) 495 | if att1 is not None: 496 | att_maps.append(att1) 497 | if cfg.TREE.BRANCH_NUM > 2: 498 | h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask) 499 | if att2 is not None: 500 | att_maps.append(att2) 501 | 502 | fake_imgs = self.img_net(h_code) 503 | return [fake_imgs], att_maps, mu, logvar 504 | 505 | 506 | # ############## D networks ########################## 507 | def Block3x3_leakRelu(in_planes, out_planes): 508 | block = nn.Sequential( 509 | conv3x3(in_planes, out_planes), 510 | nn.BatchNorm2d(out_planes), 511 | nn.LeakyReLU(0.2, inplace=True) 512 | ) 513 | return block 514 | 515 | 516 | # Downsale the spatial size by a factor of 2 517 | def downBlock(in_planes, out_planes): 518 | block = nn.Sequential( 519 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 520 | nn.BatchNorm2d(out_planes), 521 | nn.LeakyReLU(0.2, inplace=True) 522 | ) 523 | return block 524 | 525 | 526 | # Downsale the spatial size by a factor of 16 527 | def encode_image_by_16times(ndf): 528 | encode_img = nn.Sequential( 529 | # --> state size. ndf x in_size/2 x in_size/2 530 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 531 | nn.LeakyReLU(0.2, inplace=True), 532 | # --> state size 2ndf x x in_size/4 x in_size/4 533 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 534 | nn.BatchNorm2d(ndf * 2), 535 | nn.LeakyReLU(0.2, inplace=True), 536 | # --> state size 4ndf x in_size/8 x in_size/8 537 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 538 | nn.BatchNorm2d(ndf * 4), 539 | nn.LeakyReLU(0.2, inplace=True), 540 | # --> state size 8ndf x in_size/16 x in_size/16 541 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 542 | nn.BatchNorm2d(ndf * 8), 543 | nn.LeakyReLU(0.2, inplace=True) 544 | ) 545 | return encode_img 546 | 547 | 548 | class D_GET_LOGITS(nn.Module): 549 | def __init__(self, ndf, nef, bcondition=False): 550 | super(D_GET_LOGITS, self).__init__() 551 | self.df_dim = ndf 552 | self.ef_dim = nef 553 | self.bcondition = bcondition 554 | if self.bcondition: 555 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8) 556 | 557 | self.outlogits = nn.Sequential( 558 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 559 | nn.Sigmoid()) 560 | 561 | def forward(self, h_code, c_code=None): 562 | if self.bcondition and c_code is not None: 563 | # conditioning output 564 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 565 | c_code = c_code.repeat(1, 1, 4, 4) 566 | # state size (ngf+egf) x 4 x 4 567 | h_c_code = torch.cat((h_code, c_code), 1) 568 | # state size ngf x in_size x in_size 569 | h_c_code = self.jointConv(h_c_code) 570 | else: 571 | h_c_code = h_code 572 | 573 | output = self.outlogits(h_c_code) 574 | return output.view(-1) 575 | 576 | 577 | # For 64 x 64 images 578 | class D_NET64(nn.Module): 579 | def __init__(self, b_jcu=True): 580 | super(D_NET64, self).__init__() 581 | ndf = cfg.GAN.DF_DIM 582 | nef = cfg.TEXT.EMBEDDING_DIM 583 | self.img_code_s16 = encode_image_by_16times(ndf) 584 | if b_jcu: 585 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 586 | else: 587 | self.UNCOND_DNET = None 588 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 589 | 590 | def forward(self, x_var): 591 | x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df 592 | return x_code4 593 | 594 | 595 | # For 128 x 128 images 596 | class D_NET128(nn.Module): 597 | def __init__(self, b_jcu=True): 598 | super(D_NET128, self).__init__() 599 | ndf = cfg.GAN.DF_DIM 600 | nef = cfg.TEXT.EMBEDDING_DIM 601 | self.img_code_s16 = encode_image_by_16times(ndf) 602 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 603 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) 604 | # 605 | if b_jcu: 606 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 607 | else: 608 | self.UNCOND_DNET = None 609 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 610 | 611 | def forward(self, x_var): 612 | x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df 613 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df 614 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df 615 | return x_code4 616 | 617 | 618 | # For 256 x 256 images 619 | class D_NET256(nn.Module): 620 | def __init__(self, b_jcu=True): 621 | super(D_NET256, self).__init__() 622 | ndf = cfg.GAN.DF_DIM 623 | nef = cfg.TEXT.EMBEDDING_DIM 624 | self.img_code_s16 = encode_image_by_16times(ndf) 625 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 626 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32) 627 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16) 628 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8) 629 | if b_jcu: 630 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 631 | else: 632 | self.UNCOND_DNET = None 633 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 634 | 635 | def forward(self, x_var): 636 | x_code16 = self.img_code_s16(x_var) 637 | x_code8 = self.img_code_s32(x_code16) 638 | x_code4 = self.img_code_s64(x_code8) 639 | x_code4 = self.img_code_s64_1(x_code4) 640 | x_code4 = self.img_code_s64_2(x_code4) 641 | return x_code4 642 | -------------------------------------------------------------------------------- /eval/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | python-dateutil 3 | easydict 4 | scikit-image 5 | azure-storage-blob 6 | applicationinsights 7 | libmc -------------------------------------------------------------------------------- /example_bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/example_bird.png -------------------------------------------------------------------------------- /example_coco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/example_coco.png -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/framework.png -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !README.md 3 | !.gitignore --------------------------------------------------------------------------------