├── .gitignore
├── DAMSMencoders
└── .gitignore
├── LICENSE
├── README.md
├── code
├── .gitignore
├── GlobalAttention.py
├── cfg
│ ├── DAMSM
│ │ ├── bird.yml
│ │ └── coco.yml
│ ├── bird_attn2.yml
│ ├── bird_attnDCGAN2.yml
│ ├── coco_attn2.yml
│ ├── eval_bird.yml
│ ├── eval_bird_attnDCGAN2.yml
│ └── eval_coco.yml
├── datasets.py
├── main.py
├── miscc
│ ├── __init__.py
│ ├── config.py
│ ├── losses.py
│ └── utils.py
├── model.py
├── pretrain_DAMSM.py
└── trainer.py
├── data
└── .gitignore
├── eval
├── FreeMono.ttf
├── GlobalAttention.py
├── README.md
├── __init__.py
├── data
│ ├── bird_AttnGAN2.pth
│ ├── captions.pickle
│ └── text_encoder200.pth
├── dockerfile.cpu
├── dockerfile.gpu
├── eval.py
├── main.py
├── miscc
│ ├── __init__.py
│ ├── config.py
│ └── utils.py
├── model.py
└── requirements.txt
├── example_bird.png
├── example_coco.png
├── framework.png
└── models
└── .gitignore
/.gitignore:
--------------------------------------------------------------------------------
1 | backup
2 | output
3 | code/*.pyc
4 | code/miscc/*.pyc
5 | .DS_Store
6 | .idea/
7 |
--------------------------------------------------------------------------------
/DAMSMencoders/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tao Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AttnGAN
2 |
3 | Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation
4 | with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research).
5 |
6 |
7 |
8 |
9 | ### Dependencies
10 | python 2.7
11 |
12 | Pytorch
13 |
14 | In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
15 | - `python-dateutil`
16 | - `easydict`
17 | - `pandas`
18 | - `torchfile`
19 | - `nltk`
20 | - `scikit-image`
21 |
22 |
23 |
24 | **Data**
25 |
26 | 1. Download our preprocessed metadata for [birds](https://drive.google.com/open?id=1O_LtUP9sch09QH3s_EBAgLEctBQ5JBSJ) [coco](https://drive.google.com/open?id=1rSnbIGNDGZeHlsUlLdahj0RJ9oo6lgH9) and save them to `data/`
27 | 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/`
28 | 3. Download [coco](http://cocodataset.org/#download) dataset and extract the images to `data/coco/`
29 |
30 |
31 |
32 | **Training**
33 | - Pre-train DAMSM models:
34 | - For bird dataset: `python pretrain_DAMSM.py --cfg cfg/DAMSM/bird.yml --gpu 0`
35 | - For coco dataset: `python pretrain_DAMSM.py --cfg cfg/DAMSM/coco.yml --gpu 1`
36 |
37 | - Train AttnGAN models:
38 | - For bird dataset: `python main.py --cfg cfg/bird_attn2.yml --gpu 2`
39 | - For coco dataset: `python main.py --cfg cfg/coco_attn2.yml --gpu 3`
40 |
41 | - `*.yml` files are example configuration files for training/evaluation our models.
42 |
43 |
44 |
45 | **Pretrained Model**
46 | - [DAMSM for bird](https://drive.google.com/open?id=1GNUKjVeyWYBJ8hEU-yrfYQpDOkxEyP3V). Download and save it to `DAMSMencoders/`
47 | - [DAMSM for coco](https://drive.google.com/open?id=1zIrXCE9F6yfbEJIbNP5-YrEe2pZcPSGJ). Download and save it to `DAMSMencoders/`
48 | - [AttnGAN for bird](https://drive.google.com/open?id=1lqNG75suOuR_8gjoEPYNp8VyT_ufPPig). Download and save it to `models/`
49 | - [AttnGAN for coco](https://drive.google.com/open?id=1i9Xkg9nU74RAvkcqKE-rJYhjvzKAMnCi). Download and save it to `models/`
50 |
51 | - [AttnDCGAN for bird](https://drive.google.com/open?id=19TG0JUoXurxsmZLaJ82Yo6O0UJ6aDBpg). Download and save it to `models/`
52 | - This is an variant of AttnGAN which applies the propsoed attention mechanisms to DCGAN framework.
53 |
54 | **Sampling**
55 | - Run `python main.py --cfg cfg/eval_bird.yml --gpu 1` to generate examples from captions in files listed in "./data/birds/example_filenames.txt". Results are saved to `DAMSMencoders/`.
56 | - Change the `eval_*.yml` files to generate images from other pre-trained models.
57 | - Input your own sentence in "./data/birds/example_captions.txt" if you wannt to generate images from customized sentences.
58 |
59 | **Validation**
60 | - To generate images for all captions in the validation dataset, change B_VALIDATION to True in the eval_*.yml. and then run `python main.py --cfg cfg/eval_bird.yml --gpu 1`
61 | - We compute inception score for models trained on birds using [StackGAN-inception-model](https://github.com/hanzhanggit/StackGAN-inception-model).
62 | - We compute inception score for models trained on coco using [improved-gan/inception_score](https://github.com/openai/improved-gan/tree/master/inception_score).
63 |
64 |
65 | **Examples generated by AttnGAN [[Blog]](https://blogs.microsoft.com/ai/drawing-ai/)**
66 |
67 | bird example | coco example
68 | :-------------------------:|:-------------------------:
69 |  | 
70 |
71 |
72 | ### Creating an API
73 | [Evaluation code](eval) embedded into a callable containerized API is included in the `eval\` folder.
74 |
75 | ### Citing AttnGAN
76 | If you find AttnGAN useful in your research, please consider citing:
77 |
78 | ```
79 | @article{Tao18attngan,
80 | author = {Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He},
81 | title = {AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks},
82 | Year = {2018},
83 | booktitle = {{CVPR}}
84 | }
85 | ```
86 |
87 | **Reference**
88 |
89 | - [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916) [[code]](https://github.com/hanzhanggit/StackGAN-v2)
90 | - [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) [[code]](https://github.com/carpedm20/DCGAN-tensorflow)
91 |
--------------------------------------------------------------------------------
/code/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | !.gitignore
--------------------------------------------------------------------------------
/code/GlobalAttention.py:
--------------------------------------------------------------------------------
1 | """
2 | Global attention takes a matrix and a query metrix.
3 | Based on each query vector q, it computes a parameterized convex combination of the matrix
4 | based.
5 | H_1 H_2 H_3 ... H_n
6 | q q q q
7 | | | | |
8 | \ | | /
9 | .....
10 | \ | /
11 | a
12 | Constructs a unit mapping.
13 | $$(H_1 + H_n, q) => (a)$$
14 | Where H is of `batch x n x dim` and q is of `batch x dim`.
15 |
16 | References:
17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules
18 | http://www.aclweb.org/anthology/D15-1166
19 | """
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 |
25 | def conv1x1(in_planes, out_planes):
26 | "1x1 convolution with padding"
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 |
30 |
31 | def func_attention(query, context, gamma1):
32 | """
33 | query: batch x ndf x queryL
34 | context: batch x ndf x ih x iw (sourceL=ihxiw)
35 | mask: batch_size x sourceL
36 | """
37 | batch_size, queryL = query.size(0), query.size(2)
38 | ih, iw = context.size(2), context.size(3)
39 | sourceL = ih * iw
40 |
41 | # --> batch x sourceL x ndf
42 | context = context.view(batch_size, -1, sourceL)
43 | contextT = torch.transpose(context, 1, 2).contiguous()
44 |
45 | # Get attention
46 | # (batch x sourceL x ndf)(batch x ndf x queryL)
47 | # -->batch x sourceL x queryL
48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper
49 | # --> batch*sourceL x queryL
50 | attn = attn.view(batch_size*sourceL, queryL)
51 | attn = nn.Softmax()(attn) # Eq. (8)
52 |
53 | # --> batch x sourceL x queryL
54 | attn = attn.view(batch_size, sourceL, queryL)
55 | # --> batch*queryL x sourceL
56 | attn = torch.transpose(attn, 1, 2).contiguous()
57 | attn = attn.view(batch_size*queryL, sourceL)
58 | # Eq. (9)
59 | attn = attn * gamma1
60 | attn = nn.Softmax()(attn)
61 | attn = attn.view(batch_size, queryL, sourceL)
62 | # --> batch x sourceL x queryL
63 | attnT = torch.transpose(attn, 1, 2).contiguous()
64 |
65 | # (batch x ndf x sourceL)(batch x sourceL x queryL)
66 | # --> batch x ndf x queryL
67 | weightedContext = torch.bmm(context, attnT)
68 |
69 | return weightedContext, attn.view(batch_size, -1, ih, iw)
70 |
71 |
72 | class GlobalAttentionGeneral(nn.Module):
73 | def __init__(self, idf, cdf):
74 | super(GlobalAttentionGeneral, self).__init__()
75 | self.conv_context = conv1x1(cdf, idf)
76 | self.sm = nn.Softmax()
77 | self.mask = None
78 |
79 | def applyMask(self, mask):
80 | self.mask = mask # batch x sourceL
81 |
82 | def forward(self, input, context):
83 | """
84 | input: batch x idf x ih x iw (queryL=ihxiw)
85 | context: batch x cdf x sourceL
86 | """
87 | ih, iw = input.size(2), input.size(3)
88 | queryL = ih * iw
89 | batch_size, sourceL = context.size(0), context.size(2)
90 |
91 | # --> batch x queryL x idf
92 | target = input.view(batch_size, -1, queryL)
93 | targetT = torch.transpose(target, 1, 2).contiguous()
94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1
95 | sourceT = context.unsqueeze(3)
96 | # --> batch x idf x sourceL
97 | sourceT = self.conv_context(sourceT).squeeze(3)
98 |
99 | # Get attention
100 | # (batch x queryL x idf)(batch x idf x sourceL)
101 | # -->batch x queryL x sourceL
102 | attn = torch.bmm(targetT, sourceT)
103 | # --> batch*queryL x sourceL
104 | attn = attn.view(batch_size*queryL, sourceL)
105 | if self.mask is not None:
106 | # batch_size x sourceL --> batch_size*queryL x sourceL
107 | mask = self.mask.repeat(queryL, 1)
108 | attn.data.masked_fill_(mask.data, -float('inf'))
109 | attn = self.sm(attn) # Eq. (2)
110 | # --> batch x queryL x sourceL
111 | attn = attn.view(batch_size, queryL, sourceL)
112 | # --> batch x sourceL x queryL
113 | attn = torch.transpose(attn, 1, 2).contiguous()
114 |
115 | # (batch x idf x sourceL)(batch x sourceL x queryL)
116 | # --> batch x idf x queryL
117 | weightedContext = torch.bmm(sourceT, attn)
118 | weightedContext = weightedContext.view(batch_size, -1, ih, iw)
119 | attn = attn.view(batch_size, -1, ih, iw)
120 |
121 | return weightedContext, attn
122 |
--------------------------------------------------------------------------------
/code/cfg/DAMSM/bird.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'DAMSM'
2 |
3 | DATASET_NAME: 'birds'
4 | DATA_DIR: '../data/birds'
5 | GPU_ID: 0
6 | WORKERS: 1
7 |
8 |
9 | TREE:
10 | BRANCH_NUM: 1
11 | BASE_SIZE: 299
12 |
13 |
14 | TRAIN:
15 | FLAG: True
16 | NET_E: '' # '../DAMSMencoders/bird/text_encoder200.pth'
17 | BATCH_SIZE: 48
18 | MAX_EPOCH: 600
19 | SNAPSHOT_INTERVAL: 50
20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good; scott: 0.0007 with 0.98decay
21 | RNN_GRAD_CLIP: 0.25
22 | SMOOTH:
23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
24 | GAMMA2: 5.0
25 | GAMMA3: 10.0 # 10good 1&100bad
26 |
27 |
28 |
29 | TEXT:
30 | EMBEDDING_DIM: 256
31 | CAPTIONS_PER_IMAGE: 10
32 |
--------------------------------------------------------------------------------
/code/cfg/DAMSM/coco.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'DAMSM'
2 |
3 | DATASET_NAME: 'coco'
4 | DATA_DIR: '../data/coco'
5 | GPU_ID: 0
6 | WORKERS: 1
7 |
8 |
9 | TREE:
10 | BRANCH_NUM: 1
11 | BASE_SIZE: 299
12 |
13 |
14 | TRAIN:
15 | FLAG: True
16 | NET_E: '' # '../DAMSMencoders/coco/text_encoder100.pth'
17 | BATCH_SIZE: 48
18 | MAX_EPOCH: 600
19 | SNAPSHOT_INTERVAL: 5
20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good
21 | RNN_GRAD_CLIP: 0.25
22 | SMOOTH:
23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
24 | GAMMA2: 5.0
25 | GAMMA3: 10.0 # 10good 1&100bad
26 |
27 |
28 | TEXT:
29 | EMBEDDING_DIM: 256
30 | CAPTIONS_PER_IMAGE: 5
31 | WORDS_NUM: 15
32 |
--------------------------------------------------------------------------------
/code/cfg/bird_attn2.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'attn2'
2 |
3 | DATASET_NAME: 'birds'
4 | DATA_DIR: '../data/birds'
5 | GPU_ID: 0
6 | WORKERS: 4
7 |
8 |
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: True
15 | NET_G: '' # '../models/bird_AttnGAN2.pth'
16 | B_NET_D: True
17 | BATCH_SIZE: 20 # 22
18 | MAX_EPOCH: 600
19 | SNAPSHOT_INTERVAL: 50
20 | DISCRIMINATOR_LR: 0.0002
21 | GENERATOR_LR: 0.0002
22 | #
23 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth'
24 | SMOOTH:
25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
26 | GAMMA2: 5.0
27 | GAMMA3: 10.0 # 10good 1&100bad
28 | LAMBDA: 5.0
29 |
30 |
31 | GAN:
32 | DF_DIM: 64
33 | GF_DIM: 32
34 | Z_DIM: 100
35 | R_NUM: 2
36 |
37 | TEXT:
38 | EMBEDDING_DIM: 256
39 | CAPTIONS_PER_IMAGE: 10
40 |
--------------------------------------------------------------------------------
/code/cfg/bird_attnDCGAN2.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'attn2-dcgan'
2 |
3 | DATASET_NAME: 'birds'
4 | DATA_DIR: '../data/birds'
5 | GPU_ID: 0
6 | WORKERS: 4
7 |
8 |
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: True
15 | NET_G: '' # '../models/bird_AttnDCGAN2.pth'
16 | B_NET_D: True
17 | BATCH_SIZE: 30
18 | MAX_EPOCH: 400
19 | SNAPSHOT_INTERVAL: 50
20 | DISCRIMINATOR_LR: 0.0002
21 | GENERATOR_LR: 0.0002
22 | #
23 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth'
24 | SMOOTH:
25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
26 | GAMMA2: 5.0
27 | GAMMA3: 10.0 # 10good 1&100bad
28 | LAMBDA: 1.0
29 |
30 |
31 | GAN:
32 | DF_DIM: 64
33 | GF_DIM: 32
34 | Z_DIM: 100
35 | R_NUM: 0
36 | B_DCGAN: True
37 |
38 | TEXT:
39 | EMBEDDING_DIM: 256
40 | CAPTIONS_PER_IMAGE: 10
41 |
--------------------------------------------------------------------------------
/code/cfg/coco_attn2.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'glu-gan2'
2 |
3 | DATASET_NAME: 'coco'
4 | DATA_DIR: '../data/coco'
5 | GPU_ID: 0
6 | WORKERS: 4
7 |
8 |
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: True
15 | NET_G: '' # '../models/coco_AttnGAN2.pth'
16 | B_NET_D: True
17 | BATCH_SIZE: 14 # 32
18 | MAX_EPOCH: 120
19 | SNAPSHOT_INTERVAL: 5
20 | DISCRIMINATOR_LR: 0.0002
21 | GENERATOR_LR: 0.0002
22 | #
23 | NET_E: '../DAMSMencoders/coco/text_encoder100.pth'
24 | SMOOTH:
25 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
26 | GAMMA2: 5.0
27 | GAMMA3: 10.0 # 10good 1&100bad
28 | LAMBDA: 50.0
29 |
30 |
31 | GAN:
32 | DF_DIM: 96
33 | GF_DIM: 48
34 | Z_DIM: 100
35 | R_NUM: 3
36 |
37 | TEXT:
38 | EMBEDDING_DIM: 256
39 | CAPTIONS_PER_IMAGE: 5
40 | WORDS_NUM: 12
41 |
--------------------------------------------------------------------------------
/code/cfg/eval_bird.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'attn2'
2 |
3 | DATASET_NAME: 'birds'
4 | DATA_DIR: '../data/birds'
5 | GPU_ID: 3
6 | WORKERS: 1
7 |
8 | B_VALIDATION: False # True # False
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: False
15 | NET_G: '../models/bird_AttnGAN2.pth'
16 | B_NET_D: False
17 | BATCH_SIZE: 100
18 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth'
19 |
20 |
21 | GAN:
22 | DF_DIM: 64
23 | GF_DIM: 32
24 | Z_DIM: 100
25 | R_NUM: 2
26 |
27 | TEXT:
28 | EMBEDDING_DIM: 256
29 | CAPTIONS_PER_IMAGE: 10
30 | WORDS_NUM: 25
31 |
--------------------------------------------------------------------------------
/code/cfg/eval_bird_attnDCGAN2.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'attn2-dcgan'
2 |
3 | DATASET_NAME: 'birds'
4 | DATA_DIR: '../data/birds'
5 | GPU_ID: 3
6 | WORKERS: 1
7 |
8 | B_VALIDATION: False
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: False
15 | NET_G: '../models/bird_AttnDCGAN2.pth'
16 | B_NET_D: False
17 | BATCH_SIZE: 100
18 | NET_E: '../DAMSMencoders/bird/text_encoder200.pth'
19 |
20 |
21 | GAN:
22 | DF_DIM: 64
23 | GF_DIM: 32
24 | Z_DIM: 100
25 | R_NUM: 0
26 | B_DCGAN: True
27 |
28 | TEXT:
29 | EMBEDDING_DIM: 256
30 | CAPTIONS_PER_IMAGE: 10
31 | WORDS_NUM: 25
32 |
--------------------------------------------------------------------------------
/code/cfg/eval_coco.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'attn2'
2 |
3 | DATASET_NAME: 'coco'
4 | DATA_DIR: '../data/coco'
5 | GPU_ID: 3
6 | WORKERS: 1
7 |
8 | B_VALIDATION: False
9 | TREE:
10 | BRANCH_NUM: 3
11 |
12 |
13 | TRAIN:
14 | FLAG: False
15 | NET_G: '../models/coco_AttnGAN2.pth'
16 | B_NET_D: False
17 | BATCH_SIZE: 100
18 | NET_E: '../DAMSMencoders/coco/text_encoder100.pth'
19 |
20 |
21 | GAN:
22 | DF_DIM: 96
23 | GF_DIM: 48
24 | Z_DIM: 100
25 | R_NUM: 3
26 |
27 | TEXT:
28 | EMBEDDING_DIM: 256
29 | CAPTIONS_PER_IMAGE: 5
30 | WORDS_NUM: 20
31 |
--------------------------------------------------------------------------------
/code/datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 |
7 | from nltk.tokenize import RegexpTokenizer
8 | from collections import defaultdict
9 | from miscc.config import cfg
10 |
11 | import torch
12 | import torch.utils.data as data
13 | from torch.autograd import Variable
14 | import torchvision.transforms as transforms
15 |
16 | import os
17 | import sys
18 | import numpy as np
19 | import pandas as pd
20 | from PIL import Image
21 | import numpy.random as random
22 | if sys.version_info[0] == 2:
23 | import cPickle as pickle
24 | else:
25 | import pickle
26 |
27 |
28 | def prepare_data(data):
29 | imgs, captions, captions_lens, class_ids, keys = data
30 |
31 | # sort data by the length in a decreasing order
32 | sorted_cap_lens, sorted_cap_indices = \
33 | torch.sort(captions_lens, 0, True)
34 |
35 | real_imgs = []
36 | for i in range(len(imgs)):
37 | imgs[i] = imgs[i][sorted_cap_indices]
38 | if cfg.CUDA:
39 | real_imgs.append(Variable(imgs[i]).cuda())
40 | else:
41 | real_imgs.append(Variable(imgs[i]))
42 |
43 | captions = captions[sorted_cap_indices].squeeze()
44 | class_ids = class_ids[sorted_cap_indices].numpy()
45 | # sent_indices = sent_indices[sorted_cap_indices]
46 | keys = [keys[i] for i in sorted_cap_indices.numpy()]
47 | # print('keys', type(keys), keys[-1]) # list
48 | if cfg.CUDA:
49 | captions = Variable(captions).cuda()
50 | sorted_cap_lens = Variable(sorted_cap_lens).cuda()
51 | else:
52 | captions = Variable(captions)
53 | sorted_cap_lens = Variable(sorted_cap_lens)
54 |
55 | return [real_imgs, captions, sorted_cap_lens,
56 | class_ids, keys]
57 |
58 |
59 | def get_imgs(img_path, imsize, bbox=None,
60 | transform=None, normalize=None):
61 | img = Image.open(img_path).convert('RGB')
62 | width, height = img.size
63 | if bbox is not None:
64 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
65 | center_x = int((2 * bbox[0] + bbox[2]) / 2)
66 | center_y = int((2 * bbox[1] + bbox[3]) / 2)
67 | y1 = np.maximum(0, center_y - r)
68 | y2 = np.minimum(height, center_y + r)
69 | x1 = np.maximum(0, center_x - r)
70 | x2 = np.minimum(width, center_x + r)
71 | img = img.crop([x1, y1, x2, y2])
72 |
73 | if transform is not None:
74 | img = transform(img)
75 |
76 | ret = []
77 | if cfg.GAN.B_DCGAN:
78 | ret = [normalize(img)]
79 | else:
80 | for i in range(cfg.TREE.BRANCH_NUM):
81 | # print(imsize[i])
82 | if i < (cfg.TREE.BRANCH_NUM - 1):
83 | re_img = transforms.Scale(imsize[i])(img)
84 | else:
85 | re_img = img
86 | ret.append(normalize(re_img))
87 |
88 | return ret
89 |
90 |
91 | class TextDataset(data.Dataset):
92 | def __init__(self, data_dir, split='train',
93 | base_size=64,
94 | transform=None, target_transform=None):
95 | self.transform = transform
96 | self.norm = transforms.Compose([
97 | transforms.ToTensor(),
98 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
99 | self.target_transform = target_transform
100 | self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE
101 |
102 | self.imsize = []
103 | for i in range(cfg.TREE.BRANCH_NUM):
104 | self.imsize.append(base_size)
105 | base_size = base_size * 2
106 |
107 | self.data = []
108 | self.data_dir = data_dir
109 | if data_dir.find('birds') != -1:
110 | self.bbox = self.load_bbox()
111 | else:
112 | self.bbox = None
113 | split_dir = os.path.join(data_dir, split)
114 |
115 | self.filenames, self.captions, self.ixtoword, \
116 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split)
117 |
118 | self.class_id = self.load_class_id(split_dir, len(self.filenames))
119 | self.number_example = len(self.filenames)
120 |
121 | def load_bbox(self):
122 | data_dir = self.data_dir
123 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
124 | df_bounding_boxes = pd.read_csv(bbox_path,
125 | delim_whitespace=True,
126 | header=None).astype(int)
127 | #
128 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
129 | df_filenames = \
130 | pd.read_csv(filepath, delim_whitespace=True, header=None)
131 | filenames = df_filenames[1].tolist()
132 | print('Total filenames: ', len(filenames), filenames[0])
133 | #
134 | filename_bbox = {img_file[:-4]: [] for img_file in filenames}
135 | numImgs = len(filenames)
136 | for i in xrange(0, numImgs):
137 | # bbox = [x-left, y-top, width, height]
138 | bbox = df_bounding_boxes.iloc[i][1:].tolist()
139 |
140 | key = filenames[i][:-4]
141 | filename_bbox[key] = bbox
142 | #
143 | return filename_bbox
144 |
145 | def load_captions(self, data_dir, filenames):
146 | all_captions = []
147 | for i in range(len(filenames)):
148 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
149 | with open(cap_path, "r") as f:
150 | captions = f.read().decode('utf8').split('\n')
151 | cnt = 0
152 | for cap in captions:
153 | if len(cap) == 0:
154 | continue
155 | cap = cap.replace("\ufffd\ufffd", " ")
156 | # picks out sequences of alphanumeric characters as tokens
157 | # and drops everything else
158 | tokenizer = RegexpTokenizer(r'\w+')
159 | tokens = tokenizer.tokenize(cap.lower())
160 | # print('tokens', tokens)
161 | if len(tokens) == 0:
162 | print('cap', cap)
163 | continue
164 |
165 | tokens_new = []
166 | for t in tokens:
167 | t = t.encode('ascii', 'ignore').decode('ascii')
168 | if len(t) > 0:
169 | tokens_new.append(t)
170 | all_captions.append(tokens_new)
171 | cnt += 1
172 | if cnt == self.embeddings_num:
173 | break
174 | if cnt < self.embeddings_num:
175 | print('ERROR: the captions for %s less than %d'
176 | % (filenames[i], cnt))
177 | return all_captions
178 |
179 | def build_dictionary(self, train_captions, test_captions):
180 | word_counts = defaultdict(float)
181 | captions = train_captions + test_captions
182 | for sent in captions:
183 | for word in sent:
184 | word_counts[word] += 1
185 |
186 | vocab = [w for w in word_counts if word_counts[w] >= 0]
187 |
188 | ixtoword = {}
189 | ixtoword[0] = ''
190 | wordtoix = {}
191 | wordtoix[''] = 0
192 | ix = 1
193 | for w in vocab:
194 | wordtoix[w] = ix
195 | ixtoword[ix] = w
196 | ix += 1
197 |
198 | train_captions_new = []
199 | for t in train_captions:
200 | rev = []
201 | for w in t:
202 | if w in wordtoix:
203 | rev.append(wordtoix[w])
204 | # rev.append(0) # do not need '' token
205 | train_captions_new.append(rev)
206 |
207 | test_captions_new = []
208 | for t in test_captions:
209 | rev = []
210 | for w in t:
211 | if w in wordtoix:
212 | rev.append(wordtoix[w])
213 | # rev.append(0) # do not need '' token
214 | test_captions_new.append(rev)
215 |
216 | return [train_captions_new, test_captions_new,
217 | ixtoword, wordtoix, len(ixtoword)]
218 |
219 | def load_text_data(self, data_dir, split):
220 | filepath = os.path.join(data_dir, 'captions.pickle')
221 | train_names = self.load_filenames(data_dir, 'train')
222 | test_names = self.load_filenames(data_dir, 'test')
223 | if not os.path.isfile(filepath):
224 | train_captions = self.load_captions(data_dir, train_names)
225 | test_captions = self.load_captions(data_dir, test_names)
226 |
227 | train_captions, test_captions, ixtoword, wordtoix, n_words = \
228 | self.build_dictionary(train_captions, test_captions)
229 | with open(filepath, 'wb') as f:
230 | pickle.dump([train_captions, test_captions,
231 | ixtoword, wordtoix], f, protocol=2)
232 | print('Save to: ', filepath)
233 | else:
234 | with open(filepath, 'rb') as f:
235 | x = pickle.load(f)
236 | train_captions, test_captions = x[0], x[1]
237 | ixtoword, wordtoix = x[2], x[3]
238 | del x
239 | n_words = len(ixtoword)
240 | print('Load from: ', filepath)
241 | if split == 'train':
242 | # a list of list: each list contains
243 | # the indices of words in a sentence
244 | captions = train_captions
245 | filenames = train_names
246 | else: # split=='test'
247 | captions = test_captions
248 | filenames = test_names
249 | return filenames, captions, ixtoword, wordtoix, n_words
250 |
251 | def load_class_id(self, data_dir, total_num):
252 | if os.path.isfile(data_dir + '/class_info.pickle'):
253 | with open(data_dir + '/class_info.pickle', 'rb') as f:
254 | class_id = pickle.load(f)
255 | else:
256 | class_id = np.arange(total_num)
257 | return class_id
258 |
259 | def load_filenames(self, data_dir, split):
260 | filepath = '%s/%s/filenames.pickle' % (data_dir, split)
261 | if os.path.isfile(filepath):
262 | with open(filepath, 'rb') as f:
263 | filenames = pickle.load(f)
264 | print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
265 | else:
266 | filenames = []
267 | return filenames
268 |
269 | def get_caption(self, sent_ix):
270 | # a list of indices for a sentence
271 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')
272 | if (sent_caption == 0).sum() > 0:
273 | print('ERROR: do not need END (0) token', sent_caption)
274 | num_words = len(sent_caption)
275 | # pad with 0s (i.e., '')
276 | x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64')
277 | x_len = num_words
278 | if num_words <= cfg.TEXT.WORDS_NUM:
279 | x[:num_words, 0] = sent_caption
280 | else:
281 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum
282 | np.random.shuffle(ix)
283 | ix = ix[:cfg.TEXT.WORDS_NUM]
284 | ix = np.sort(ix)
285 | x[:, 0] = sent_caption[ix]
286 | x_len = cfg.TEXT.WORDS_NUM
287 | return x, x_len
288 |
289 | def __getitem__(self, index):
290 | #
291 | key = self.filenames[index]
292 | cls_id = self.class_id[index]
293 | #
294 | if self.bbox is not None:
295 | bbox = self.bbox[key]
296 | data_dir = '%s/CUB_200_2011' % self.data_dir
297 | else:
298 | bbox = None
299 | data_dir = self.data_dir
300 | #
301 | img_name = '%s/images/%s.jpg' % (data_dir, key)
302 | imgs = get_imgs(img_name, self.imsize,
303 | bbox, self.transform, normalize=self.norm)
304 | # random select a sentence
305 | sent_ix = random.randint(0, self.embeddings_num)
306 | new_sent_ix = index * self.embeddings_num + sent_ix
307 | caps, cap_len = self.get_caption(new_sent_ix)
308 | return imgs, caps, cap_len, cls_id, key
309 |
310 |
311 | def __len__(self):
312 | return len(self.filenames)
313 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from miscc.config import cfg, cfg_from_file
4 | from datasets import TextDataset
5 | from trainer import condGANTrainer as trainer
6 |
7 | import os
8 | import sys
9 | import time
10 | import random
11 | import pprint
12 | import datetime
13 | import dateutil.tz
14 | import argparse
15 | import numpy as np
16 |
17 | import torch
18 | import torchvision.transforms as transforms
19 |
20 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
21 | sys.path.append(dir_path)
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description='Train a AttnGAN network')
26 | parser.add_argument('--cfg', dest='cfg_file',
27 | help='optional config file',
28 | default='cfg/bird_attn2.yml', type=str)
29 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1)
30 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
31 | parser.add_argument('--manualSeed', type=int, help='manual seed')
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | def gen_example(wordtoix, algo):
37 | '''generate images from example sentences'''
38 | from nltk.tokenize import RegexpTokenizer
39 | filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)
40 | data_dic = {}
41 | with open(filepath, "r") as f:
42 | filenames = f.read().decode('utf8').split('\n')
43 | for name in filenames:
44 | if len(name) == 0:
45 | continue
46 | filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)
47 | with open(filepath, "r") as f:
48 | print('Load from:', name)
49 | sentences = f.read().decode('utf8').split('\n')
50 | # a list of indices for a sentence
51 | captions = []
52 | cap_lens = []
53 | for sent in sentences:
54 | if len(sent) == 0:
55 | continue
56 | sent = sent.replace("\ufffd\ufffd", " ")
57 | tokenizer = RegexpTokenizer(r'\w+')
58 | tokens = tokenizer.tokenize(sent.lower())
59 | if len(tokens) == 0:
60 | print('sent', sent)
61 | continue
62 |
63 | rev = []
64 | for t in tokens:
65 | t = t.encode('ascii', 'ignore').decode('ascii')
66 | if len(t) > 0 and t in wordtoix:
67 | rev.append(wordtoix[t])
68 | captions.append(rev)
69 | cap_lens.append(len(rev))
70 | max_len = np.max(cap_lens)
71 |
72 | sorted_indices = np.argsort(cap_lens)[::-1]
73 | cap_lens = np.asarray(cap_lens)
74 | cap_lens = cap_lens[sorted_indices]
75 | cap_array = np.zeros((len(captions), max_len), dtype='int64')
76 | for i in range(len(captions)):
77 | idx = sorted_indices[i]
78 | cap = captions[idx]
79 | c_len = len(cap)
80 | cap_array[i, :c_len] = cap
81 | key = name[(name.rfind('/') + 1):]
82 | data_dic[key] = [cap_array, cap_lens, sorted_indices]
83 | algo.gen_example(data_dic)
84 |
85 |
86 | if __name__ == "__main__":
87 | args = parse_args()
88 | if args.cfg_file is not None:
89 | cfg_from_file(args.cfg_file)
90 |
91 | if args.gpu_id != -1:
92 | cfg.GPU_ID = args.gpu_id
93 | else:
94 | cfg.CUDA = False
95 |
96 | if args.data_dir != '':
97 | cfg.DATA_DIR = args.data_dir
98 | print('Using config:')
99 | pprint.pprint(cfg)
100 |
101 | if not cfg.TRAIN.FLAG:
102 | args.manualSeed = 100
103 | elif args.manualSeed is None:
104 | args.manualSeed = random.randint(1, 10000)
105 | random.seed(args.manualSeed)
106 | np.random.seed(args.manualSeed)
107 | torch.manual_seed(args.manualSeed)
108 | if cfg.CUDA:
109 | torch.cuda.manual_seed_all(args.manualSeed)
110 |
111 | now = datetime.datetime.now(dateutil.tz.tzlocal())
112 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
113 | output_dir = '../output/%s_%s_%s' % \
114 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
115 |
116 | split_dir, bshuffle = 'train', True
117 | if not cfg.TRAIN.FLAG:
118 | # bshuffle = False
119 | split_dir = 'test'
120 |
121 | # Get data loader
122 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
123 | image_transform = transforms.Compose([
124 | transforms.Scale(int(imsize * 76 / 64)),
125 | transforms.RandomCrop(imsize),
126 | transforms.RandomHorizontalFlip()])
127 | dataset = TextDataset(cfg.DATA_DIR, split_dir,
128 | base_size=cfg.TREE.BASE_SIZE,
129 | transform=image_transform)
130 | assert dataset
131 | dataloader = torch.utils.data.DataLoader(
132 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
133 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))
134 |
135 | # Define models and go to train/evaluate
136 | algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword)
137 |
138 | start_t = time.time()
139 | if cfg.TRAIN.FLAG:
140 | algo.train()
141 | else:
142 | '''generate images from pre-extracted embeddings'''
143 | if cfg.B_VALIDATION:
144 | algo.sampling(split_dir) # generate images for the whole valid dataset
145 | else:
146 | gen_example(dataset.wordtoix, algo) # generate images for customized captions
147 | end_t = time.time()
148 | print('Total time for training:', end_t - start_t)
149 |
--------------------------------------------------------------------------------
/code/miscc/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
--------------------------------------------------------------------------------
/code/miscc/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import os.path as osp
5 | import numpy as np
6 | from easydict import EasyDict as edict
7 |
8 |
9 | __C = edict()
10 | cfg = __C
11 |
12 | # Dataset name: flowers, birds
13 | __C.DATASET_NAME = 'birds'
14 | __C.CONFIG_NAME = ''
15 | __C.DATA_DIR = ''
16 | __C.GPU_ID = 0
17 | __C.CUDA = True
18 | __C.WORKERS = 6
19 |
20 | __C.RNN_TYPE = 'LSTM' # 'GRU'
21 | __C.B_VALIDATION = False
22 |
23 | __C.TREE = edict()
24 | __C.TREE.BRANCH_NUM = 3
25 | __C.TREE.BASE_SIZE = 64
26 |
27 |
28 | # Training options
29 | __C.TRAIN = edict()
30 | __C.TRAIN.BATCH_SIZE = 64
31 | __C.TRAIN.MAX_EPOCH = 600
32 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000
33 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4
34 | __C.TRAIN.GENERATOR_LR = 2e-4
35 | __C.TRAIN.ENCODER_LR = 2e-4
36 | __C.TRAIN.RNN_GRAD_CLIP = 0.25
37 | __C.TRAIN.FLAG = True
38 | __C.TRAIN.NET_E = ''
39 | __C.TRAIN.NET_G = ''
40 | __C.TRAIN.B_NET_D = True
41 |
42 | __C.TRAIN.SMOOTH = edict()
43 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0
44 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0
45 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0
46 | __C.TRAIN.SMOOTH.LAMBDA = 1.0
47 |
48 |
49 | # Modal options
50 | __C.GAN = edict()
51 | __C.GAN.DF_DIM = 64
52 | __C.GAN.GF_DIM = 128
53 | __C.GAN.Z_DIM = 100
54 | __C.GAN.CONDITION_DIM = 100
55 | __C.GAN.R_NUM = 2
56 | __C.GAN.B_ATTENTION = True
57 | __C.GAN.B_DCGAN = False
58 |
59 |
60 | __C.TEXT = edict()
61 | __C.TEXT.CAPTIONS_PER_IMAGE = 10
62 | __C.TEXT.EMBEDDING_DIM = 256
63 | __C.TEXT.WORDS_NUM = 18
64 |
65 |
66 | def _merge_a_into_b(a, b):
67 | """Merge config dictionary a into config dictionary b, clobbering the
68 | options in b whenever they are also specified in a.
69 | """
70 | if type(a) is not edict:
71 | return
72 |
73 | for k, v in a.iteritems():
74 | # a must specify keys that are in b
75 | if not b.has_key(k):
76 | raise KeyError('{} is not a valid config key'.format(k))
77 |
78 | # the types must match, too
79 | old_type = type(b[k])
80 | if old_type is not type(v):
81 | if isinstance(b[k], np.ndarray):
82 | v = np.array(v, dtype=b[k].dtype)
83 | else:
84 | raise ValueError(('Type mismatch ({} vs. {}) '
85 | 'for config key: {}').format(type(b[k]),
86 | type(v), k))
87 |
88 | # recursively merge dicts
89 | if type(v) is edict:
90 | try:
91 | _merge_a_into_b(a[k], b[k])
92 | except:
93 | print('Error under config key: {}'.format(k))
94 | raise
95 | else:
96 | b[k] = v
97 |
98 |
99 | def cfg_from_file(filename):
100 | """Load a config file and merge it into the default options."""
101 | import yaml
102 | with open(filename, 'r') as f:
103 | yaml_cfg = edict(yaml.load(f))
104 |
105 | _merge_a_into_b(yaml_cfg, __C)
106 |
--------------------------------------------------------------------------------
/code/miscc/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import numpy as np
5 | from miscc.config import cfg
6 |
7 | from GlobalAttention import func_attention
8 |
9 |
10 | # ##################Loss for matching text-image###################
11 | def cosine_similarity(x1, x2, dim=1, eps=1e-8):
12 | """Returns cosine similarity between x1 and x2, computed along dim.
13 | """
14 | w12 = torch.sum(x1 * x2, dim)
15 | w1 = torch.norm(x1, 2, dim)
16 | w2 = torch.norm(x2, 2, dim)
17 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
18 |
19 |
20 | def sent_loss(cnn_code, rnn_code, labels, class_ids,
21 | batch_size, eps=1e-8):
22 | # ### Mask mis-match samples ###
23 | # that come from the same class as the real sample ###
24 | masks = []
25 | if class_ids is not None:
26 | for i in range(batch_size):
27 | mask = (class_ids == class_ids[i]).astype(np.uint8)
28 | mask[i] = 0
29 | masks.append(mask.reshape((1, -1)))
30 | masks = np.concatenate(masks, 0)
31 | # masks: batch_size x batch_size
32 | masks = torch.ByteTensor(masks)
33 | if cfg.CUDA:
34 | masks = masks.cuda()
35 |
36 | # --> seq_len x batch_size x nef
37 | if cnn_code.dim() == 2:
38 | cnn_code = cnn_code.unsqueeze(0)
39 | rnn_code = rnn_code.unsqueeze(0)
40 |
41 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1
42 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
43 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
44 | # scores* / norm*: seq_len x batch_size x batch_size
45 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
46 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
47 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3
48 |
49 | # --> batch_size x batch_size
50 | scores0 = scores0.squeeze()
51 | if class_ids is not None:
52 | scores0.data.masked_fill_(masks, -float('inf'))
53 | scores1 = scores0.transpose(0, 1)
54 | if labels is not None:
55 | loss0 = nn.CrossEntropyLoss()(scores0, labels)
56 | loss1 = nn.CrossEntropyLoss()(scores1, labels)
57 | else:
58 | loss0, loss1 = None, None
59 | return loss0, loss1
60 |
61 |
62 | def words_loss(img_features, words_emb, labels,
63 | cap_lens, class_ids, batch_size):
64 | """
65 | words_emb(query): batch x nef x seq_len
66 | img_features(context): batch x nef x 17 x 17
67 | """
68 | masks = []
69 | att_maps = []
70 | similarities = []
71 | cap_lens = cap_lens.data.tolist()
72 | for i in range(batch_size):
73 | if class_ids is not None:
74 | mask = (class_ids == class_ids[i]).astype(np.uint8)
75 | mask[i] = 0
76 | masks.append(mask.reshape((1, -1)))
77 | # Get the i-th text description
78 | words_num = cap_lens[i]
79 | # -> 1 x nef x words_num
80 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
81 | # -> batch_size x nef x words_num
82 | word = word.repeat(batch_size, 1, 1)
83 | # batch x nef x 17*17
84 | context = img_features
85 | """
86 | word(query): batch x nef x words_num
87 | context: batch x nef x 17 x 17
88 | weiContext: batch x nef x words_num
89 | attn: batch x words_num x 17 x 17
90 | """
91 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)
92 | att_maps.append(attn[i].unsqueeze(0).contiguous())
93 | # --> batch_size x words_num x nef
94 | word = word.transpose(1, 2).contiguous()
95 | weiContext = weiContext.transpose(1, 2).contiguous()
96 | # --> batch_size*words_num x nef
97 | word = word.view(batch_size * words_num, -1)
98 | weiContext = weiContext.view(batch_size * words_num, -1)
99 | #
100 | # -->batch_size*words_num
101 | row_sim = cosine_similarity(word, weiContext)
102 | # --> batch_size x words_num
103 | row_sim = row_sim.view(batch_size, words_num)
104 |
105 | # Eq. (10)
106 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()
107 | row_sim = row_sim.sum(dim=1, keepdim=True)
108 | row_sim = torch.log(row_sim)
109 |
110 | # --> 1 x batch_size
111 | # similarities(i, j): the similarity between the i-th image and the j-th text description
112 | similarities.append(row_sim)
113 |
114 | # batch_size x batch_size
115 | similarities = torch.cat(similarities, 1)
116 | if class_ids is not None:
117 | masks = np.concatenate(masks, 0)
118 | # masks: batch_size x batch_size
119 | masks = torch.ByteTensor(masks)
120 | if cfg.CUDA:
121 | masks = masks.cuda()
122 |
123 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3
124 | if class_ids is not None:
125 | similarities.data.masked_fill_(masks, -float('inf'))
126 | similarities1 = similarities.transpose(0, 1)
127 | if labels is not None:
128 | loss0 = nn.CrossEntropyLoss()(similarities, labels)
129 | loss1 = nn.CrossEntropyLoss()(similarities1, labels)
130 | else:
131 | loss0, loss1 = None, None
132 | return loss0, loss1, att_maps
133 |
134 |
135 | # ##################Loss for G and Ds##############################
136 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
137 | real_labels, fake_labels):
138 | # Forward
139 | real_features = netD(real_imgs)
140 | fake_features = netD(fake_imgs.detach())
141 | # loss
142 | #
143 | cond_real_logits = netD.COND_DNET(real_features, conditions)
144 | cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)
145 | cond_fake_logits = netD.COND_DNET(fake_features, conditions)
146 | cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)
147 | #
148 | batch_size = real_features.size(0)
149 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])
150 | cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])
151 |
152 | if netD.UNCOND_DNET is not None:
153 | real_logits = netD.UNCOND_DNET(real_features)
154 | fake_logits = netD.UNCOND_DNET(fake_features)
155 | real_errD = nn.BCELoss()(real_logits, real_labels)
156 | fake_errD = nn.BCELoss()(fake_logits, fake_labels)
157 | errD = ((real_errD + cond_real_errD) / 2. +
158 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)
159 | else:
160 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.
161 | return errD
162 |
163 |
164 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
165 | words_embs, sent_emb, match_labels,
166 | cap_lens, class_ids):
167 | numDs = len(netsD)
168 | batch_size = real_labels.size(0)
169 | logs = ''
170 | # Forward
171 | errG_total = 0
172 | for i in range(numDs):
173 | features = netsD[i](fake_imgs[i])
174 | cond_logits = netsD[i].COND_DNET(features, sent_emb)
175 | cond_errG = nn.BCELoss()(cond_logits, real_labels)
176 | if netsD[i].UNCOND_DNET is not None:
177 | logits = netsD[i].UNCOND_DNET(features)
178 | errG = nn.BCELoss()(logits, real_labels)
179 | g_loss = errG + cond_errG
180 | else:
181 | g_loss = cond_errG
182 | errG_total += g_loss
183 | # err_img = errG_total.data[0]
184 | logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])
185 |
186 | # Ranking loss
187 | if i == (numDs - 1):
188 | # words_features: batch_size x nef x 17 x 17
189 | # sent_code: batch_size x nef
190 | region_features, cnn_code = image_encoder(fake_imgs[i])
191 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
192 | match_labels, cap_lens,
193 | class_ids, batch_size)
194 | w_loss = (w_loss0 + w_loss1) * \
195 | cfg.TRAIN.SMOOTH.LAMBDA
196 | # err_words = err_words + w_loss.data[0]
197 |
198 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
199 | match_labels, class_ids, batch_size)
200 | s_loss = (s_loss0 + s_loss1) * \
201 | cfg.TRAIN.SMOOTH.LAMBDA
202 | # err_sent = err_sent + s_loss.data[0]
203 |
204 | errG_total += w_loss + s_loss
205 | logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0])
206 | return errG_total, logs
207 |
208 |
209 | ##################################################################
210 | def KL_loss(mu, logvar):
211 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
212 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
213 | KLD = torch.mean(KLD_element).mul_(-0.5)
214 | return KLD
215 |
--------------------------------------------------------------------------------
/code/miscc/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import numpy as np
4 | from torch.nn import init
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from PIL import Image, ImageDraw, ImageFont
10 | from copy import deepcopy
11 | import skimage.transform
12 |
13 | from miscc.config import cfg
14 |
15 |
16 | # For visualization ################################################
17 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232],
18 | 2:[70, 70, 70], 3:[102,102,156],
19 | 4:[190,153,153], 5:[153,153,153],
20 | 6:[250,170, 30], 7:[220, 220, 0],
21 | 8:[107,142, 35], 9:[152,251,152],
22 | 10:[70,130,180], 11:[220,20, 60],
23 | 12:[255, 0, 0], 13:[0, 0, 142],
24 | 14:[119,11, 32], 15:[0, 60,100],
25 | 16:[0, 80, 100], 17:[0, 0, 230],
26 | 18:[0, 0, 70], 19:[0, 0, 0]}
27 | FONT_MAX = 50
28 |
29 |
30 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):
31 | num = captions.size(0)
32 | img_txt = Image.fromarray(convas)
33 | # get a font
34 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
35 | fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
36 | # get a drawing context
37 | d = ImageDraw.Draw(img_txt)
38 | sentence_list = []
39 | for i in range(num):
40 | cap = captions[i].data.cpu().numpy()
41 | sentence = []
42 | for j in range(len(cap)):
43 | if cap[j] == 0:
44 | break
45 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
46 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]),
47 | font=fnt, fill=(255, 255, 255, 255))
48 | sentence.append(word)
49 | sentence_list.append(sentence)
50 | return img_txt, sentence_list
51 |
52 |
53 | def build_super_images(real_imgs, captions, ixtoword,
54 | attn_maps, att_sze, lr_imgs=None,
55 | batch_size=cfg.TRAIN.BATCH_SIZE,
56 | max_word_num=cfg.TEXT.WORDS_NUM):
57 | nvis = 8
58 | real_imgs = real_imgs[:nvis]
59 | if lr_imgs is not None:
60 | lr_imgs = lr_imgs[:nvis]
61 | if att_sze == 17:
62 | vis_size = att_sze * 16
63 | else:
64 | vis_size = real_imgs.size(2)
65 |
66 | text_convas = \
67 | np.ones([batch_size * FONT_MAX,
68 | (max_word_num + 2) * (vis_size + 2), 3],
69 | dtype=np.uint8)
70 |
71 | for i in range(max_word_num):
72 | istart = (i + 2) * (vis_size + 2)
73 | iend = (i + 3) * (vis_size + 2)
74 | text_convas[:, istart:iend, :] = COLOR_DIC[i]
75 |
76 |
77 | real_imgs = \
78 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
79 | # [-1, 1] --> [0, 1]
80 | real_imgs.add_(1).div_(2).mul_(255)
81 | real_imgs = real_imgs.data.numpy()
82 | # b x c x h x w --> b x h x w x c
83 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
84 | pad_sze = real_imgs.shape
85 | middle_pad = np.zeros([pad_sze[2], 2, 3])
86 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3])
87 | if lr_imgs is not None:
88 | lr_imgs = \
89 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs)
90 | # [-1, 1] --> [0, 1]
91 | lr_imgs.add_(1).div_(2).mul_(255)
92 | lr_imgs = lr_imgs.data.numpy()
93 | # b x c x h x w --> b x h x w x c
94 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1))
95 |
96 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
97 | seq_len = max_word_num
98 | img_set = []
99 | num = nvis # len(attn_maps)
100 |
101 | text_map, sentences = \
102 | drawCaption(text_convas, captions, ixtoword, vis_size)
103 | text_map = np.asarray(text_map).astype(np.uint8)
104 |
105 | bUpdate = 1
106 | for i in range(num):
107 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
108 | # --> 1 x 1 x 17 x 17
109 | attn_max = attn.max(dim=1, keepdim=True)
110 | attn = torch.cat([attn_max[0], attn], 1)
111 | #
112 | attn = attn.view(-1, 1, att_sze, att_sze)
113 | attn = attn.repeat(1, 3, 1, 1).data.numpy()
114 | # n x c x h x w --> n x h x w x c
115 | attn = np.transpose(attn, (0, 2, 3, 1))
116 | num_attn = attn.shape[0]
117 | #
118 | img = real_imgs[i]
119 | if lr_imgs is None:
120 | lrI = img
121 | else:
122 | lrI = lr_imgs[i]
123 | row = [lrI, middle_pad]
124 | row_merge = [img, middle_pad]
125 | row_beforeNorm = []
126 | minVglobal, maxVglobal = 1, 0
127 | for j in range(num_attn):
128 | one_map = attn[j]
129 | if (vis_size // att_sze) > 1:
130 | one_map = \
131 | skimage.transform.pyramid_expand(one_map, sigma=20,
132 | upscale=vis_size // att_sze)
133 | row_beforeNorm.append(one_map)
134 | minV = one_map.min()
135 | maxV = one_map.max()
136 | if minVglobal > minV:
137 | minVglobal = minV
138 | if maxVglobal < maxV:
139 | maxVglobal = maxV
140 | for j in range(seq_len + 1):
141 | if j < num_attn:
142 | one_map = row_beforeNorm[j]
143 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal)
144 | one_map *= 255
145 | #
146 | PIL_im = Image.fromarray(np.uint8(img))
147 | PIL_att = Image.fromarray(np.uint8(one_map))
148 | merged = \
149 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
150 | mask = Image.new('L', (vis_size, vis_size), (210))
151 | merged.paste(PIL_im, (0, 0))
152 | merged.paste(PIL_att, (0, 0), mask)
153 | merged = np.array(merged)[:, :, :3]
154 | else:
155 | one_map = post_pad
156 | merged = post_pad
157 | row.append(one_map)
158 | row.append(middle_pad)
159 | #
160 | row_merge.append(merged)
161 | row_merge.append(middle_pad)
162 | row = np.concatenate(row, 1)
163 | row_merge = np.concatenate(row_merge, 1)
164 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX]
165 | if txt.shape[1] != row.shape[1]:
166 | print('txt', txt.shape, 'row', row.shape)
167 | bUpdate = 0
168 | break
169 | row = np.concatenate([txt, row, row_merge], 0)
170 | img_set.append(row)
171 | if bUpdate:
172 | img_set = np.concatenate(img_set, 0)
173 | img_set = img_set.astype(np.uint8)
174 | return img_set, sentences
175 | else:
176 | return None
177 |
178 |
179 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
180 | attn_maps, att_sze, vis_size=256, topK=5):
181 | batch_size = real_imgs.size(0)
182 | max_word_num = np.max(cap_lens)
183 | text_convas = np.ones([batch_size * FONT_MAX,
184 | max_word_num * (vis_size + 2), 3],
185 | dtype=np.uint8)
186 |
187 | real_imgs = \
188 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
189 | # [-1, 1] --> [0, 1]
190 | real_imgs.add_(1).div_(2).mul_(255)
191 | real_imgs = real_imgs.data.numpy()
192 | # b x c x h x w --> b x h x w x c
193 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
194 | pad_sze = real_imgs.shape
195 | middle_pad = np.zeros([pad_sze[2], 2, 3])
196 |
197 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
198 | img_set = []
199 | num = len(attn_maps)
200 |
201 | text_map, sentences = \
202 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0)
203 | text_map = np.asarray(text_map).astype(np.uint8)
204 |
205 | bUpdate = 1
206 | for i in range(num):
207 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
208 | #
209 | attn = attn.view(-1, 1, att_sze, att_sze)
210 | attn = attn.repeat(1, 3, 1, 1).data.numpy()
211 | # n x c x h x w --> n x h x w x c
212 | attn = np.transpose(attn, (0, 2, 3, 1))
213 | num_attn = cap_lens[i]
214 | thresh = 2./float(num_attn)
215 | #
216 | img = real_imgs[i]
217 | row = []
218 | row_merge = []
219 | row_txt = []
220 | row_beforeNorm = []
221 | conf_score = []
222 | for j in range(num_attn):
223 | one_map = attn[j]
224 | mask0 = one_map > (2. * thresh)
225 | conf_score.append(np.sum(one_map * mask0))
226 | mask = one_map > thresh
227 | one_map = one_map * mask
228 | if (vis_size // att_sze) > 1:
229 | one_map = \
230 | skimage.transform.pyramid_expand(one_map, sigma=20,
231 | upscale=vis_size // att_sze)
232 | minV = one_map.min()
233 | maxV = one_map.max()
234 | one_map = (one_map - minV) / (maxV - minV)
235 | row_beforeNorm.append(one_map)
236 | sorted_indices = np.argsort(conf_score)[::-1]
237 |
238 | for j in range(num_attn):
239 | one_map = row_beforeNorm[j]
240 | one_map *= 255
241 | #
242 | PIL_im = Image.fromarray(np.uint8(img))
243 | PIL_att = Image.fromarray(np.uint8(one_map))
244 | merged = \
245 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
246 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210)
247 | merged.paste(PIL_im, (0, 0))
248 | merged.paste(PIL_att, (0, 0), mask)
249 | merged = np.array(merged)[:, :, :3]
250 |
251 | row.append(np.concatenate([one_map, middle_pad], 1))
252 | #
253 | row_merge.append(np.concatenate([merged, middle_pad], 1))
254 | #
255 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX,
256 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :]
257 | row_txt.append(txt)
258 | # reorder
259 | row_new = []
260 | row_merge_new = []
261 | txt_new = []
262 | for j in range(num_attn):
263 | idx = sorted_indices[j]
264 | row_new.append(row[idx])
265 | row_merge_new.append(row_merge[idx])
266 | txt_new.append(row_txt[idx])
267 | row = np.concatenate(row_new[:topK], 1)
268 | row_merge = np.concatenate(row_merge_new[:topK], 1)
269 | txt = np.concatenate(txt_new[:topK], 1)
270 | if txt.shape[1] != row.shape[1]:
271 | print('Warnings: txt', txt.shape, 'row', row.shape,
272 | 'row_merge_new', row_merge_new.shape)
273 | bUpdate = 0
274 | break
275 | row = np.concatenate([txt, row_merge], 0)
276 | img_set.append(row)
277 | if bUpdate:
278 | img_set = np.concatenate(img_set, 0)
279 | img_set = img_set.astype(np.uint8)
280 | return img_set, sentences
281 | else:
282 | return None
283 |
284 |
285 | ####################################################################
286 | def weights_init(m):
287 | classname = m.__class__.__name__
288 | if classname.find('Conv') != -1:
289 | nn.init.orthogonal(m.weight.data, 1.0)
290 | elif classname.find('BatchNorm') != -1:
291 | m.weight.data.normal_(1.0, 0.02)
292 | m.bias.data.fill_(0)
293 | elif classname.find('Linear') != -1:
294 | nn.init.orthogonal(m.weight.data, 1.0)
295 | if m.bias is not None:
296 | m.bias.data.fill_(0.0)
297 |
298 |
299 | def load_params(model, new_param):
300 | for p, new_p in zip(model.parameters(), new_param):
301 | p.data.copy_(new_p)
302 |
303 |
304 | def copy_G_params(model):
305 | flatten = deepcopy(list(p.data for p in model.parameters()))
306 | return flatten
307 |
308 |
309 | def mkdir_p(path):
310 | try:
311 | os.makedirs(path)
312 | except OSError as exc: # Python >2.5
313 | if exc.errno == errno.EEXIST and os.path.isdir(path):
314 | pass
315 | else:
316 | raise
317 |
--------------------------------------------------------------------------------
/code/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | from torch.autograd import Variable
5 | from torchvision import models
6 | import torch.utils.model_zoo as model_zoo
7 | import torch.nn.functional as F
8 |
9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
10 |
11 | from miscc.config import cfg
12 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET
13 |
14 |
15 | class GLU(nn.Module):
16 | def __init__(self):
17 | super(GLU, self).__init__()
18 |
19 | def forward(self, x):
20 | nc = x.size(1)
21 | assert nc % 2 == 0, 'channels dont divide 2!'
22 | nc = int(nc/2)
23 | return x[:, :nc] * F.sigmoid(x[:, nc:])
24 |
25 |
26 | def conv1x1(in_planes, out_planes, bias=False):
27 | "1x1 convolution with padding"
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
29 | padding=0, bias=bias)
30 |
31 |
32 | def conv3x3(in_planes, out_planes):
33 | "3x3 convolution with padding"
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
35 | padding=1, bias=False)
36 |
37 |
38 | # Upsale the spatial size by a factor of 2
39 | def upBlock(in_planes, out_planes):
40 | block = nn.Sequential(
41 | nn.Upsample(scale_factor=2, mode='nearest'),
42 | conv3x3(in_planes, out_planes * 2),
43 | nn.BatchNorm2d(out_planes * 2),
44 | GLU())
45 | return block
46 |
47 |
48 | # Keep the spatial size
49 | def Block3x3_relu(in_planes, out_planes):
50 | block = nn.Sequential(
51 | conv3x3(in_planes, out_planes * 2),
52 | nn.BatchNorm2d(out_planes * 2),
53 | GLU())
54 | return block
55 |
56 |
57 | class ResBlock(nn.Module):
58 | def __init__(self, channel_num):
59 | super(ResBlock, self).__init__()
60 | self.block = nn.Sequential(
61 | conv3x3(channel_num, channel_num * 2),
62 | nn.BatchNorm2d(channel_num * 2),
63 | GLU(),
64 | conv3x3(channel_num, channel_num),
65 | nn.BatchNorm2d(channel_num))
66 |
67 | def forward(self, x):
68 | residual = x
69 | out = self.block(x)
70 | out += residual
71 | return out
72 |
73 |
74 | # ############## Text2Image Encoder-Decoder #######
75 | class RNN_ENCODER(nn.Module):
76 | def __init__(self, ntoken, ninput=300, drop_prob=0.5,
77 | nhidden=128, nlayers=1, bidirectional=True):
78 | super(RNN_ENCODER, self).__init__()
79 | self.n_steps = cfg.TEXT.WORDS_NUM
80 | self.ntoken = ntoken # size of the dictionary
81 | self.ninput = ninput # size of each embedding vector
82 | self.drop_prob = drop_prob # probability of an element to be zeroed
83 | self.nlayers = nlayers # Number of recurrent layers
84 | self.bidirectional = bidirectional
85 | self.rnn_type = cfg.RNN_TYPE
86 | if bidirectional:
87 | self.num_directions = 2
88 | else:
89 | self.num_directions = 1
90 | # number of features in the hidden state
91 | self.nhidden = nhidden // self.num_directions
92 |
93 | self.define_module()
94 | self.init_weights()
95 |
96 | def define_module(self):
97 | self.encoder = nn.Embedding(self.ntoken, self.ninput)
98 | self.drop = nn.Dropout(self.drop_prob)
99 | if self.rnn_type == 'LSTM':
100 | # dropout: If non-zero, introduces a dropout layer on
101 | # the outputs of each RNN layer except the last layer
102 | self.rnn = nn.LSTM(self.ninput, self.nhidden,
103 | self.nlayers, batch_first=True,
104 | dropout=self.drop_prob,
105 | bidirectional=self.bidirectional)
106 | elif self.rnn_type == 'GRU':
107 | self.rnn = nn.GRU(self.ninput, self.nhidden,
108 | self.nlayers, batch_first=True,
109 | dropout=self.drop_prob,
110 | bidirectional=self.bidirectional)
111 | else:
112 | raise NotImplementedError
113 |
114 | def init_weights(self):
115 | initrange = 0.1
116 | self.encoder.weight.data.uniform_(-initrange, initrange)
117 | # Do not need to initialize RNN parameters, which have been initialized
118 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM
119 | # self.decoder.weight.data.uniform_(-initrange, initrange)
120 | # self.decoder.bias.data.fill_(0)
121 |
122 | def init_hidden(self, bsz):
123 | weight = next(self.parameters()).data
124 | if self.rnn_type == 'LSTM':
125 | return (Variable(weight.new(self.nlayers * self.num_directions,
126 | bsz, self.nhidden).zero_()),
127 | Variable(weight.new(self.nlayers * self.num_directions,
128 | bsz, self.nhidden).zero_()))
129 | else:
130 | return Variable(weight.new(self.nlayers * self.num_directions,
131 | bsz, self.nhidden).zero_())
132 |
133 | def forward(self, captions, cap_lens, hidden, mask=None):
134 | # input: torch.LongTensor of size batch x n_steps
135 | # --> emb: batch x n_steps x ninput
136 | emb = self.drop(self.encoder(captions))
137 | #
138 | # Returns: a PackedSequence object
139 | cap_lens = cap_lens.data.tolist()
140 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
141 | # #hidden and memory (num_layers * num_directions, batch, hidden_size):
142 | # tensor containing the initial hidden state for each element in batch.
143 | # #output (batch, seq_len, hidden_size * num_directions)
144 | # #or a PackedSequence object:
145 | # tensor containing output features (h_t) from the last layer of RNN
146 | output, hidden = self.rnn(emb, hidden)
147 | # PackedSequence object
148 | # --> (batch, seq_len, hidden_size * num_directions)
149 | output = pad_packed_sequence(output, batch_first=True)[0]
150 | # output = self.drop(output)
151 | # --> batch x hidden_size*num_directions x seq_len
152 | words_emb = output.transpose(1, 2)
153 | # --> batch x num_directions*hidden_size
154 | if self.rnn_type == 'LSTM':
155 | sent_emb = hidden[0].transpose(0, 1).contiguous()
156 | else:
157 | sent_emb = hidden.transpose(0, 1).contiguous()
158 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
159 | return words_emb, sent_emb
160 |
161 |
162 | class CNN_ENCODER(nn.Module):
163 | def __init__(self, nef):
164 | super(CNN_ENCODER, self).__init__()
165 | if cfg.TRAIN.FLAG:
166 | self.nef = nef
167 | else:
168 | self.nef = 256 # define a uniform ranker
169 |
170 | model = models.inception_v3()
171 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
172 | model.load_state_dict(model_zoo.load_url(url))
173 | for param in model.parameters():
174 | param.requires_grad = False
175 | print('Load pretrained model from ', url)
176 | # print(model)
177 |
178 | self.define_module(model)
179 | self.init_trainable_weights()
180 |
181 | def define_module(self, model):
182 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
183 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
184 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
185 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
186 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
187 | self.Mixed_5b = model.Mixed_5b
188 | self.Mixed_5c = model.Mixed_5c
189 | self.Mixed_5d = model.Mixed_5d
190 | self.Mixed_6a = model.Mixed_6a
191 | self.Mixed_6b = model.Mixed_6b
192 | self.Mixed_6c = model.Mixed_6c
193 | self.Mixed_6d = model.Mixed_6d
194 | self.Mixed_6e = model.Mixed_6e
195 | self.Mixed_7a = model.Mixed_7a
196 | self.Mixed_7b = model.Mixed_7b
197 | self.Mixed_7c = model.Mixed_7c
198 |
199 | self.emb_features = conv1x1(768, self.nef)
200 | self.emb_cnn_code = nn.Linear(2048, self.nef)
201 |
202 | def init_trainable_weights(self):
203 | initrange = 0.1
204 | self.emb_features.weight.data.uniform_(-initrange, initrange)
205 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
206 |
207 | def forward(self, x):
208 | features = None
209 | # --> fixed-size input: batch x 3 x 299 x 299
210 | x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
211 | # 299 x 299 x 3
212 | x = self.Conv2d_1a_3x3(x)
213 | # 149 x 149 x 32
214 | x = self.Conv2d_2a_3x3(x)
215 | # 147 x 147 x 32
216 | x = self.Conv2d_2b_3x3(x)
217 | # 147 x 147 x 64
218 | x = F.max_pool2d(x, kernel_size=3, stride=2)
219 | # 73 x 73 x 64
220 | x = self.Conv2d_3b_1x1(x)
221 | # 73 x 73 x 80
222 | x = self.Conv2d_4a_3x3(x)
223 | # 71 x 71 x 192
224 |
225 | x = F.max_pool2d(x, kernel_size=3, stride=2)
226 | # 35 x 35 x 192
227 | x = self.Mixed_5b(x)
228 | # 35 x 35 x 256
229 | x = self.Mixed_5c(x)
230 | # 35 x 35 x 288
231 | x = self.Mixed_5d(x)
232 | # 35 x 35 x 288
233 |
234 | x = self.Mixed_6a(x)
235 | # 17 x 17 x 768
236 | x = self.Mixed_6b(x)
237 | # 17 x 17 x 768
238 | x = self.Mixed_6c(x)
239 | # 17 x 17 x 768
240 | x = self.Mixed_6d(x)
241 | # 17 x 17 x 768
242 | x = self.Mixed_6e(x)
243 | # 17 x 17 x 768
244 |
245 | # image region features
246 | features = x
247 | # 17 x 17 x 768
248 |
249 | x = self.Mixed_7a(x)
250 | # 8 x 8 x 1280
251 | x = self.Mixed_7b(x)
252 | # 8 x 8 x 2048
253 | x = self.Mixed_7c(x)
254 | # 8 x 8 x 2048
255 | x = F.avg_pool2d(x, kernel_size=8)
256 | # 1 x 1 x 2048
257 | # x = F.dropout(x, training=self.training)
258 | # 1 x 1 x 2048
259 | x = x.view(x.size(0), -1)
260 | # 2048
261 |
262 | # global image features
263 | cnn_code = self.emb_cnn_code(x)
264 | # 512
265 | if features is not None:
266 | features = self.emb_features(features)
267 | return features, cnn_code
268 |
269 |
270 | # ############## G networks ###################
271 | class CA_NET(nn.Module):
272 | # some code is modified from vae examples
273 | # (https://github.com/pytorch/examples/blob/master/vae/main.py)
274 | def __init__(self):
275 | super(CA_NET, self).__init__()
276 | self.t_dim = cfg.TEXT.EMBEDDING_DIM
277 | self.c_dim = cfg.GAN.CONDITION_DIM
278 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
279 | self.relu = GLU()
280 |
281 | def encode(self, text_embedding):
282 | x = self.relu(self.fc(text_embedding))
283 | mu = x[:, :self.c_dim]
284 | logvar = x[:, self.c_dim:]
285 | return mu, logvar
286 |
287 | def reparametrize(self, mu, logvar):
288 | std = logvar.mul(0.5).exp_()
289 | if cfg.CUDA:
290 | eps = torch.cuda.FloatTensor(std.size()).normal_()
291 | else:
292 | eps = torch.FloatTensor(std.size()).normal_()
293 | eps = Variable(eps)
294 | return eps.mul(std).add_(mu)
295 |
296 | def forward(self, text_embedding):
297 | mu, logvar = self.encode(text_embedding)
298 | c_code = self.reparametrize(mu, logvar)
299 | return c_code, mu, logvar
300 |
301 |
302 | class INIT_STAGE_G(nn.Module):
303 | def __init__(self, ngf, ncf):
304 | super(INIT_STAGE_G, self).__init__()
305 | self.gf_dim = ngf
306 | self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM
307 |
308 | self.define_module()
309 |
310 | def define_module(self):
311 | nz, ngf = self.in_dim, self.gf_dim
312 | self.fc = nn.Sequential(
313 | nn.Linear(nz, ngf * 4 * 4 * 2, bias=False),
314 | nn.BatchNorm1d(ngf * 4 * 4 * 2),
315 | GLU())
316 |
317 | self.upsample1 = upBlock(ngf, ngf // 2)
318 | self.upsample2 = upBlock(ngf // 2, ngf // 4)
319 | self.upsample3 = upBlock(ngf // 4, ngf // 8)
320 | self.upsample4 = upBlock(ngf // 8, ngf // 16)
321 |
322 | def forward(self, z_code, c_code):
323 | """
324 | :param z_code: batch x cfg.GAN.Z_DIM
325 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM
326 | :return: batch x ngf/16 x 64 x 64
327 | """
328 | c_z_code = torch.cat((c_code, z_code), 1)
329 | # state size ngf x 4 x 4
330 | out_code = self.fc(c_z_code)
331 | out_code = out_code.view(-1, self.gf_dim, 4, 4)
332 | # state size ngf/3 x 8 x 8
333 | out_code = self.upsample1(out_code)
334 | # state size ngf/4 x 16 x 16
335 | out_code = self.upsample2(out_code)
336 | # state size ngf/8 x 32 x 32
337 | out_code32 = self.upsample3(out_code)
338 | # state size ngf/16 x 64 x 64
339 | out_code64 = self.upsample4(out_code32)
340 |
341 | return out_code64
342 |
343 |
344 | class NEXT_STAGE_G(nn.Module):
345 | def __init__(self, ngf, nef, ncf):
346 | super(NEXT_STAGE_G, self).__init__()
347 | self.gf_dim = ngf
348 | self.ef_dim = nef
349 | self.cf_dim = ncf
350 | self.num_residual = cfg.GAN.R_NUM
351 | self.define_module()
352 |
353 | def _make_layer(self, block, channel_num):
354 | layers = []
355 | for i in range(cfg.GAN.R_NUM):
356 | layers.append(block(channel_num))
357 | return nn.Sequential(*layers)
358 |
359 | def define_module(self):
360 | ngf = self.gf_dim
361 | self.att = ATT_NET(ngf, self.ef_dim)
362 | self.residual = self._make_layer(ResBlock, ngf * 2)
363 | self.upsample = upBlock(ngf * 2, ngf)
364 |
365 | def forward(self, h_code, c_code, word_embs, mask):
366 | """
367 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw)
368 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
369 | c_code1: batch x idf x queryL
370 | att1: batch x sourceL x queryL
371 | """
372 | self.att.applyMask(mask)
373 | c_code, att = self.att(h_code, word_embs)
374 | h_c_code = torch.cat((h_code, c_code), 1)
375 | out_code = self.residual(h_c_code)
376 |
377 | # state size ngf/2 x 2in_size x 2in_size
378 | out_code = self.upsample(out_code)
379 |
380 | return out_code, att
381 |
382 |
383 | class GET_IMAGE_G(nn.Module):
384 | def __init__(self, ngf):
385 | super(GET_IMAGE_G, self).__init__()
386 | self.gf_dim = ngf
387 | self.img = nn.Sequential(
388 | conv3x3(ngf, 3),
389 | nn.Tanh()
390 | )
391 |
392 | def forward(self, h_code):
393 | out_img = self.img(h_code)
394 | return out_img
395 |
396 |
397 | class G_NET(nn.Module):
398 | def __init__(self):
399 | super(G_NET, self).__init__()
400 | ngf = cfg.GAN.GF_DIM
401 | nef = cfg.TEXT.EMBEDDING_DIM
402 | ncf = cfg.GAN.CONDITION_DIM
403 | self.ca_net = CA_NET()
404 |
405 | if cfg.TREE.BRANCH_NUM > 0:
406 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
407 | self.img_net1 = GET_IMAGE_G(ngf)
408 | # gf x 64 x 64
409 | if cfg.TREE.BRANCH_NUM > 1:
410 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
411 | self.img_net2 = GET_IMAGE_G(ngf)
412 | if cfg.TREE.BRANCH_NUM > 2:
413 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
414 | self.img_net3 = GET_IMAGE_G(ngf)
415 |
416 | def forward(self, z_code, sent_emb, word_embs, mask):
417 | """
418 | :param z_code: batch x cfg.GAN.Z_DIM
419 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
420 | :param word_embs: batch x cdf x seq_len
421 | :param mask: batch x seq_len
422 | :return:
423 | """
424 | fake_imgs = []
425 | att_maps = []
426 | c_code, mu, logvar = self.ca_net(sent_emb)
427 |
428 | if cfg.TREE.BRANCH_NUM > 0:
429 | h_code1 = self.h_net1(z_code, c_code)
430 | fake_img1 = self.img_net1(h_code1)
431 | fake_imgs.append(fake_img1)
432 | if cfg.TREE.BRANCH_NUM > 1:
433 | h_code2, att1 = \
434 | self.h_net2(h_code1, c_code, word_embs, mask)
435 | fake_img2 = self.img_net2(h_code2)
436 | fake_imgs.append(fake_img2)
437 | if att1 is not None:
438 | att_maps.append(att1)
439 | if cfg.TREE.BRANCH_NUM > 2:
440 | h_code3, att2 = \
441 | self.h_net3(h_code2, c_code, word_embs, mask)
442 | fake_img3 = self.img_net3(h_code3)
443 | fake_imgs.append(fake_img3)
444 | if att2 is not None:
445 | att_maps.append(att2)
446 |
447 | return fake_imgs, att_maps, mu, logvar
448 |
449 |
450 |
451 | class G_DCGAN(nn.Module):
452 | def __init__(self):
453 | super(G_DCGAN, self).__init__()
454 | ngf = cfg.GAN.GF_DIM
455 | nef = cfg.TEXT.EMBEDDING_DIM
456 | ncf = cfg.GAN.CONDITION_DIM
457 | self.ca_net = CA_NET()
458 |
459 | # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64
460 | if cfg.TREE.BRANCH_NUM > 0:
461 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
462 | # gf x 64 x 64
463 | if cfg.TREE.BRANCH_NUM > 1:
464 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
465 | if cfg.TREE.BRANCH_NUM > 2:
466 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
467 | self.img_net = GET_IMAGE_G(ngf)
468 |
469 | def forward(self, z_code, sent_emb, word_embs, mask):
470 | """
471 | :param z_code: batch x cfg.GAN.Z_DIM
472 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
473 | :param word_embs: batch x cdf x seq_len
474 | :param mask: batch x seq_len
475 | :return:
476 | """
477 | att_maps = []
478 | c_code, mu, logvar = self.ca_net(sent_emb)
479 | if cfg.TREE.BRANCH_NUM > 0:
480 | h_code = self.h_net1(z_code, c_code)
481 | if cfg.TREE.BRANCH_NUM > 1:
482 | h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask)
483 | if att1 is not None:
484 | att_maps.append(att1)
485 | if cfg.TREE.BRANCH_NUM > 2:
486 | h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask)
487 | if att2 is not None:
488 | att_maps.append(att2)
489 |
490 | fake_imgs = self.img_net(h_code)
491 | return [fake_imgs], att_maps, mu, logvar
492 |
493 |
494 | # ############## D networks ##########################
495 | def Block3x3_leakRelu(in_planes, out_planes):
496 | block = nn.Sequential(
497 | conv3x3(in_planes, out_planes),
498 | nn.BatchNorm2d(out_planes),
499 | nn.LeakyReLU(0.2, inplace=True)
500 | )
501 | return block
502 |
503 |
504 | # Downsale the spatial size by a factor of 2
505 | def downBlock(in_planes, out_planes):
506 | block = nn.Sequential(
507 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
508 | nn.BatchNorm2d(out_planes),
509 | nn.LeakyReLU(0.2, inplace=True)
510 | )
511 | return block
512 |
513 |
514 | # Downsale the spatial size by a factor of 16
515 | def encode_image_by_16times(ndf):
516 | encode_img = nn.Sequential(
517 | # --> state size. ndf x in_size/2 x in_size/2
518 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
519 | nn.LeakyReLU(0.2, inplace=True),
520 | # --> state size 2ndf x x in_size/4 x in_size/4
521 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
522 | nn.BatchNorm2d(ndf * 2),
523 | nn.LeakyReLU(0.2, inplace=True),
524 | # --> state size 4ndf x in_size/8 x in_size/8
525 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
526 | nn.BatchNorm2d(ndf * 4),
527 | nn.LeakyReLU(0.2, inplace=True),
528 | # --> state size 8ndf x in_size/16 x in_size/16
529 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
530 | nn.BatchNorm2d(ndf * 8),
531 | nn.LeakyReLU(0.2, inplace=True)
532 | )
533 | return encode_img
534 |
535 |
536 | class D_GET_LOGITS(nn.Module):
537 | def __init__(self, ndf, nef, bcondition=False):
538 | super(D_GET_LOGITS, self).__init__()
539 | self.df_dim = ndf
540 | self.ef_dim = nef
541 | self.bcondition = bcondition
542 | if self.bcondition:
543 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)
544 |
545 | self.outlogits = nn.Sequential(
546 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
547 | nn.Sigmoid())
548 |
549 | def forward(self, h_code, c_code=None):
550 | if self.bcondition and c_code is not None:
551 | # conditioning output
552 | c_code = c_code.view(-1, self.ef_dim, 1, 1)
553 | c_code = c_code.repeat(1, 1, 4, 4)
554 | # state size (ngf+egf) x 4 x 4
555 | h_c_code = torch.cat((h_code, c_code), 1)
556 | # state size ngf x in_size x in_size
557 | h_c_code = self.jointConv(h_c_code)
558 | else:
559 | h_c_code = h_code
560 |
561 | output = self.outlogits(h_c_code)
562 | return output.view(-1)
563 |
564 |
565 | # For 64 x 64 images
566 | class D_NET64(nn.Module):
567 | def __init__(self, b_jcu=True):
568 | super(D_NET64, self).__init__()
569 | ndf = cfg.GAN.DF_DIM
570 | nef = cfg.TEXT.EMBEDDING_DIM
571 | self.img_code_s16 = encode_image_by_16times(ndf)
572 | if b_jcu:
573 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
574 | else:
575 | self.UNCOND_DNET = None
576 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
577 |
578 | def forward(self, x_var):
579 | x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df
580 | return x_code4
581 |
582 |
583 | # For 128 x 128 images
584 | class D_NET128(nn.Module):
585 | def __init__(self, b_jcu=True):
586 | super(D_NET128, self).__init__()
587 | ndf = cfg.GAN.DF_DIM
588 | nef = cfg.TEXT.EMBEDDING_DIM
589 | self.img_code_s16 = encode_image_by_16times(ndf)
590 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
591 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)
592 | #
593 | if b_jcu:
594 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
595 | else:
596 | self.UNCOND_DNET = None
597 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
598 |
599 | def forward(self, x_var):
600 | x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df
601 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df
602 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df
603 | return x_code4
604 |
605 |
606 | # For 256 x 256 images
607 | class D_NET256(nn.Module):
608 | def __init__(self, b_jcu=True):
609 | super(D_NET256, self).__init__()
610 | ndf = cfg.GAN.DF_DIM
611 | nef = cfg.TEXT.EMBEDDING_DIM
612 | self.img_code_s16 = encode_image_by_16times(ndf)
613 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
614 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32)
615 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16)
616 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8)
617 | if b_jcu:
618 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
619 | else:
620 | self.UNCOND_DNET = None
621 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
622 |
623 | def forward(self, x_var):
624 | x_code16 = self.img_code_s16(x_var)
625 | x_code8 = self.img_code_s32(x_code16)
626 | x_code4 = self.img_code_s64(x_code8)
627 | x_code4 = self.img_code_s64_1(x_code4)
628 | x_code4 = self.img_code_s64_2(x_code4)
629 | return x_code4
630 |
--------------------------------------------------------------------------------
/code/pretrain_DAMSM.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from miscc.utils import mkdir_p
4 | from miscc.utils import build_super_images
5 | from miscc.losses import sent_loss, words_loss
6 | from miscc.config import cfg, cfg_from_file
7 |
8 | from datasets import TextDataset
9 | from datasets import prepare_data
10 |
11 | from model import RNN_ENCODER, CNN_ENCODER
12 |
13 | import os
14 | import sys
15 | import time
16 | import random
17 | import pprint
18 | import datetime
19 | import dateutil.tz
20 | import argparse
21 | import numpy as np
22 | from PIL import Image
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.optim as optim
27 | from torch.autograd import Variable
28 | import torch.backends.cudnn as cudnn
29 | import torchvision.transforms as transforms
30 |
31 |
32 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
33 | sys.path.append(dir_path)
34 |
35 |
36 | UPDATE_INTERVAL = 200
37 | def parse_args():
38 | parser = argparse.ArgumentParser(description='Train a DAMSM network')
39 | parser.add_argument('--cfg', dest='cfg_file',
40 | help='optional config file',
41 | default='cfg/DAMSM/bird.yml', type=str)
42 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)
43 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
44 | parser.add_argument('--manualSeed', type=int, help='manual seed')
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def train(dataloader, cnn_model, rnn_model, batch_size,
50 | labels, optimizer, epoch, ixtoword, image_dir):
51 | cnn_model.train()
52 | rnn_model.train()
53 | s_total_loss0 = 0
54 | s_total_loss1 = 0
55 | w_total_loss0 = 0
56 | w_total_loss1 = 0
57 | count = (epoch + 1) * len(dataloader)
58 | start_time = time.time()
59 | for step, data in enumerate(dataloader, 0):
60 | # print('step', step)
61 | rnn_model.zero_grad()
62 | cnn_model.zero_grad()
63 |
64 | imgs, captions, cap_lens, \
65 | class_ids, keys = prepare_data(data)
66 |
67 |
68 | # words_features: batch_size x nef x 17 x 17
69 | # sent_code: batch_size x nef
70 | words_features, sent_code = cnn_model(imgs[-1])
71 | # --> batch_size x nef x 17*17
72 | nef, att_sze = words_features.size(1), words_features.size(2)
73 | # words_features = words_features.view(batch_size, nef, -1)
74 |
75 | hidden = rnn_model.init_hidden(batch_size)
76 | # words_emb: batch_size x nef x seq_len
77 | # sent_emb: batch_size x nef
78 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
79 |
80 | w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
81 | cap_lens, class_ids, batch_size)
82 | w_total_loss0 += w_loss0.data
83 | w_total_loss1 += w_loss1.data
84 | loss = w_loss0 + w_loss1
85 |
86 | s_loss0, s_loss1 = \
87 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
88 | loss += s_loss0 + s_loss1
89 | s_total_loss0 += s_loss0.data
90 | s_total_loss1 += s_loss1.data
91 | #
92 | loss.backward()
93 | #
94 | # `clip_grad_norm` helps prevent
95 | # the exploding gradient problem in RNNs / LSTMs.
96 | torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
97 | cfg.TRAIN.RNN_GRAD_CLIP)
98 | optimizer.step()
99 |
100 | if step % UPDATE_INTERVAL == 0:
101 | count = epoch * len(dataloader) + step
102 |
103 | s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL
104 | s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL
105 |
106 | w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL
107 | w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL
108 |
109 | elapsed = time.time() - start_time
110 | print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
111 | 's_loss {:5.2f} {:5.2f} | '
112 | 'w_loss {:5.2f} {:5.2f}'
113 | .format(epoch, step, len(dataloader),
114 | elapsed * 1000. / UPDATE_INTERVAL,
115 | s_cur_loss0, s_cur_loss1,
116 | w_cur_loss0, w_cur_loss1))
117 | s_total_loss0 = 0
118 | s_total_loss1 = 0
119 | w_total_loss0 = 0
120 | w_total_loss1 = 0
121 | start_time = time.time()
122 | # attention Maps
123 | img_set, _ = \
124 | build_super_images(imgs[-1].cpu(), captions,
125 | ixtoword, attn_maps, att_sze)
126 | if img_set is not None:
127 | im = Image.fromarray(img_set)
128 | fullpath = '%s/attention_maps%d.png' % (image_dir, step)
129 | im.save(fullpath)
130 | return count
131 |
132 |
133 | def evaluate(dataloader, cnn_model, rnn_model, batch_size):
134 | cnn_model.eval()
135 | rnn_model.eval()
136 | s_total_loss = 0
137 | w_total_loss = 0
138 | for step, data in enumerate(dataloader, 0):
139 | real_imgs, captions, cap_lens, \
140 | class_ids, keys = prepare_data(data)
141 |
142 | words_features, sent_code = cnn_model(real_imgs[-1])
143 | # nef = words_features.size(1)
144 | # words_features = words_features.view(batch_size, nef, -1)
145 |
146 | hidden = rnn_model.init_hidden(batch_size)
147 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
148 |
149 | w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
150 | cap_lens, class_ids, batch_size)
151 | w_total_loss += (w_loss0 + w_loss1).data
152 |
153 | s_loss0, s_loss1 = \
154 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
155 | s_total_loss += (s_loss0 + s_loss1).data
156 |
157 | if step == 50:
158 | break
159 |
160 | s_cur_loss = s_total_loss[0] / step
161 | w_cur_loss = w_total_loss[0] / step
162 |
163 | return s_cur_loss, w_cur_loss
164 |
165 |
166 | def build_models():
167 | # build model ############################################################
168 | text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
169 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
170 | labels = Variable(torch.LongTensor(range(batch_size)))
171 | start_epoch = 0
172 | if cfg.TRAIN.NET_E != '':
173 | state_dict = torch.load(cfg.TRAIN.NET_E)
174 | text_encoder.load_state_dict(state_dict)
175 | print('Load ', cfg.TRAIN.NET_E)
176 | #
177 | name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
178 | state_dict = torch.load(name)
179 | image_encoder.load_state_dict(state_dict)
180 | print('Load ', name)
181 |
182 | istart = cfg.TRAIN.NET_E.rfind('_') + 8
183 | iend = cfg.TRAIN.NET_E.rfind('.')
184 | start_epoch = cfg.TRAIN.NET_E[istart:iend]
185 | start_epoch = int(start_epoch) + 1
186 | print('start_epoch', start_epoch)
187 | if cfg.CUDA:
188 | text_encoder = text_encoder.cuda()
189 | image_encoder = image_encoder.cuda()
190 | labels = labels.cuda()
191 |
192 | return text_encoder, image_encoder, labels, start_epoch
193 |
194 |
195 | if __name__ == "__main__":
196 | args = parse_args()
197 | if args.cfg_file is not None:
198 | cfg_from_file(args.cfg_file)
199 |
200 | if args.gpu_id == -1:
201 | cfg.CUDA = False
202 | else:
203 | cfg.GPU_ID = args.gpu_id
204 |
205 | if args.data_dir != '':
206 | cfg.DATA_DIR = args.data_dir
207 | print('Using config:')
208 | pprint.pprint(cfg)
209 |
210 | if not cfg.TRAIN.FLAG:
211 | args.manualSeed = 100
212 | elif args.manualSeed is None:
213 | args.manualSeed = random.randint(1, 10000)
214 | random.seed(args.manualSeed)
215 | np.random.seed(args.manualSeed)
216 | torch.manual_seed(args.manualSeed)
217 | if cfg.CUDA:
218 | torch.cuda.manual_seed_all(args.manualSeed)
219 |
220 | ##########################################################################
221 | now = datetime.datetime.now(dateutil.tz.tzlocal())
222 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
223 | output_dir = '../output/%s_%s_%s' % \
224 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
225 |
226 | model_dir = os.path.join(output_dir, 'Model')
227 | image_dir = os.path.join(output_dir, 'Image')
228 | mkdir_p(model_dir)
229 | mkdir_p(image_dir)
230 |
231 | torch.cuda.set_device(cfg.GPU_ID)
232 | cudnn.benchmark = True
233 |
234 | # Get data loader ##################################################
235 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
236 | batch_size = cfg.TRAIN.BATCH_SIZE
237 | image_transform = transforms.Compose([
238 | transforms.Scale(int(imsize * 76 / 64)),
239 | transforms.RandomCrop(imsize),
240 | transforms.RandomHorizontalFlip()])
241 | dataset = TextDataset(cfg.DATA_DIR, 'train',
242 | base_size=cfg.TREE.BASE_SIZE,
243 | transform=image_transform)
244 |
245 | print(dataset.n_words, dataset.embeddings_num)
246 | assert dataset
247 | dataloader = torch.utils.data.DataLoader(
248 | dataset, batch_size=batch_size, drop_last=True,
249 | shuffle=True, num_workers=int(cfg.WORKERS))
250 |
251 | # # validation data #
252 | dataset_val = TextDataset(cfg.DATA_DIR, 'test',
253 | base_size=cfg.TREE.BASE_SIZE,
254 | transform=image_transform)
255 | dataloader_val = torch.utils.data.DataLoader(
256 | dataset_val, batch_size=batch_size, drop_last=True,
257 | shuffle=True, num_workers=int(cfg.WORKERS))
258 |
259 | # Train ##############################################################
260 | text_encoder, image_encoder, labels, start_epoch = build_models()
261 | para = list(text_encoder.parameters())
262 | for v in image_encoder.parameters():
263 | if v.requires_grad:
264 | para.append(v)
265 | # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
266 | # At any point you can hit Ctrl + C to break out of training early.
267 | try:
268 | lr = cfg.TRAIN.ENCODER_LR
269 | for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
270 | optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
271 | epoch_start_time = time.time()
272 | count = train(dataloader, image_encoder, text_encoder,
273 | batch_size, labels, optimizer, epoch,
274 | dataset.ixtoword, image_dir)
275 | print('-' * 89)
276 | if len(dataloader_val) > 0:
277 | s_loss, w_loss = evaluate(dataloader_val, image_encoder,
278 | text_encoder, batch_size)
279 | print('| end epoch {:3d} | valid loss '
280 | '{:5.2f} {:5.2f} | lr {:.5f}|'
281 | .format(epoch, s_loss, w_loss, lr))
282 | print('-' * 89)
283 | if lr > cfg.TRAIN.ENCODER_LR/10.:
284 | lr *= 0.98
285 |
286 | if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
287 | epoch == cfg.TRAIN.MAX_EPOCH):
288 | torch.save(image_encoder.state_dict(),
289 | '%s/image_encoder%d.pth' % (model_dir, epoch))
290 | torch.save(text_encoder.state_dict(),
291 | '%s/text_encoder%d.pth' % (model_dir, epoch))
292 | print('Save G/Ds models.')
293 | except KeyboardInterrupt:
294 | print('-' * 89)
295 | print('Exiting from training early')
296 |
--------------------------------------------------------------------------------
/code/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from six.moves import range
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | import torch.backends.cudnn as cudnn
9 |
10 | from PIL import Image
11 |
12 | from miscc.config import cfg
13 | from miscc.utils import mkdir_p
14 | from miscc.utils import build_super_images, build_super_images2
15 | from miscc.utils import weights_init, load_params, copy_G_params
16 | from model import G_DCGAN, G_NET
17 | from datasets import prepare_data
18 | from model import RNN_ENCODER, CNN_ENCODER
19 |
20 | from miscc.losses import words_loss
21 | from miscc.losses import discriminator_loss, generator_loss, KL_loss
22 | import os
23 | import time
24 | import numpy as np
25 | import sys
26 |
27 | # ################# Text to image task############################ #
28 | class condGANTrainer(object):
29 | def __init__(self, output_dir, data_loader, n_words, ixtoword):
30 | if cfg.TRAIN.FLAG:
31 | self.model_dir = os.path.join(output_dir, 'Model')
32 | self.image_dir = os.path.join(output_dir, 'Image')
33 | mkdir_p(self.model_dir)
34 | mkdir_p(self.image_dir)
35 |
36 | torch.cuda.set_device(cfg.GPU_ID)
37 | cudnn.benchmark = True
38 |
39 | self.batch_size = cfg.TRAIN.BATCH_SIZE
40 | self.max_epoch = cfg.TRAIN.MAX_EPOCH
41 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
42 |
43 | self.n_words = n_words
44 | self.ixtoword = ixtoword
45 | self.data_loader = data_loader
46 | self.num_batches = len(self.data_loader)
47 |
48 | def build_models(self):
49 | # ###################encoders######################################## #
50 | if cfg.TRAIN.NET_E == '':
51 | print('Error: no pretrained text-image encoders')
52 | return
53 |
54 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
55 | img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
56 | state_dict = \
57 | torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
58 | image_encoder.load_state_dict(state_dict)
59 | for p in image_encoder.parameters():
60 | p.requires_grad = False
61 | print('Load image encoder from:', img_encoder_path)
62 | image_encoder.eval()
63 |
64 | text_encoder = \
65 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
66 | state_dict = \
67 | torch.load(cfg.TRAIN.NET_E,
68 | map_location=lambda storage, loc: storage)
69 | text_encoder.load_state_dict(state_dict)
70 | for p in text_encoder.parameters():
71 | p.requires_grad = False
72 | print('Load text encoder from:', cfg.TRAIN.NET_E)
73 | text_encoder.eval()
74 |
75 | # #######################generator and discriminators############## #
76 | netsD = []
77 | if cfg.GAN.B_DCGAN:
78 | if cfg.TREE.BRANCH_NUM ==1:
79 | from model import D_NET64 as D_NET
80 | elif cfg.TREE.BRANCH_NUM == 2:
81 | from model import D_NET128 as D_NET
82 | else: # cfg.TREE.BRANCH_NUM == 3:
83 | from model import D_NET256 as D_NET
84 | # TODO: elif cfg.TREE.BRANCH_NUM > 3:
85 | netG = G_DCGAN()
86 | netsD = [D_NET(b_jcu=False)]
87 | else:
88 | from model import D_NET64, D_NET128, D_NET256
89 | netG = G_NET()
90 | if cfg.TREE.BRANCH_NUM > 0:
91 | netsD.append(D_NET64())
92 | if cfg.TREE.BRANCH_NUM > 1:
93 | netsD.append(D_NET128())
94 | if cfg.TREE.BRANCH_NUM > 2:
95 | netsD.append(D_NET256())
96 | # TODO: if cfg.TREE.BRANCH_NUM > 3:
97 | netG.apply(weights_init)
98 | # print(netG)
99 | for i in range(len(netsD)):
100 | netsD[i].apply(weights_init)
101 | # print(netsD[i])
102 | print('# of netsD', len(netsD))
103 | #
104 | epoch = 0
105 | if cfg.TRAIN.NET_G != '':
106 | state_dict = \
107 | torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
108 | netG.load_state_dict(state_dict)
109 | print('Load G from: ', cfg.TRAIN.NET_G)
110 | istart = cfg.TRAIN.NET_G.rfind('_') + 1
111 | iend = cfg.TRAIN.NET_G.rfind('.')
112 | epoch = cfg.TRAIN.NET_G[istart:iend]
113 | epoch = int(epoch) + 1
114 | if cfg.TRAIN.B_NET_D:
115 | Gname = cfg.TRAIN.NET_G
116 | for i in range(len(netsD)):
117 | s_tmp = Gname[:Gname.rfind('/')]
118 | Dname = '%s/netD%d.pth' % (s_tmp, i)
119 | print('Load D from: ', Dname)
120 | state_dict = \
121 | torch.load(Dname, map_location=lambda storage, loc: storage)
122 | netsD[i].load_state_dict(state_dict)
123 | # ########################################################### #
124 | if cfg.CUDA:
125 | text_encoder = text_encoder.cuda()
126 | image_encoder = image_encoder.cuda()
127 | netG.cuda()
128 | for i in range(len(netsD)):
129 | netsD[i].cuda()
130 | return [text_encoder, image_encoder, netG, netsD, epoch]
131 |
132 | def define_optimizers(self, netG, netsD):
133 | optimizersD = []
134 | num_Ds = len(netsD)
135 | for i in range(num_Ds):
136 | opt = optim.Adam(netsD[i].parameters(),
137 | lr=cfg.TRAIN.DISCRIMINATOR_LR,
138 | betas=(0.5, 0.999))
139 | optimizersD.append(opt)
140 |
141 | optimizerG = optim.Adam(netG.parameters(),
142 | lr=cfg.TRAIN.GENERATOR_LR,
143 | betas=(0.5, 0.999))
144 |
145 | return optimizerG, optimizersD
146 |
147 | def prepare_labels(self):
148 | batch_size = self.batch_size
149 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
150 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
151 | match_labels = Variable(torch.LongTensor(range(batch_size)))
152 | if cfg.CUDA:
153 | real_labels = real_labels.cuda()
154 | fake_labels = fake_labels.cuda()
155 | match_labels = match_labels.cuda()
156 |
157 | return real_labels, fake_labels, match_labels
158 |
159 | def save_model(self, netG, avg_param_G, netsD, epoch):
160 | backup_para = copy_G_params(netG)
161 | load_params(netG, avg_param_G)
162 | torch.save(netG.state_dict(),
163 | '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
164 | load_params(netG, backup_para)
165 | #
166 | for i in range(len(netsD)):
167 | netD = netsD[i]
168 | torch.save(netD.state_dict(),
169 | '%s/netD%d.pth' % (self.model_dir, i))
170 | print('Save G/Ds models.')
171 |
172 | def set_requires_grad_value(self, models_list, brequires):
173 | for i in range(len(models_list)):
174 | for p in models_list[i].parameters():
175 | p.requires_grad = brequires
176 |
177 | def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
178 | image_encoder, captions, cap_lens,
179 | gen_iterations, name='current'):
180 | # Save images
181 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
182 | for i in range(len(attention_maps)):
183 | if len(fake_imgs) > 1:
184 | img = fake_imgs[i + 1].detach().cpu()
185 | lr_img = fake_imgs[i].detach().cpu()
186 | else:
187 | img = fake_imgs[0].detach().cpu()
188 | lr_img = None
189 | attn_maps = attention_maps[i]
190 | att_sze = attn_maps.size(2)
191 | img_set, _ = \
192 | build_super_images(img, captions, self.ixtoword,
193 | attn_maps, att_sze, lr_imgs=lr_img)
194 | if img_set is not None:
195 | im = Image.fromarray(img_set)
196 | fullpath = '%s/G_%s_%d_%d.png'\
197 | % (self.image_dir, name, gen_iterations, i)
198 | im.save(fullpath)
199 |
200 | # for i in range(len(netsD)):
201 | i = -1
202 | img = fake_imgs[i].detach()
203 | region_features, _ = image_encoder(img)
204 | att_sze = region_features.size(2)
205 | _, _, att_maps = words_loss(region_features.detach(),
206 | words_embs.detach(),
207 | None, cap_lens,
208 | None, self.batch_size)
209 | img_set, _ = \
210 | build_super_images(fake_imgs[i].detach().cpu(),
211 | captions, self.ixtoword, att_maps, att_sze)
212 | if img_set is not None:
213 | im = Image.fromarray(img_set)
214 | fullpath = '%s/D_%s_%d.png'\
215 | % (self.image_dir, name, gen_iterations)
216 | im.save(fullpath)
217 |
218 | def train(self):
219 | text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
220 | avg_param_G = copy_G_params(netG)
221 | optimizerG, optimizersD = self.define_optimizers(netG, netsD)
222 | real_labels, fake_labels, match_labels = self.prepare_labels()
223 |
224 | batch_size = self.batch_size
225 | nz = cfg.GAN.Z_DIM
226 | noise = Variable(torch.FloatTensor(batch_size, nz))
227 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
228 | if cfg.CUDA:
229 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
230 |
231 | gen_iterations = 0
232 | # gen_iterations = start_epoch * self.num_batches
233 | for epoch in range(start_epoch, self.max_epoch):
234 | start_t = time.time()
235 |
236 | data_iter = iter(self.data_loader)
237 | step = 0
238 | while step < self.num_batches:
239 | # reset requires_grad to be trainable for all Ds
240 | # self.set_requires_grad_value(netsD, True)
241 |
242 | ######################################################
243 | # (1) Prepare training data and Compute text embeddings
244 | ######################################################
245 | data = data_iter.next()
246 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
247 |
248 | hidden = text_encoder.init_hidden(batch_size)
249 | # words_embs: batch_size x nef x seq_len
250 | # sent_emb: batch_size x nef
251 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
252 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
253 | mask = (captions == 0)
254 | num_words = words_embs.size(2)
255 | if mask.size(1) > num_words:
256 | mask = mask[:, :num_words]
257 |
258 | #######################################################
259 | # (2) Generate fake images
260 | ######################################################
261 | noise.data.normal_(0, 1)
262 | fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
263 |
264 | #######################################################
265 | # (3) Update D network
266 | ######################################################
267 | errD_total = 0
268 | D_logs = ''
269 | for i in range(len(netsD)):
270 | netsD[i].zero_grad()
271 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
272 | sent_emb, real_labels, fake_labels)
273 | # backward and update parameters
274 | errD.backward()
275 | optimizersD[i].step()
276 | errD_total += errD
277 | D_logs += 'errD%d: %.2f ' % (i, errD.data[0])
278 |
279 | #######################################################
280 | # (4) Update G network: maximize log(D(G(z)))
281 | ######################################################
282 | # compute total loss for training G
283 | step += 1
284 | gen_iterations += 1
285 |
286 | # do not need to compute gradient for Ds
287 | # self.set_requires_grad_value(netsD, False)
288 | netG.zero_grad()
289 | errG_total, G_logs = \
290 | generator_loss(netsD, image_encoder, fake_imgs, real_labels,
291 | words_embs, sent_emb, match_labels, cap_lens, class_ids)
292 | kl_loss = KL_loss(mu, logvar)
293 | errG_total += kl_loss
294 | G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
295 | # backward and update parameters
296 | errG_total.backward()
297 | optimizerG.step()
298 | for p, avg_p in zip(netG.parameters(), avg_param_G):
299 | avg_p.mul_(0.999).add_(0.001, p.data)
300 |
301 | if gen_iterations % 100 == 0:
302 | print(D_logs + '\n' + G_logs)
303 | # save images
304 | if gen_iterations % 1000 == 0:
305 | backup_para = copy_G_params(netG)
306 | load_params(netG, avg_param_G)
307 | self.save_img_results(netG, fixed_noise, sent_emb,
308 | words_embs, mask, image_encoder,
309 | captions, cap_lens, epoch, name='average')
310 | load_params(netG, backup_para)
311 | #
312 | # self.save_img_results(netG, fixed_noise, sent_emb,
313 | # words_embs, mask, image_encoder,
314 | # captions, cap_lens,
315 | # epoch, name='current')
316 | end_t = time.time()
317 |
318 | print('''[%d/%d][%d]
319 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
320 | % (epoch, self.max_epoch, self.num_batches,
321 | errD_total.data[0], errG_total.data[0],
322 | end_t - start_t))
323 |
324 | if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0:
325 | self.save_model(netG, avg_param_G, netsD, epoch)
326 |
327 | self.save_model(netG, avg_param_G, netsD, self.max_epoch)
328 |
329 | def save_singleimages(self, images, filenames, save_dir,
330 | split_dir, sentenceID=0):
331 | for i in range(images.size(0)):
332 | s_tmp = '%s/single_samples/%s/%s' %\
333 | (save_dir, split_dir, filenames[i])
334 | folder = s_tmp[:s_tmp.rfind('/')]
335 | if not os.path.isdir(folder):
336 | print('Make a new folder: ', folder)
337 | mkdir_p(folder)
338 |
339 | fullpath = '%s_%d.jpg' % (s_tmp, sentenceID)
340 | # range from [-1, 1] to [0, 1]
341 | # img = (images[i] + 1.0) / 2
342 | img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
343 | # range from [0, 1] to [0, 255]
344 | ndarr = img.permute(1, 2, 0).data.cpu().numpy()
345 | im = Image.fromarray(ndarr)
346 | im.save(fullpath)
347 |
348 | def sampling(self, split_dir):
349 | if cfg.TRAIN.NET_G == '':
350 | print('Error: the path for morels is not found!')
351 | else:
352 | if split_dir == 'test':
353 | split_dir = 'valid'
354 | # Build and load the generator
355 | if cfg.GAN.B_DCGAN:
356 | netG = G_DCGAN()
357 | else:
358 | netG = G_NET()
359 | netG.apply(weights_init)
360 | netG.cuda()
361 | netG.eval()
362 | #
363 | text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
364 | state_dict = \
365 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
366 | text_encoder.load_state_dict(state_dict)
367 | print('Load text encoder from:', cfg.TRAIN.NET_E)
368 | text_encoder = text_encoder.cuda()
369 | text_encoder.eval()
370 |
371 | batch_size = self.batch_size
372 | nz = cfg.GAN.Z_DIM
373 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
374 | noise = noise.cuda()
375 |
376 | model_dir = cfg.TRAIN.NET_G
377 | state_dict = \
378 | torch.load(model_dir, map_location=lambda storage, loc: storage)
379 | # state_dict = torch.load(cfg.TRAIN.NET_G)
380 | netG.load_state_dict(state_dict)
381 | print('Load G from: ', model_dir)
382 |
383 | # the path to save generated images
384 | s_tmp = model_dir[:model_dir.rfind('.pth')]
385 | save_dir = '%s/%s' % (s_tmp, split_dir)
386 | mkdir_p(save_dir)
387 |
388 | cnt = 0
389 |
390 | for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE):
391 | for step, data in enumerate(self.data_loader, 0):
392 | cnt += batch_size
393 | if step % 100 == 0:
394 | print('step: ', step)
395 | # if step > 50:
396 | # break
397 |
398 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
399 |
400 | hidden = text_encoder.init_hidden(batch_size)
401 | # words_embs: batch_size x nef x seq_len
402 | # sent_emb: batch_size x nef
403 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
404 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
405 | mask = (captions == 0)
406 | num_words = words_embs.size(2)
407 | if mask.size(1) > num_words:
408 | mask = mask[:, :num_words]
409 |
410 | #######################################################
411 | # (2) Generate fake images
412 | ######################################################
413 | noise.data.normal_(0, 1)
414 | fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)
415 | for j in range(batch_size):
416 | s_tmp = '%s/single/%s' % (save_dir, keys[j])
417 | folder = s_tmp[:s_tmp.rfind('/')]
418 | if not os.path.isdir(folder):
419 | print('Make a new folder: ', folder)
420 | mkdir_p(folder)
421 | k = -1
422 | # for k in range(len(fake_imgs)):
423 | im = fake_imgs[k][j].data.cpu().numpy()
424 | # [-1, 1] --> [0, 255]
425 | im = (im + 1.0) * 127.5
426 | im = im.astype(np.uint8)
427 | im = np.transpose(im, (1, 2, 0))
428 | im = Image.fromarray(im)
429 | fullpath = '%s_s%d.png' % (s_tmp, k)
430 | im.save(fullpath)
431 |
432 | def gen_example(self, data_dic):
433 | if cfg.TRAIN.NET_G == '':
434 | print('Error: the path for morels is not found!')
435 | else:
436 | # Build and load the generator
437 | text_encoder = \
438 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
439 | state_dict = \
440 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
441 | text_encoder.load_state_dict(state_dict)
442 | print('Load text encoder from:', cfg.TRAIN.NET_E)
443 | text_encoder = text_encoder.cuda()
444 | text_encoder.eval()
445 |
446 | # the path to save generated images
447 | if cfg.GAN.B_DCGAN:
448 | netG = G_DCGAN()
449 | else:
450 | netG = G_NET()
451 | s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
452 | model_dir = cfg.TRAIN.NET_G
453 | state_dict = \
454 | torch.load(model_dir, map_location=lambda storage, loc: storage)
455 | netG.load_state_dict(state_dict)
456 | print('Load G from: ', model_dir)
457 | netG.cuda()
458 | netG.eval()
459 | for key in data_dic:
460 | save_dir = '%s/%s' % (s_tmp, key)
461 | mkdir_p(save_dir)
462 | captions, cap_lens, sorted_indices = data_dic[key]
463 |
464 | batch_size = captions.shape[0]
465 | nz = cfg.GAN.Z_DIM
466 | captions = Variable(torch.from_numpy(captions), volatile=True)
467 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
468 |
469 | captions = captions.cuda()
470 | cap_lens = cap_lens.cuda()
471 | for i in range(1): # 16
472 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
473 | noise = noise.cuda()
474 | #######################################################
475 | # (1) Extract text embeddings
476 | ######################################################
477 | hidden = text_encoder.init_hidden(batch_size)
478 | # words_embs: batch_size x nef x seq_len
479 | # sent_emb: batch_size x nef
480 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
481 | mask = (captions == 0)
482 | #######################################################
483 | # (2) Generate fake images
484 | ######################################################
485 | noise.data.normal_(0, 1)
486 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
487 | # G attention
488 | cap_lens_np = cap_lens.cpu().data.numpy()
489 | for j in range(batch_size):
490 | save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
491 | for k in range(len(fake_imgs)):
492 | im = fake_imgs[k][j].data.cpu().numpy()
493 | im = (im + 1.0) * 127.5
494 | im = im.astype(np.uint8)
495 | # print('im', im.shape)
496 | im = np.transpose(im, (1, 2, 0))
497 | # print('im', im.shape)
498 | im = Image.fromarray(im)
499 | fullpath = '%s_g%d.png' % (save_name, k)
500 | im.save(fullpath)
501 |
502 | for k in range(len(attention_maps)):
503 | if len(fake_imgs) > 1:
504 | im = fake_imgs[k + 1].detach().cpu()
505 | else:
506 | im = fake_imgs[0].detach().cpu()
507 | attn_maps = attention_maps[k]
508 | att_sze = attn_maps.size(2)
509 | img_set, sentences = \
510 | build_super_images2(im[j].unsqueeze(0),
511 | captions[j].unsqueeze(0),
512 | [cap_lens_np[j]], self.ixtoword,
513 | [attn_maps[j]], att_sze)
514 | if img_set is not None:
515 | im = Image.fromarray(img_set)
516 | fullpath = '%s_a%d.png' % (save_name, k)
517 | im.save(fullpath)
518 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/eval/FreeMono.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/FreeMono.ttf
--------------------------------------------------------------------------------
/eval/GlobalAttention.py:
--------------------------------------------------------------------------------
1 | """
2 | Global attention takes a matrix and a query metrix.
3 | Based on each query vector q, it computes a parameterized convex combination of the matrix
4 | based.
5 | H_1 H_2 H_3 ... H_n
6 | q q q q
7 | | | | |
8 | \ | | /
9 | .....
10 | \ | /
11 | a
12 | Constructs a unit mapping.
13 | $$(H_1 + H_n, q) => (a)$$
14 | Where H is of `batch x n x dim` and q is of `batch x dim`.
15 |
16 | References:
17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules
18 | http://www.aclweb.org/anthology/D15-1166
19 | """
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 |
25 | def conv1x1(in_planes, out_planes):
26 | "1x1 convolution with padding"
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 |
30 |
31 | def func_attention(query, context, gamma1):
32 | """
33 | query: batch x ndf x queryL
34 | context: batch x ndf x ih x iw (sourceL=ihxiw)
35 | mask: batch_size x sourceL
36 | """
37 | batch_size, queryL = query.size(0), query.size(2)
38 | ih, iw = context.size(2), context.size(3)
39 | sourceL = ih * iw
40 |
41 | # --> batch x sourceL x ndf
42 | context = context.view(batch_size, -1, sourceL)
43 | contextT = torch.transpose(context, 1, 2).contiguous()
44 |
45 | # Get attention
46 | # (batch x sourceL x ndf)(batch x ndf x queryL)
47 | # -->batch x sourceL x queryL
48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper
49 | # --> batch*sourceL x queryL
50 | attn = attn.view(batch_size*sourceL, queryL)
51 | attn = nn.Softmax()(attn) # Eq. (8)
52 |
53 | # --> batch x sourceL x queryL
54 | attn = attn.view(batch_size, sourceL, queryL)
55 | # --> batch*queryL x sourceL
56 | attn = torch.transpose(attn, 1, 2).contiguous()
57 | attn = attn.view(batch_size*queryL, sourceL)
58 | # Eq. (9)
59 | attn = attn * gamma1
60 | attn = nn.Softmax()(attn)
61 | attn = attn.view(batch_size, queryL, sourceL)
62 | # --> batch x sourceL x queryL
63 | attnT = torch.transpose(attn, 1, 2).contiguous()
64 |
65 | # (batch x ndf x sourceL)(batch x sourceL x queryL)
66 | # --> batch x ndf x queryL
67 | weightedContext = torch.bmm(context, attnT)
68 |
69 | return weightedContext, attn.view(batch_size, -1, ih, iw)
70 |
71 |
72 | class GlobalAttentionGeneral(nn.Module):
73 | def __init__(self, idf, cdf):
74 | super(GlobalAttentionGeneral, self).__init__()
75 | self.conv_context = conv1x1(cdf, idf)
76 | self.sm = nn.Softmax()
77 | self.mask = None
78 |
79 | def applyMask(self, mask):
80 | self.mask = mask # batch x sourceL
81 |
82 | def forward(self, input, context):
83 | """
84 | input: batch x idf x ih x iw (queryL=ihxiw)
85 | context: batch x cdf x sourceL
86 | """
87 | ih, iw = input.size(2), input.size(3)
88 | queryL = ih * iw
89 | batch_size, sourceL = context.size(0), context.size(2)
90 |
91 | # --> batch x queryL x idf
92 | target = input.view(batch_size, -1, queryL)
93 | targetT = torch.transpose(target, 1, 2).contiguous()
94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1
95 | sourceT = context.unsqueeze(3)
96 | # --> batch x idf x sourceL
97 | sourceT = self.conv_context(sourceT).squeeze(3)
98 |
99 | # Get attention
100 | # (batch x queryL x idf)(batch x idf x sourceL)
101 | # -->batch x queryL x sourceL
102 | attn = torch.bmm(targetT, sourceT)
103 | # --> batch*queryL x sourceL
104 | attn = attn.view(batch_size*queryL, sourceL)
105 | if self.mask is not None:
106 | # batch_size x sourceL --> batch_size*queryL x sourceL
107 | mask = self.mask.repeat(queryL, 1)
108 | attn.data.masked_fill_(mask.data, -float('inf'))
109 | attn = self.sm(attn) # Eq. (2)
110 | # --> batch x queryL x sourceL
111 | attn = attn.view(batch_size, queryL, sourceL)
112 | # --> batch x sourceL x queryL
113 | attn = torch.transpose(attn, 1, 2).contiguous()
114 |
115 | # (batch x idf x sourceL)(batch x sourceL x queryL)
116 | # --> batch x idf x queryL
117 | weightedContext = torch.bmm(sourceT, attn)
118 | weightedContext = weightedContext.view(batch_size, -1, ih, iw)
119 | attn = attn.view(batch_size, -1, ih, iw)
120 |
121 | return weightedContext, attn
122 |
--------------------------------------------------------------------------------
/eval/README.md:
--------------------------------------------------------------------------------
1 | # AttnGAN Eval API
2 | Model evaluation code is extracted here in order to create a separate inference mode for the project. The evaluation code is then embedded into a flask app that accepts API requests.
3 | There are two docker files:
4 | 1. [dockerfile.cpu](dockerfile.cpu) - creates a CPU bonud container
5 | 2. [dockerfile.gpu](dockerfile.gpu) - creates a GPU bound container
6 |
7 | # Requirements
8 | The app uses Azure Blob Storage as an image repository as well as Application Insights for logging telemetry.
9 |
10 | # Running the Application
11 | There is a three step process running the application and generating bird images.
12 | 1. Create the container (optionally choose the cpu or gpu dockerfile:
13 | ```
14 | docker build -t "attngan" -f dockerfile.cpu .
15 | ```
16 | 2. Run the container (replace the key's with the appropriate blob storage location as well as App Insights Key):
17 | ```
18 | docker run -it -e BLOB_KEY=KEY -e TELEMETRY=TELEMETRY_KEY -p 5678:8080 attngan
19 | ```
20 | 3. Call the API:
21 | ```
22 | curl -H "Content-Type: application/json" -X POST -d '{"caption":"the bird has a yellow crown and a black eyering that is round"}' http://locahost:5678/api/v1.0/bird
23 | ```
24 |
25 | # Images
26 | You should have your very own image generator.
27 |
28 |
29 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
--------------------------------------------------------------------------------
/eval/data/bird_AttnGAN2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/bird_AttnGAN2.pth
--------------------------------------------------------------------------------
/eval/data/captions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/captions.pickle
--------------------------------------------------------------------------------
/eval/data/text_encoder200.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/eval/data/text_encoder200.pth
--------------------------------------------------------------------------------
/eval/dockerfile.cpu:
--------------------------------------------------------------------------------
1 | FROM python:2
2 |
3 | RUN mkdir -p /usr/src/app
4 | WORKDIR /usr/src/app
5 |
6 | COPY requirements.txt /usr/src/app/
7 | RUN pip install --upgrade pip
8 | RUN pip install --no-cache-dir -r requirements.txt
9 | RUN pip install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
10 | RUN pip install torchvision
11 |
12 | COPY . /usr/src/app
13 |
14 | ENV GPU False
15 | ENV EXPORT_MODEL True
16 |
17 | EXPOSE 8080
18 |
19 | CMD ["python", "main.py"]
20 |
21 |
22 |
--------------------------------------------------------------------------------
/eval/dockerfile.gpu:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda
2 |
3 | RUN apt-get update \
4 | && apt-get upgrade -y \
5 | && apt-get install -y \
6 | python-pip \
7 | python2.7 \
8 | && apt-get autoremove \
9 | && apt-get clean
10 |
11 | RUN mkdir -p /usr/src/app
12 | WORKDIR /usr/src/app
13 |
14 | COPY requirements.txt /usr/src/app/
15 | RUN pip install --upgrade pip
16 | RUN pip install --no-cache-dir -r requirements.txt
17 | RUN pip install http://download.pytorch.org/whl/cu90/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
18 | RUN pip install torchvision
19 |
20 | COPY . /usr/src/app
21 |
22 | ENV NVIDIA_VISIBLE_DEVICES all
23 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
24 | ENV GPU True
25 | ENV EXPORT_MODEL False
26 |
27 | EXPOSE 8080
28 |
29 | CMD ["python", "main.py"]
--------------------------------------------------------------------------------
/eval/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import torch
6 | import io
7 | import time
8 | import numpy as np
9 | from PIL import Image
10 | import torch.onnx
11 | from datetime import datetime
12 | from torch.autograd import Variable
13 | from miscc.config import cfg
14 | from miscc.utils import build_super_images2
15 | from model import RNN_ENCODER, G_NET
16 | from azure.storage.blob import BlockBlobService
17 |
18 | if sys.version_info[0] == 2:
19 | import cPickle as pickle
20 | else:
21 | import pickle
22 |
23 | from werkzeug.contrib.cache import SimpleCache
24 | cache = SimpleCache()
25 |
26 | def vectorize_caption(wordtoix, caption, copies=2):
27 | # create caption vector
28 | tokens = caption.split(' ')
29 | cap_v = []
30 | for t in tokens:
31 | t = t.strip().encode('ascii', 'ignore').decode('ascii')
32 | if len(t) > 0 and t in wordtoix:
33 | cap_v.append(wordtoix[t])
34 |
35 | # expected state for single generation
36 | captions = np.zeros((copies, len(cap_v)))
37 | for i in range(copies):
38 | captions[i,:] = np.array(cap_v)
39 | cap_lens = np.zeros(copies) + len(cap_v)
40 |
41 | #print(captions.astype(int), cap_lens.astype(int))
42 | #captions, cap_lens = np.array([cap_v, cap_v]), np.array([len(cap_v), len(cap_v)])
43 | #print(captions, cap_lens)
44 | #return captions, cap_lens
45 |
46 | return captions.astype(int), cap_lens.astype(int)
47 |
48 | def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=2):
49 | # load word vector
50 | captions, cap_lens = vectorize_caption(wordtoix, caption, copies)
51 | n_words = len(wordtoix)
52 |
53 | # only one to generate
54 | batch_size = captions.shape[0]
55 |
56 | nz = cfg.GAN.Z_DIM
57 | captions = Variable(torch.from_numpy(captions), volatile=True)
58 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
59 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
60 |
61 | if cfg.CUDA:
62 | captions = captions.cuda()
63 | cap_lens = cap_lens.cuda()
64 | noise = noise.cuda()
65 |
66 |
67 |
68 | #######################################################
69 | # (1) Extract text embeddings
70 | #######################################################
71 | hidden = text_encoder.init_hidden(batch_size)
72 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
73 | mask = (captions == 0)
74 |
75 |
76 | #######################################################
77 | # (2) Generate fake images
78 | #######################################################
79 | noise.data.normal_(0, 1)
80 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
81 |
82 | # ONNX EXPORT
83 | #export = os.environ["EXPORT_MODEL"].lower() == 'true'
84 | if False:
85 | print("saving text_encoder.onnx")
86 | text_encoder_out = torch.onnx._export(text_encoder, (captions, cap_lens, hidden), "text_encoder.onnx", export_params=True)
87 | print("uploading text_encoder.onnx")
88 | blob_service.create_blob_from_path('models', "text_encoder.onnx", os.path.abspath("text_encoder.onnx"))
89 | print("done")
90 |
91 | print("saving netg.onnx")
92 | netg_out = torch.onnx._export(netG, (noise, sent_emb, words_embs, mask), "netg.onnx", export_params=True)
93 | print("uploading netg.onnx")
94 | blob_service.create_blob_from_path('models', "netg.onnx", os.path.abspath("netg.onnx"))
95 | print("done")
96 | return
97 |
98 | # G attention
99 | cap_lens_np = cap_lens.cpu().data.numpy()
100 |
101 | # storing to blob storage
102 | container_name = "images"
103 | full_path = "https://attgan.blob.core.windows.net/images/%s"
104 | prefix = datetime.now().strftime('%Y/%B/%d/%H_%M_%S_%f')
105 | urls = []
106 | # only look at first one
107 | #j = 0
108 | for j in range(batch_size):
109 | for k in range(len(fake_imgs)):
110 | im = fake_imgs[k][j].data.cpu().numpy()
111 | im = (im + 1.0) * 127.5
112 | im = im.astype(np.uint8)
113 | im = np.transpose(im, (1, 2, 0))
114 | im = Image.fromarray(im)
115 |
116 | # save image to stream
117 | stream = io.BytesIO()
118 | im.save(stream, format="png")
119 | stream.seek(0)
120 | if copies > 2:
121 | blob_name = '%s/%d/%s_g%d.png' % (prefix, j, "bird", k)
122 | else:
123 | blob_name = '%s/%s_g%d.png' % (prefix, "bird", k)
124 | blob_service.create_blob_from_stream(container_name, blob_name, stream)
125 | urls.append(full_path % blob_name)
126 |
127 | if copies == 2:
128 | for k in range(len(attention_maps)):
129 | #if False:
130 | if len(fake_imgs) > 1:
131 | im = fake_imgs[k + 1].detach().cpu()
132 | else:
133 | im = fake_imgs[0].detach().cpu()
134 |
135 | attn_maps = attention_maps[k]
136 | att_sze = attn_maps.size(2)
137 |
138 | img_set, sentences = \
139 | build_super_images2(im[j].unsqueeze(0),
140 | captions[j].unsqueeze(0),
141 | [cap_lens_np[j]], ixtoword,
142 | [attn_maps[j]], att_sze)
143 |
144 | if img_set is not None:
145 | im = Image.fromarray(img_set)
146 | stream = io.BytesIO()
147 | im.save(stream, format="png")
148 | stream.seek(0)
149 |
150 | blob_name = '%s/%s_a%d.png' % (prefix, "attmaps", k)
151 | blob_service.create_blob_from_stream(container_name, blob_name, stream)
152 | urls.append(full_path % blob_name)
153 | if copies == 2:
154 | break
155 |
156 | #print(len(urls), urls)
157 | return urls
158 |
159 | def word_index():
160 | ixtoword = cache.get('ixtoword')
161 | wordtoix = cache.get('wordtoix')
162 | if ixtoword is None or wordtoix is None:
163 | #print("ix and word not cached")
164 | # load word to index dictionary
165 | x = pickle.load(open('data/captions.pickle', 'rb'))
166 | ixtoword = x[2]
167 | wordtoix = x[3]
168 | del x
169 | cache.set('ixtoword', ixtoword, timeout=60 * 60 * 24)
170 | cache.set('wordtoix', wordtoix, timeout=60 * 60 * 24)
171 |
172 | return wordtoix, ixtoword
173 |
174 | def models(word_len):
175 | #print(word_len)
176 | text_encoder = cache.get('text_encoder')
177 | if text_encoder is None:
178 | #print("text_encoder not cached")
179 | text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
180 | state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
181 | text_encoder.load_state_dict(state_dict)
182 | if cfg.CUDA:
183 | text_encoder.cuda()
184 | text_encoder.eval()
185 | cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)
186 |
187 | netG = cache.get('netG')
188 | if netG is None:
189 | #print("netG not cached")
190 | netG = G_NET()
191 | state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
192 | netG.load_state_dict(state_dict)
193 | if cfg.CUDA:
194 | netG.cuda()
195 | netG.eval()
196 | cache.set('netG', netG, timeout=60 * 60 * 24)
197 |
198 | return text_encoder, netG
199 |
200 | def eval(caption):
201 | # load word dictionaries
202 | wordtoix, ixtoword = word_index()
203 | # lead models
204 | text_encoder, netG = models(len(wordtoix))
205 | # load blob service
206 | blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"])
207 |
208 | t0 = time.time()
209 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
210 | t1 = time.time()
211 |
212 | response = {
213 | 'small': urls[0],
214 | 'medium': urls[1],
215 | 'large': urls[2],
216 | 'map1': urls[3],
217 | 'map2': urls[4],
218 | 'caption': caption,
219 | 'elapsed': t1 - t0
220 | }
221 |
222 | return response
223 |
224 | if __name__ == "__main__":
225 | caption = "the bird has a yellow crown and a black eyering that is round"
226 |
227 | # load configuration
228 | #cfg_from_file('eval_bird.yml')
229 | # load word dictionaries
230 | wordtoix, ixtoword = word_index()
231 | # lead models
232 | text_encoder, netG = models(len(wordtoix))
233 | # load blob service
234 | blob_service = BlockBlobService(account_name='attgan', account_key='[REDACTED]')
235 |
236 | t0 = time.time()
237 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
238 | t1 = time.time()
239 | print(t1-t0)
240 | print(urls)
--------------------------------------------------------------------------------
/eval/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | from eval import *
5 | from flask import Flask, jsonify, request, abort
6 | from applicationinsights import TelemetryClient
7 | from applicationinsights.requests import WSGIApplication
8 | from applicationinsights.exceptions import enable
9 | from miscc.config import cfg
10 | #from werkzeug.contrib.profiler import ProfilerMiddleware
11 |
12 | enable(os.environ["TELEMETRY"])
13 | app = Flask(__name__)
14 | app.wsgi_app = WSGIApplication(os.environ["TELEMETRY"], app.wsgi_app)
15 |
16 | @app.route('/api/v1.0/bird', methods=['POST'])
17 | def create_bird():
18 | if not request.json or not 'caption' in request.json:
19 | abort(400)
20 |
21 | caption = request.json['caption']
22 |
23 | t0 = time.time()
24 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
25 | t1 = time.time()
26 |
27 | response = {
28 | 'small': urls[0],
29 | 'medium': urls[1],
30 | 'large': urls[2],
31 | 'map1': urls[3],
32 | 'map2': urls[4],
33 | 'caption': caption,
34 | 'elapsed': t1 - t0
35 | }
36 | return jsonify({'bird': response}), 201
37 |
38 | @app.route('/api/v1.0/birds', methods=['POST'])
39 | def create_birds():
40 | if not request.json or not 'caption' in request.json:
41 | abort(400)
42 |
43 | caption = request.json['caption']
44 |
45 | t0 = time.time()
46 | urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=6)
47 | t1 = time.time()
48 |
49 | response = {
50 | 'bird1' : { 'small': urls[0], 'medium': urls[1], 'large': urls[2] },
51 | 'bird2' : { 'small': urls[3], 'medium': urls[4], 'large': urls[5] },
52 | 'bird3' : { 'small': urls[6], 'medium': urls[7], 'large': urls[8] },
53 | 'bird4' : { 'small': urls[9], 'medium': urls[10], 'large': urls[11] },
54 | 'bird5' : { 'small': urls[12], 'medium': urls[13], 'large': urls[14] },
55 | 'bird6' : { 'small': urls[15], 'medium': urls[16], 'large': urls[17] },
56 | 'caption': caption,
57 | 'elapsed': t1 - t0
58 | }
59 | return jsonify({'bird': response}), 201
60 |
61 | @app.route('/', methods=['GET'])
62 | def get_bird():
63 | return 'Version 1'
64 |
65 | if __name__ == '__main__':
66 | t0 = time.time()
67 | tc = TelemetryClient(os.environ["TELEMETRY"])
68 |
69 | # gpu based
70 | cfg.CUDA = os.environ["GPU"].lower() == 'true'
71 | tc.track_event('container initializing', {"CUDA": str(cfg.CUDA)})
72 |
73 | # load word dictionaries
74 | wordtoix, ixtoword = word_index()
75 | # lead models
76 | text_encoder, netG = models(len(wordtoix))
77 | # load blob service
78 | blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"])
79 |
80 | seed = 100
81 | random.seed(seed)
82 | np.random.seed(seed)
83 | torch.manual_seed(seed)
84 | if cfg.CUDA:
85 | torch.cuda.manual_seed_all(seed)
86 |
87 | #app.config['PROFILE'] = True
88 | #app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])
89 | #app.run(host='0.0.0.0', port=8080, debug = True)
90 |
91 | t1 = time.time()
92 | tc.track_event('container start', {"starttime": str(t1-t0)})
93 | app.run(host='0.0.0.0', port=8080)
94 |
--------------------------------------------------------------------------------
/eval/miscc/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
--------------------------------------------------------------------------------
/eval/miscc/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import os.path as osp
5 | import numpy as np
6 | from easydict import EasyDict as edict
7 |
8 |
9 | __C = edict()
10 | cfg = __C
11 |
12 | # Dataset name: flowers, birds
13 | __C.DATASET_NAME = 'birds'
14 | __C.CONFIG_NAME = 'attn2'
15 | __C.DATA_DIR = ''
16 | __C.GPU_ID = 0
17 | __C.CUDA = False
18 | __C.WORKERS = 1
19 |
20 | __C.RNN_TYPE = 'LSTM' # 'GRU'
21 | __C.B_VALIDATION = False
22 |
23 | __C.TREE = edict()
24 | __C.TREE.BRANCH_NUM = 3
25 | __C.TREE.BASE_SIZE = 64
26 |
27 |
28 | # Training options
29 | __C.TRAIN = edict()
30 | __C.TRAIN.BATCH_SIZE = 64
31 | __C.TRAIN.MAX_EPOCH = 600
32 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000
33 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4
34 | __C.TRAIN.GENERATOR_LR = 2e-4
35 | __C.TRAIN.ENCODER_LR = 2e-4
36 | __C.TRAIN.RNN_GRAD_CLIP = 0.25
37 | __C.TRAIN.FLAG = False
38 | __C.TRAIN.NET_E = 'data/text_encoder200.pth'
39 | __C.TRAIN.NET_G = 'data/bird_AttnGAN2.pth'
40 | __C.TRAIN.B_NET_D = False
41 |
42 | __C.TRAIN.SMOOTH = edict()
43 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0
44 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0
45 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0
46 | __C.TRAIN.SMOOTH.LAMBDA = 1.0
47 |
48 |
49 | # Modal options
50 | __C.GAN = edict()
51 | __C.GAN.DF_DIM = 64
52 | __C.GAN.GF_DIM = 32
53 | __C.GAN.Z_DIM = 100
54 | __C.GAN.CONDITION_DIM = 100
55 | __C.GAN.R_NUM = 2
56 | __C.GAN.B_ATTENTION = True
57 | __C.GAN.B_DCGAN = False
58 |
59 |
60 | __C.TEXT = edict()
61 | __C.TEXT.CAPTIONS_PER_IMAGE = 10
62 | __C.TEXT.EMBEDDING_DIM = 256
63 | __C.TEXT.WORDS_NUM = 25
64 |
65 |
--------------------------------------------------------------------------------
/eval/miscc/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import numpy as np
4 | from torch.nn import init
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from PIL import Image, ImageDraw, ImageFont
10 | from copy import deepcopy
11 | import skimage.transform
12 |
13 |
14 |
15 | # For visualization ################################################
16 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232],
17 | 2:[70, 70, 70], 3:[102,102,156],
18 | 4:[190,153,153], 5:[153,153,153],
19 | 6:[250,170, 30], 7:[220, 220, 0],
20 | 8:[107,142, 35], 9:[152,251,152],
21 | 10:[70,130,180], 11:[220,20, 60],
22 | 12:[255, 0, 0], 13:[0, 0, 142],
23 | 14:[119,11, 32], 15:[0, 60,100],
24 | 16:[0, 80, 100], 17:[0, 0, 230],
25 | 18:[0, 0, 70], 19:[0, 0, 0]}
26 | FONT_MAX = 50
27 |
28 |
29 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):
30 | num = captions.size(0)
31 | img_txt = Image.fromarray(convas)
32 | # get a font
33 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
34 | #fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
35 | fnt = ImageFont.truetype('FreeMono.ttf', 50)
36 | # get a drawing context
37 | d = ImageDraw.Draw(img_txt)
38 | sentence_list = []
39 | for i in range(num):
40 | cap = captions[i].data.cpu().numpy()
41 | sentence = []
42 | for j in range(len(cap)):
43 | if cap[j] == 0:
44 | break
45 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
46 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]),
47 | font=fnt, fill=(255, 255, 255, 255))
48 | sentence.append(word)
49 | sentence_list.append(sentence)
50 | return img_txt, sentence_list
51 |
52 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
53 | attn_maps, att_sze, vis_size=256, topK=5):
54 | batch_size = real_imgs.size(0)
55 | max_word_num = np.max(cap_lens)
56 | text_convas = np.ones([batch_size * FONT_MAX,
57 | max_word_num * (vis_size + 2), 3],
58 | dtype=np.uint8)
59 |
60 | real_imgs = \
61 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
62 | # [-1, 1] --> [0, 1]
63 | real_imgs.add_(1).div_(2).mul_(255)
64 | real_imgs = real_imgs.data.numpy()
65 | # b x c x h x w --> b x h x w x c
66 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
67 | pad_sze = real_imgs.shape
68 | middle_pad = np.zeros([pad_sze[2], 2, 3])
69 |
70 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
71 | img_set = []
72 | num = len(attn_maps)
73 |
74 | text_map, sentences = \
75 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0)
76 | text_map = np.asarray(text_map).astype(np.uint8)
77 |
78 | bUpdate = 1
79 | for i in range(num):
80 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
81 | #
82 | attn = attn.view(-1, 1, att_sze, att_sze)
83 | attn = attn.repeat(1, 3, 1, 1).data.numpy()
84 | # n x c x h x w --> n x h x w x c
85 | attn = np.transpose(attn, (0, 2, 3, 1))
86 | num_attn = cap_lens[i]
87 | thresh = 2./float(num_attn)
88 | #
89 | img = real_imgs[i]
90 | row = []
91 | row_merge = []
92 | row_txt = []
93 | row_beforeNorm = []
94 | conf_score = []
95 | for j in range(num_attn):
96 | one_map = attn[j]
97 | mask0 = one_map > (2. * thresh)
98 | conf_score.append(np.sum(one_map * mask0))
99 | mask = one_map > thresh
100 | one_map = one_map * mask
101 | if (vis_size // att_sze) > 1:
102 | one_map = \
103 | skimage.transform.pyramid_expand(one_map, sigma=20,
104 | upscale=vis_size // att_sze)
105 | minV = one_map.min()
106 | maxV = one_map.max()
107 | one_map = (one_map - minV) / (maxV - minV)
108 | row_beforeNorm.append(one_map)
109 | sorted_indices = np.argsort(conf_score)[::-1]
110 |
111 | for j in range(num_attn):
112 | one_map = row_beforeNorm[j]
113 | one_map *= 255
114 | #
115 | PIL_im = Image.fromarray(np.uint8(img))
116 | PIL_att = Image.fromarray(np.uint8(one_map))
117 | merged = \
118 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
119 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210)
120 | merged.paste(PIL_im, (0, 0))
121 | merged.paste(PIL_att, (0, 0), mask)
122 | merged = np.array(merged)[:, :, :3]
123 |
124 | row.append(np.concatenate([one_map, middle_pad], 1))
125 | #
126 | row_merge.append(np.concatenate([merged, middle_pad], 1))
127 | #
128 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX,
129 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :]
130 | row_txt.append(txt)
131 | # reorder
132 | row_new = []
133 | row_merge_new = []
134 | txt_new = []
135 | for j in range(num_attn):
136 | idx = sorted_indices[j]
137 | row_new.append(row[idx])
138 | row_merge_new.append(row_merge[idx])
139 | txt_new.append(row_txt[idx])
140 | row = np.concatenate(row_new[:topK], 1)
141 | row_merge = np.concatenate(row_merge_new[:topK], 1)
142 | txt = np.concatenate(txt_new[:topK], 1)
143 | if txt.shape[1] != row.shape[1]:
144 | print('Warnings: txt', txt.shape, 'row', row.shape,
145 | 'row_merge_new', row_merge_new.shape)
146 | bUpdate = 0
147 | break
148 | row = np.concatenate([txt, row_merge], 0)
149 | img_set.append(row)
150 | if bUpdate:
151 | img_set = np.concatenate(img_set, 0)
152 | img_set = img_set.astype(np.uint8)
153 | return img_set, sentences
154 | else:
155 | return None
156 |
157 |
158 | ####################################################################
159 | def weights_init(m):
160 | classname = m.__class__.__name__
161 | if classname.find('Conv') != -1:
162 | nn.init.orthogonal(m.weight.data, 1.0)
163 | elif classname.find('BatchNorm') != -1:
164 | m.weight.data.normal_(1.0, 0.02)
165 | m.bias.data.fill_(0)
166 | elif classname.find('Linear') != -1:
167 | nn.init.orthogonal(m.weight.data, 1.0)
168 | if m.bias is not None:
169 | m.bias.data.fill_(0.0)
170 |
171 |
172 | def load_params(model, new_param):
173 | for p, new_p in zip(model.parameters(), new_param):
174 | p.data.copy_(new_p)
175 |
176 |
177 | def copy_G_params(model):
178 | flatten = deepcopy(list(p.data for p in model.parameters()))
179 | return flatten
180 |
181 |
182 | def mkdir_p(path):
183 | try:
184 | os.makedirs(path)
185 | except OSError as exc: # Python >2.5
186 | if exc.errno == errno.EEXIST and os.path.isdir(path):
187 | pass
188 | else:
189 | raise
190 |
--------------------------------------------------------------------------------
/eval/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | from torch.autograd import Variable
5 | from torchvision import models
6 | import torch.utils.model_zoo as model_zoo
7 | import torch.nn.functional as F
8 |
9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
10 |
11 | from miscc.config import cfg
12 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET
13 |
14 |
15 | # ############## Text2Image Encoder-Decoder #######
16 | class RNN_ENCODER(nn.Module):
17 | def __init__(self, ntoken, ninput=300, drop_prob=0.5,
18 | nhidden=128, nlayers=1, bidirectional=True):
19 | super(RNN_ENCODER, self).__init__()
20 |
21 | self.n_steps = cfg.TEXT.WORDS_NUM
22 | self.rnn_type = cfg.RNN_TYPE
23 |
24 | self.ntoken = ntoken # size of the dictionary
25 | self.ninput = ninput # size of each embedding vector
26 | self.drop_prob = drop_prob # probability of an element to be zeroed
27 | self.nlayers = nlayers # Number of recurrent layers
28 | self.bidirectional = bidirectional
29 |
30 | if bidirectional:
31 | self.num_directions = 2
32 | else:
33 | self.num_directions = 1
34 | # number of features in the hidden state
35 | self.nhidden = nhidden // self.num_directions
36 |
37 | self.define_module()
38 | self.init_weights()
39 |
40 | def define_module(self):
41 | self.encoder = nn.Embedding(self.ntoken, self.ninput)
42 | self.drop = nn.Dropout(self.drop_prob)
43 | if self.rnn_type == 'LSTM':
44 | # dropout: If non-zero, introduces a dropout layer on
45 | # the outputs of each RNN layer except the last layer
46 | self.rnn = nn.LSTM(self.ninput, self.nhidden,
47 | self.nlayers, batch_first=True,
48 | dropout=self.drop_prob,
49 | bidirectional=self.bidirectional)
50 | elif self.rnn_type == 'GRU':
51 | self.rnn = nn.GRU(self.ninput, self.nhidden,
52 | self.nlayers, batch_first=True,
53 | dropout=self.drop_prob,
54 | bidirectional=self.bidirectional)
55 | else:
56 | raise NotImplementedError
57 |
58 | def init_weights(self):
59 | initrange = 0.1
60 | self.encoder.weight.data.uniform_(-initrange, initrange)
61 | # Do not need to initialize RNN parameters, which have been initialized
62 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM
63 | # self.decoder.weight.data.uniform_(-initrange, initrange)
64 | # self.decoder.bias.data.fill_(0)
65 |
66 | def init_hidden(self, bsz):
67 | weight = next(self.parameters()).data
68 | if self.rnn_type == 'LSTM':
69 | return (Variable(weight.new(self.nlayers * self.num_directions,
70 | bsz, self.nhidden).zero_()),
71 | Variable(weight.new(self.nlayers * self.num_directions,
72 | bsz, self.nhidden).zero_()))
73 | else:
74 | return Variable(weight.new(self.nlayers * self.num_directions,
75 | bsz, self.nhidden).zero_())
76 |
77 | def forward(self, captions, cap_lens, hidden, mask=None):
78 | # input: torch.LongTensor of size batch x n_steps
79 | # --> emb: batch x n_steps x ninput
80 | emb = self.drop(self.encoder(captions))
81 | #
82 | # Returns: a PackedSequence object
83 | cap_lens = cap_lens.data.tolist()
84 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
85 | # #hidden and memory (num_layers * num_directions, batch, hidden_size):
86 | # tensor containing the initial hidden state for each element in batch.
87 | # #output (batch, seq_len, hidden_size * num_directions)
88 | # #or a PackedSequence object:
89 | # tensor containing output features (h_t) from the last layer of RNN
90 | output, hidden = self.rnn(emb, hidden)
91 | # PackedSequence object
92 | # --> (batch, seq_len, hidden_size * num_directions)
93 | output = pad_packed_sequence(output, batch_first=True)[0]
94 | # output = self.drop(output)
95 | # --> batch x hidden_size*num_directions x seq_len
96 | words_emb = output.transpose(1, 2)
97 | # --> batch x num_directions*hidden_size
98 | if self.rnn_type == 'LSTM':
99 | sent_emb = hidden[0].transpose(0, 1).contiguous()
100 | else:
101 | sent_emb = hidden.transpose(0, 1).contiguous()
102 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
103 | return words_emb, sent_emb
104 |
105 |
106 |
107 |
108 | class G_NET(nn.Module):
109 | def __init__(self):
110 | super(G_NET, self).__init__()
111 |
112 | ngf = cfg.GAN.GF_DIM
113 | nef = cfg.TEXT.EMBEDDING_DIM
114 | ncf = cfg.GAN.CONDITION_DIM
115 |
116 |
117 | self.ca_net = CA_NET()
118 |
119 | if cfg.TREE.BRANCH_NUM > 0:
120 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
121 | self.img_net1 = GET_IMAGE_G(ngf)
122 | # gf x 64 x 64
123 | if cfg.TREE.BRANCH_NUM > 1:
124 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
125 | self.img_net2 = GET_IMAGE_G(ngf)
126 | if cfg.TREE.BRANCH_NUM > 2:
127 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
128 | self.img_net3 = GET_IMAGE_G(ngf)
129 |
130 | def forward(self, z_code, sent_emb, word_embs, mask):
131 | """
132 | :param z_code: batch x cfg.GAN.Z_DIM
133 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
134 | :param word_embs: batch x cdf x seq_len
135 | :param mask: batch x seq_len
136 | :return:
137 | """
138 | fake_imgs = []
139 | att_maps = []
140 | c_code, mu, logvar = self.ca_net(sent_emb)
141 |
142 | if cfg.TREE.BRANCH_NUM > 0:
143 | h_code1 = self.h_net1(z_code, c_code)
144 | fake_img1 = self.img_net1(h_code1)
145 | fake_imgs.append(fake_img1)
146 | if cfg.TREE.BRANCH_NUM > 1:
147 | h_code2, att1 = \
148 | self.h_net2(h_code1, c_code, word_embs, mask)
149 | fake_img2 = self.img_net2(h_code2)
150 | fake_imgs.append(fake_img2)
151 | if att1 is not None:
152 | att_maps.append(att1)
153 | if cfg.TREE.BRANCH_NUM > 2:
154 | h_code3, att2 = \
155 | self.h_net3(h_code2, c_code, word_embs, mask)
156 | fake_img3 = self.img_net3(h_code3)
157 | fake_imgs.append(fake_img3)
158 | if att2 is not None:
159 | att_maps.append(att2)
160 |
161 | return fake_imgs, att_maps, mu, logvar
162 |
163 |
164 | # ############## G networks ###################
165 | class CA_NET(nn.Module):
166 | # some code is modified from vae examples
167 | # (https://github.com/pytorch/examples/blob/master/vae/main.py)
168 | def __init__(self):
169 | super(CA_NET, self).__init__()
170 |
171 | self.t_dim = cfg.TEXT.EMBEDDING_DIM
172 | self.c_dim = cfg.GAN.CONDITION_DIM
173 |
174 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
175 | self.relu = GLU()
176 |
177 | def encode(self, text_embedding):
178 | x = self.relu(self.fc(text_embedding))
179 | mu = x[:, :self.c_dim]
180 | logvar = x[:, self.c_dim:]
181 | return mu, logvar
182 |
183 | def reparametrize(self, mu, logvar):
184 | std = logvar.mul(0.5).exp_()
185 | if cfg.CUDA:
186 | eps = torch.cuda.FloatTensor(std.size()).normal_()
187 | else:
188 | eps = torch.FloatTensor(std.size()).normal_()
189 | eps = Variable(eps)
190 | return eps.mul(std).add_(mu)
191 |
192 | def forward(self, text_embedding):
193 | mu, logvar = self.encode(text_embedding)
194 | c_code = self.reparametrize(mu, logvar)
195 | return c_code, mu, logvar
196 |
197 |
198 |
199 | class GLU(nn.Module):
200 | def __init__(self):
201 | super(GLU, self).__init__()
202 |
203 | def forward(self, x):
204 | nc = x.size(1)
205 | assert nc % 2 == 0, 'channels dont divide 2!'
206 | nc = int(nc/2)
207 | return x[:, :nc] * F.sigmoid(x[:, nc:])
208 |
209 |
210 | def conv1x1(in_planes, out_planes, bias=False):
211 | "1x1 convolution with padding"
212 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
213 | padding=0, bias=bias)
214 |
215 |
216 | def conv3x3(in_planes, out_planes):
217 | "3x3 convolution with padding"
218 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
219 | padding=1, bias=False)
220 |
221 |
222 | # Upsale the spatial size by a factor of 2
223 | def upBlock(in_planes, out_planes):
224 | block = nn.Sequential(
225 | nn.Upsample(scale_factor=2, mode='nearest'),
226 | conv3x3(in_planes, out_planes * 2),
227 | nn.BatchNorm2d(out_planes * 2),
228 | GLU())
229 | return block
230 |
231 |
232 | # Keep the spatial size
233 | def Block3x3_relu(in_planes, out_planes):
234 | block = nn.Sequential(
235 | conv3x3(in_planes, out_planes * 2),
236 | nn.BatchNorm2d(out_planes * 2),
237 | GLU())
238 | return block
239 |
240 |
241 | class ResBlock(nn.Module):
242 | def __init__(self, channel_num):
243 | super(ResBlock, self).__init__()
244 | self.block = nn.Sequential(
245 | conv3x3(channel_num, channel_num * 2),
246 | nn.BatchNorm2d(channel_num * 2),
247 | GLU(),
248 | conv3x3(channel_num, channel_num),
249 | nn.BatchNorm2d(channel_num))
250 |
251 | def forward(self, x):
252 | residual = x
253 | out = self.block(x)
254 | out += residual
255 | return out
256 |
257 |
258 |
259 | class CNN_ENCODER(nn.Module):
260 | def __init__(self, nef):
261 | super(CNN_ENCODER, self).__init__()
262 | if cfg.TRAIN.FLAG:
263 | self.nef = nef
264 | else:
265 | self.nef = 256 # define a uniform ranker
266 |
267 | model = models.inception_v3()
268 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
269 | model.load_state_dict(model_zoo.load_url(url))
270 | for param in model.parameters():
271 | param.requires_grad = False
272 | print('Load pretrained model from ', url)
273 | # print(model)
274 |
275 | self.define_module(model)
276 | self.init_trainable_weights()
277 |
278 | def define_module(self, model):
279 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
280 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
281 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
282 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
283 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
284 | self.Mixed_5b = model.Mixed_5b
285 | self.Mixed_5c = model.Mixed_5c
286 | self.Mixed_5d = model.Mixed_5d
287 | self.Mixed_6a = model.Mixed_6a
288 | self.Mixed_6b = model.Mixed_6b
289 | self.Mixed_6c = model.Mixed_6c
290 | self.Mixed_6d = model.Mixed_6d
291 | self.Mixed_6e = model.Mixed_6e
292 | self.Mixed_7a = model.Mixed_7a
293 | self.Mixed_7b = model.Mixed_7b
294 | self.Mixed_7c = model.Mixed_7c
295 |
296 | self.emb_features = conv1x1(768, self.nef)
297 | self.emb_cnn_code = nn.Linear(2048, self.nef)
298 |
299 | def init_trainable_weights(self):
300 | initrange = 0.1
301 | self.emb_features.weight.data.uniform_(-initrange, initrange)
302 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
303 |
304 | def forward(self, x):
305 | features = None
306 | # --> fixed-size input: batch x 3 x 299 x 299
307 | x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
308 | # 299 x 299 x 3
309 | x = self.Conv2d_1a_3x3(x)
310 | # 149 x 149 x 32
311 | x = self.Conv2d_2a_3x3(x)
312 | # 147 x 147 x 32
313 | x = self.Conv2d_2b_3x3(x)
314 | # 147 x 147 x 64
315 | x = F.max_pool2d(x, kernel_size=3, stride=2)
316 | # 73 x 73 x 64
317 | x = self.Conv2d_3b_1x1(x)
318 | # 73 x 73 x 80
319 | x = self.Conv2d_4a_3x3(x)
320 | # 71 x 71 x 192
321 |
322 | x = F.max_pool2d(x, kernel_size=3, stride=2)
323 | # 35 x 35 x 192
324 | x = self.Mixed_5b(x)
325 | # 35 x 35 x 256
326 | x = self.Mixed_5c(x)
327 | # 35 x 35 x 288
328 | x = self.Mixed_5d(x)
329 | # 35 x 35 x 288
330 |
331 | x = self.Mixed_6a(x)
332 | # 17 x 17 x 768
333 | x = self.Mixed_6b(x)
334 | # 17 x 17 x 768
335 | x = self.Mixed_6c(x)
336 | # 17 x 17 x 768
337 | x = self.Mixed_6d(x)
338 | # 17 x 17 x 768
339 | x = self.Mixed_6e(x)
340 | # 17 x 17 x 768
341 |
342 | # image region features
343 | features = x
344 | # 17 x 17 x 768
345 |
346 | x = self.Mixed_7a(x)
347 | # 8 x 8 x 1280
348 | x = self.Mixed_7b(x)
349 | # 8 x 8 x 2048
350 | x = self.Mixed_7c(x)
351 | # 8 x 8 x 2048
352 | x = F.avg_pool2d(x, kernel_size=8)
353 | # 1 x 1 x 2048
354 | # x = F.dropout(x, training=self.training)
355 | # 1 x 1 x 2048
356 | x = x.view(x.size(0), -1)
357 | # 2048
358 |
359 | # global image features
360 | cnn_code = self.emb_cnn_code(x)
361 | # 512
362 | if features is not None:
363 | features = self.emb_features(features)
364 | return features, cnn_code
365 |
366 |
367 | class INIT_STAGE_G(nn.Module):
368 | def __init__(self, ngf, ncf):
369 | super(INIT_STAGE_G, self).__init__()
370 | self.gf_dim = ngf
371 | self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM
372 |
373 | self.define_module()
374 |
375 | def define_module(self):
376 | nz, ngf = self.in_dim, self.gf_dim
377 | self.fc = nn.Sequential(
378 | nn.Linear(nz, ngf * 4 * 4 * 2, bias=False),
379 | # removing for single instance caption
380 | nn.BatchNorm1d(ngf * 4 * 4 * 2),
381 | GLU())
382 |
383 | self.upsample1 = upBlock(ngf, ngf // 2)
384 | self.upsample2 = upBlock(ngf // 2, ngf // 4)
385 | self.upsample3 = upBlock(ngf // 4, ngf // 8)
386 | self.upsample4 = upBlock(ngf // 8, ngf // 16)
387 |
388 | def forward(self, z_code, c_code):
389 | """
390 | :param z_code: batch x cfg.GAN.Z_DIM
391 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM
392 | :return: batch x ngf/16 x 64 x 64
393 | """
394 | c_z_code = torch.cat((c_code, z_code), 1)
395 | # state size ngf x 4 x 4
396 | out_code = self.fc(c_z_code)
397 | out_code = out_code.view(-1, self.gf_dim, 4, 4)
398 | # state size ngf/3 x 8 x 8
399 | out_code = self.upsample1(out_code)
400 | # state size ngf/4 x 16 x 16
401 | out_code = self.upsample2(out_code)
402 | # state size ngf/8 x 32 x 32
403 | out_code32 = self.upsample3(out_code)
404 | # state size ngf/16 x 64 x 64
405 | out_code64 = self.upsample4(out_code32)
406 |
407 | return out_code64
408 |
409 |
410 | class NEXT_STAGE_G(nn.Module):
411 | def __init__(self, ngf, nef, ncf):
412 | super(NEXT_STAGE_G, self).__init__()
413 | self.gf_dim = ngf
414 | self.ef_dim = nef
415 | self.cf_dim = ncf
416 | self.num_residual = cfg.GAN.R_NUM
417 | self.define_module()
418 |
419 | def _make_layer(self, block, channel_num):
420 | layers = []
421 | for i in range(cfg.GAN.R_NUM):
422 | layers.append(block(channel_num))
423 | return nn.Sequential(*layers)
424 |
425 | def define_module(self):
426 | ngf = self.gf_dim
427 | self.att = ATT_NET(ngf, self.ef_dim)
428 | self.residual = self._make_layer(ResBlock, ngf * 2)
429 | self.upsample = upBlock(ngf * 2, ngf)
430 |
431 | def forward(self, h_code, c_code, word_embs, mask):
432 | """
433 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw)
434 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
435 | c_code1: batch x idf x queryL
436 | att1: batch x sourceL x queryL
437 | """
438 | self.att.applyMask(mask)
439 | c_code, att = self.att(h_code, word_embs)
440 | h_c_code = torch.cat((h_code, c_code), 1)
441 | out_code = self.residual(h_c_code)
442 |
443 | # state size ngf/2 x 2in_size x 2in_size
444 | out_code = self.upsample(out_code)
445 |
446 | return out_code, att
447 |
448 |
449 | class GET_IMAGE_G(nn.Module):
450 | def __init__(self, ngf):
451 | super(GET_IMAGE_G, self).__init__()
452 | self.gf_dim = ngf
453 | self.img = nn.Sequential(
454 | conv3x3(ngf, 3),
455 | nn.Tanh()
456 | )
457 |
458 | def forward(self, h_code):
459 | out_img = self.img(h_code)
460 | return out_img
461 |
462 |
463 | class G_DCGAN(nn.Module):
464 | def __init__(self):
465 | super(G_DCGAN, self).__init__()
466 | ngf = cfg.GAN.GF_DIM
467 | nef = cfg.TEXT.EMBEDDING_DIM
468 | ncf = cfg.GAN.CONDITION_DIM
469 | self.ca_net = CA_NET()
470 |
471 | # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64
472 | if cfg.TREE.BRANCH_NUM > 0:
473 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
474 | # gf x 64 x 64
475 | if cfg.TREE.BRANCH_NUM > 1:
476 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
477 | if cfg.TREE.BRANCH_NUM > 2:
478 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
479 | self.img_net = GET_IMAGE_G(ngf)
480 |
481 | def forward(self, z_code, sent_emb, word_embs, mask):
482 | """
483 | :param z_code: batch x cfg.GAN.Z_DIM
484 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
485 | :param word_embs: batch x cdf x seq_len
486 | :param mask: batch x seq_len
487 | :return:
488 | """
489 | att_maps = []
490 | c_code, mu, logvar = self.ca_net(sent_emb)
491 | if cfg.TREE.BRANCH_NUM > 0:
492 | h_code = self.h_net1(z_code, c_code)
493 | if cfg.TREE.BRANCH_NUM > 1:
494 | h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask)
495 | if att1 is not None:
496 | att_maps.append(att1)
497 | if cfg.TREE.BRANCH_NUM > 2:
498 | h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask)
499 | if att2 is not None:
500 | att_maps.append(att2)
501 |
502 | fake_imgs = self.img_net(h_code)
503 | return [fake_imgs], att_maps, mu, logvar
504 |
505 |
506 | # ############## D networks ##########################
507 | def Block3x3_leakRelu(in_planes, out_planes):
508 | block = nn.Sequential(
509 | conv3x3(in_planes, out_planes),
510 | nn.BatchNorm2d(out_planes),
511 | nn.LeakyReLU(0.2, inplace=True)
512 | )
513 | return block
514 |
515 |
516 | # Downsale the spatial size by a factor of 2
517 | def downBlock(in_planes, out_planes):
518 | block = nn.Sequential(
519 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
520 | nn.BatchNorm2d(out_planes),
521 | nn.LeakyReLU(0.2, inplace=True)
522 | )
523 | return block
524 |
525 |
526 | # Downsale the spatial size by a factor of 16
527 | def encode_image_by_16times(ndf):
528 | encode_img = nn.Sequential(
529 | # --> state size. ndf x in_size/2 x in_size/2
530 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
531 | nn.LeakyReLU(0.2, inplace=True),
532 | # --> state size 2ndf x x in_size/4 x in_size/4
533 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
534 | nn.BatchNorm2d(ndf * 2),
535 | nn.LeakyReLU(0.2, inplace=True),
536 | # --> state size 4ndf x in_size/8 x in_size/8
537 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
538 | nn.BatchNorm2d(ndf * 4),
539 | nn.LeakyReLU(0.2, inplace=True),
540 | # --> state size 8ndf x in_size/16 x in_size/16
541 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
542 | nn.BatchNorm2d(ndf * 8),
543 | nn.LeakyReLU(0.2, inplace=True)
544 | )
545 | return encode_img
546 |
547 |
548 | class D_GET_LOGITS(nn.Module):
549 | def __init__(self, ndf, nef, bcondition=False):
550 | super(D_GET_LOGITS, self).__init__()
551 | self.df_dim = ndf
552 | self.ef_dim = nef
553 | self.bcondition = bcondition
554 | if self.bcondition:
555 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)
556 |
557 | self.outlogits = nn.Sequential(
558 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
559 | nn.Sigmoid())
560 |
561 | def forward(self, h_code, c_code=None):
562 | if self.bcondition and c_code is not None:
563 | # conditioning output
564 | c_code = c_code.view(-1, self.ef_dim, 1, 1)
565 | c_code = c_code.repeat(1, 1, 4, 4)
566 | # state size (ngf+egf) x 4 x 4
567 | h_c_code = torch.cat((h_code, c_code), 1)
568 | # state size ngf x in_size x in_size
569 | h_c_code = self.jointConv(h_c_code)
570 | else:
571 | h_c_code = h_code
572 |
573 | output = self.outlogits(h_c_code)
574 | return output.view(-1)
575 |
576 |
577 | # For 64 x 64 images
578 | class D_NET64(nn.Module):
579 | def __init__(self, b_jcu=True):
580 | super(D_NET64, self).__init__()
581 | ndf = cfg.GAN.DF_DIM
582 | nef = cfg.TEXT.EMBEDDING_DIM
583 | self.img_code_s16 = encode_image_by_16times(ndf)
584 | if b_jcu:
585 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
586 | else:
587 | self.UNCOND_DNET = None
588 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
589 |
590 | def forward(self, x_var):
591 | x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df
592 | return x_code4
593 |
594 |
595 | # For 128 x 128 images
596 | class D_NET128(nn.Module):
597 | def __init__(self, b_jcu=True):
598 | super(D_NET128, self).__init__()
599 | ndf = cfg.GAN.DF_DIM
600 | nef = cfg.TEXT.EMBEDDING_DIM
601 | self.img_code_s16 = encode_image_by_16times(ndf)
602 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
603 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)
604 | #
605 | if b_jcu:
606 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
607 | else:
608 | self.UNCOND_DNET = None
609 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
610 |
611 | def forward(self, x_var):
612 | x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df
613 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df
614 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df
615 | return x_code4
616 |
617 |
618 | # For 256 x 256 images
619 | class D_NET256(nn.Module):
620 | def __init__(self, b_jcu=True):
621 | super(D_NET256, self).__init__()
622 | ndf = cfg.GAN.DF_DIM
623 | nef = cfg.TEXT.EMBEDDING_DIM
624 | self.img_code_s16 = encode_image_by_16times(ndf)
625 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
626 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32)
627 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16)
628 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8)
629 | if b_jcu:
630 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
631 | else:
632 | self.UNCOND_DNET = None
633 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
634 |
635 | def forward(self, x_var):
636 | x_code16 = self.img_code_s16(x_var)
637 | x_code8 = self.img_code_s32(x_code16)
638 | x_code4 = self.img_code_s64(x_code8)
639 | x_code4 = self.img_code_s64_1(x_code4)
640 | x_code4 = self.img_code_s64_2(x_code4)
641 | return x_code4
642 |
--------------------------------------------------------------------------------
/eval/requirements.txt:
--------------------------------------------------------------------------------
1 | Flask
2 | python-dateutil
3 | easydict
4 | scikit-image
5 | azure-storage-blob
6 | applicationinsights
7 | libmc
--------------------------------------------------------------------------------
/example_bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/example_bird.png
--------------------------------------------------------------------------------
/example_coco.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/example_coco.png
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taoxugit/AttnGAN/0d000e652b407e976cb88fab299e8566f3de8a37/framework.png
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !README.md
3 | !.gitignore
--------------------------------------------------------------------------------