├── DAMSMencoders └── .gitkeep ├── miscc ├── __init__.py ├── config.py ├── losses.py └── utils.py ├── data └── birds │ ├── captions.pickle │ └── bert_captions.pickle ├── cfg ├── eval_bird.yaml ├── DAMSM │ └── bird.yaml ├── STREAM │ └── bird.yaml ├── bird_attn2.yaml ├── test │ ├── bird_attn2.yaml │ └── bird_cycle.yaml └── bird_cycle.yaml ├── README.md ├── LICENSE ├── .gitignore ├── inception.py ├── GlobalAttention.py ├── main.py ├── pretrain_DAMSM.py ├── pretrain_STREAM.py ├── datasets.py ├── model.py └── trainer.py /DAMSMencoders/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /data/birds/captions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suetAndTie/cycle-image-gan/HEAD/data/birds/captions.pickle -------------------------------------------------------------------------------- /data/birds/bert_captions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suetAndTie/cycle-image-gan/HEAD/data/birds/bert_captions.pickle -------------------------------------------------------------------------------- /cfg/eval_bird.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 3 6 | WORKERS: 1 7 | 8 | B_VALIDATION: False # True # False 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | FLAG: False 15 | NET_G: 'models/bird_AttnGAN2.pth' 16 | B_NET_D: False 17 | BATCH_SIZE: 100 18 | NET_E: 'DAMSMencoders/bird/text_encoder200.pth' 19 | 20 | 21 | GAN: 22 | DF_DIM: 64 23 | GF_DIM: 32 24 | Z_DIM: 100 25 | R_NUM: 2 26 | 27 | TEXT: 28 | EMBEDDING_DIM: 256 29 | CAPTIONS_PER_IMAGE: 10 30 | WORDS_NUM: 25 31 | -------------------------------------------------------------------------------- /cfg/DAMSM/bird.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'DAMSM' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 1 11 | BASE_SIZE: 299 12 | 13 | 14 | TRAIN: 15 | FLAG: True 16 | NET_E: '' # '../DAMSMencoders/bird/text_encoder200.pth' 17 | BATCH_SIZE: 48 18 | MAX_EPOCH: 600 19 | SNAPSHOT_INTERVAL: 50 20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good; scott: 0.0007 with 0.98decay 21 | RNN_GRAD_CLIP: 0.25 22 | SMOOTH: 23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 24 | GAMMA2: 5.0 25 | GAMMA3: 10.0 # 10good 1&100bad 26 | 27 | 28 | 29 | TEXT: 30 | EMBEDDING_DIM: 256 31 | CAPTIONS_PER_IMAGE: 10 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cycle-image-gan 2 | Based on https://github.com/taoxugit/AttnGAN/tree/master/code 3 | 4 | Paper https://arxiv.org/abs/2003.12137 5 | 6 | ## Spring 2019 CS 224U Project 7 | * BERT encoder 8 | * Cycle-GAN 9 | * Image2Text encoder 10 | 11 | ## Download Data 12 | 1. Download AttnGAN preprocessed data and captions [birds](https://drive.google.com/open?id=1O_LtUP9sch09QH3s_EBAgLEctBQ5JBSJ) 13 | 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/` 14 | 15 | ## Instructions 16 | * pretrain STREAM 17 | ``` 18 | python pretrain_STREAM.py --cfg cfg/STREAM/bird.yaml --gpu 0 19 | ``` 20 | * train CycleGAN 21 | ``` 22 | python main.py --cfg cfg/bird_cycle.yaml --gpu 0 23 | ``` 24 | -------------------------------------------------------------------------------- /cfg/STREAM/bird.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'STREAM' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 1 11 | BASE_SIZE: 299 12 | 13 | 14 | TRAIN: 15 | FLAG: True 16 | NET_E: '' # '../DAMSMencoders/bird/text_encoder200.pth' 17 | BATCH_SIZE: 48 18 | MAX_EPOCH: 600 19 | SNAPSHOT_INTERVAL: 50 20 | ENCODER_LR: 0.002 # 0.0002best; 0.002good; scott: 0.0007 with 0.98decay 21 | RNN_GRAD_CLIP: 0.25 22 | SMOOTH: 23 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 24 | GAMMA2: 5.0 25 | GAMMA3: 10.0 # 10good 1&100bad 26 | 27 | CNN_RNN: 28 | HIDDEN_DIM: 256 29 | 30 | TEXT: 31 | EMBEDDING_DIM: 768 32 | CAPTIONS_PER_IMAGE: 10 33 | -------------------------------------------------------------------------------- /cfg/bird_attn2.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | TRAINER: 'condGANTrainer' 15 | FLAG: True 16 | NET_G: '' # '../models/bird_AttnGAN2.pth' 17 | B_NET_D: True 18 | BATCH_SIZE: 20 # 22 19 | MAX_EPOCH: 600 20 | SNAPSHOT_INTERVAL: 50 21 | DISCRIMINATOR_LR: 0.0002 22 | GENERATOR_LR: 0.0002 23 | # 24 | NET_E: 'DAMSMencoders/birdattn/text_encoder200.pth' 25 | SMOOTH: 26 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 27 | GAMMA2: 5.0 28 | GAMMA3: 10.0 # 10good 1&100bad 29 | LAMBDA: 5.0 30 | 31 | 32 | GAN: 33 | DF_DIM: 64 34 | GF_DIM: 32 35 | Z_DIM: 100 36 | R_NUM: 2 37 | 38 | TEXT: 39 | EMBEDDING_DIM: 256 40 | CAPTIONS_PER_IMAGE: 10 41 | -------------------------------------------------------------------------------- /cfg/test/bird_attn2.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | B_VALIDATION: True 8 | 9 | 10 | TREE: 11 | BRANCH_NUM: 3 12 | 13 | 14 | TRAIN: 15 | TRAINER: 'condGANTrainer' 16 | FLAG: False 17 | NET_G: 'models/attn/netG_epoch_100.pth' 18 | B_NET_D: True 19 | BATCH_SIZE: 20 # 22 20 | MAX_EPOCH: 600 21 | SNAPSHOT_INTERVAL: 50 22 | DISCRIMINATOR_LR: 0.0002 23 | GENERATOR_LR: 0.0002 24 | # 25 | NET_E: 'models/DAMSM/text_encoder200.pth' 26 | SMOOTH: 27 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 28 | GAMMA2: 5.0 29 | GAMMA3: 10.0 # 10good 1&100bad 30 | LAMBDA: 5.0 31 | 32 | 33 | GAN: 34 | DF_DIM: 64 35 | GF_DIM: 32 36 | Z_DIM: 100 37 | R_NUM: 2 38 | 39 | TEXT: 40 | EMBEDDING_DIM: 256 41 | CAPTIONS_PER_IMAGE: 10 42 | -------------------------------------------------------------------------------- /cfg/test/bird_cycle.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'cycle' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | B_VALIDATION: True 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | TRAINER: 'CycleGANTrainer' 15 | FLAG: False 16 | NET_G: 'models/cycle/netG_epoch_100.pth' 17 | B_NET_D: True 18 | BATCH_SIZE: 20 # 22 19 | MAX_EPOCH: 600 20 | SNAPSHOT_INTERVAL: 25 21 | DISCRIMINATOR_LR: 0.0002 22 | GENERATOR_LR: 0.0002 23 | # 24 | NET_E: 'models/STREAM/text_encoder100.pth' 25 | #NET_E: 'STREAMencoders/birdcycle/text_encoder200.pth' 26 | SMOOTH: 27 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 28 | GAMMA2: 5.0 29 | GAMMA3: 10.0 # 10good 1&100bad 30 | LAMBDA: 5.0 31 | 32 | CNN_RNN: 33 | HIDDEN_DIM: 256 34 | 35 | GAN: 36 | DF_DIM: 64 37 | GF_DIM: 32 38 | Z_DIM: 100 39 | R_NUM: 2 40 | 41 | TEXT: 42 | EMBEDDING_DIM: 768 43 | CAPTIONS_PER_IMAGE: 10 44 | -------------------------------------------------------------------------------- /cfg/bird_cycle.yaml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'cycle' 2 | 3 | DATASET_NAME: 'birds' 4 | DATA_DIR: 'data/birds/test' 5 | GPU_ID: 0 6 | WORKERS: 4 7 | 8 | 9 | TREE: 10 | BRANCH_NUM: 3 11 | 12 | 13 | TRAIN: 14 | TRAINER: 'CycleGANTrainer' 15 | FLAG: True 16 | NET_G: '' # '../models/bird_AttnGAN2.pth' 17 | B_NET_D: True 18 | BATCH_SIZE: 20 # 22 19 | MAX_EPOCH: 600 20 | SNAPSHOT_INTERVAL: 25 21 | DISCRIMINATOR_LR: 0.0002 22 | GENERATOR_LR: 0.0002 23 | # 24 | NET_E: 'output/birds_STREAM_2019_06_07_18_55_55/Model/text_encoder100.pth' 25 | #NET_E: 'STREAMencoders/birdcycle/text_encoder200.pth' 26 | SMOOTH: 27 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 28 | GAMMA2: 5.0 29 | GAMMA3: 10.0 # 10good 1&100bad 30 | LAMBDA: 5.0 31 | 32 | CNN_RNN: 33 | HIDDEN_DIM: 256 34 | 35 | GAN: 36 | DF_DIM: 64 37 | GF_DIM: 32 38 | Z_DIM: 100 39 | R_NUM: 2 40 | 41 | TEXT: 42 | EMBEDDING_DIM: 768 43 | CAPTIONS_PER_IMAGE: 10 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Trevor Tsue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DAMSMencoders/* 2 | output/* 3 | models/* 4 | 5 | .DS_Store 6 | Desktop.ini 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | # input data, saved log, checkpoints 111 | data/ 112 | input/ 113 | saved/ 114 | datasets/ 115 | 116 | # editor, os cache directory 117 | .vscode/ 118 | .idea/ 119 | __MACOSX/ 120 | -------------------------------------------------------------------------------- /miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.CONFIG_NAME = '' 15 | __C.DATA_DIR = '' 16 | __C.GPU_ID = 0 17 | __C.CUDA = True 18 | __C.WORKERS = 6 19 | 20 | __C.RNN_TYPE = 'LSTM' # 'GRU' 21 | __C.B_VALIDATION = False 22 | 23 | __C.TREE = edict() 24 | __C.TREE.BRANCH_NUM = 3 25 | __C.TREE.BASE_SIZE = 64 26 | 27 | 28 | # Training options 29 | __C.TRAIN = edict() 30 | __C.TRAIN.TRAINER = 'condGANTrainer' 31 | __C.TRAIN.BATCH_SIZE = 64 32 | __C.TRAIN.MAX_EPOCH = 600 33 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000 34 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 35 | __C.TRAIN.GENERATOR_LR = 2e-4 36 | __C.TRAIN.ENCODER_LR = 2e-4 37 | __C.TRAIN.RNN_GRAD_CLIP = 0.25 38 | __C.TRAIN.FLAG = True 39 | __C.TRAIN.NET_E = '' 40 | __C.TRAIN.NET_G = '' 41 | __C.TRAIN.B_NET_D = True 42 | 43 | __C.TRAIN.SMOOTH = edict() 44 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0 45 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0 46 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0 47 | __C.TRAIN.SMOOTH.LAMBDA = 1.0 48 | 49 | 50 | # Modal options 51 | __C.GAN = edict() 52 | __C.GAN.DF_DIM = 64 53 | __C.GAN.GF_DIM = 128 54 | __C.GAN.Z_DIM = 100 55 | __C.GAN.CONDITION_DIM = 100 56 | __C.GAN.R_NUM = 2 57 | __C.GAN.B_ATTENTION = True 58 | __C.GAN.B_DCGAN = False 59 | 60 | __C.CNN_RNN = edict() 61 | __C.CNN_RNN.HIDDEN_DIM = 256 62 | 63 | 64 | __C.TEXT = edict() 65 | __C.TEXT.CAPTIONS_PER_IMAGE = 10 66 | __C.TEXT.EMBEDDING_DIM = 256 67 | __C.TEXT.WORDS_NUM = 18 68 | 69 | 70 | def _merge_a_into_b(a, b): 71 | """Merge config dictionary a into config dictionary b, clobbering the 72 | options in b whenever they are also specified in a. 73 | """ 74 | if type(a) is not edict: 75 | return 76 | 77 | for k, v in a.items(): 78 | # a must specify keys that are in b 79 | if k not in b: 80 | raise KeyError('{} is not a valid config key'.format(k)) 81 | 82 | # the types must match, too 83 | old_type = type(b[k]) 84 | if old_type is not type(v): 85 | if isinstance(b[k], np.ndarray): 86 | v = np.array(v, dtype=b[k].dtype) 87 | else: 88 | raise ValueError(('Type mismatch ({} vs. {}) ' 89 | 'for config key: {}').format(type(b[k]), 90 | type(v), k)) 91 | 92 | # recursively merge dicts 93 | if type(v) is edict: 94 | try: 95 | _merge_a_into_b(a[k], b[k]) 96 | except: 97 | print('Error under config key: {}'.format(k)) 98 | raise 99 | else: 100 | b[k] = v 101 | 102 | 103 | def cfg_from_file(filename): 104 | """Load a config file and merge it into the default options.""" 105 | import yaml 106 | with open(filename, 'r') as f: 107 | yaml_cfg = edict(yaml.load(f)) 108 | 109 | _merge_a_into_b(yaml_cfg, __C) 110 | -------------------------------------------------------------------------------- /inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | import torch.utils.data 7 | from torchvision.models.inception import inception_v3 8 | 9 | import numpy as np 10 | from scipy.stats import entropy 11 | from torch.utils.data.dataset import Dataset 12 | 13 | from skimage import io 14 | 15 | 16 | class GeneratedDataset(Dataset): 17 | def __init__(self, root_dir, transform=None): 18 | self.root_dir = root_dir 19 | self.transform = transform 20 | 21 | self.all_images = [] 22 | for dir in list(os.listdir(self.root_dir)): 23 | if dir == '.DS_Store': continue 24 | for filename in list(os.listdir(os.path.join(self.root_dir, dir))): 25 | self.all_images.append(os.path.join(dir, filename)) 26 | 27 | 28 | def __len__(self): 29 | return len(self.all_images) 30 | 31 | def __getitem__(self, idx): 32 | img_name = os.path.join(self.root_dir,self.all_images[idx]) 33 | image = io.imread(img_name) 34 | sample = image 35 | 36 | if self.transform: 37 | sample = self.transform(sample) 38 | 39 | return sample 40 | 41 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 42 | """Computes the inception score of the generated images imgs 43 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 44 | cuda -- whether or not to run on GPU 45 | batch_size -- batch size for feeding into Inception v3 46 | splits -- number of splits 47 | """ 48 | N = len(imgs) 49 | 50 | assert batch_size > 0 51 | assert N > batch_size 52 | 53 | # Set up dtype 54 | if cuda: 55 | dtype = torch.cuda.FloatTensor 56 | else: 57 | if torch.cuda.is_available(): 58 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 59 | dtype = torch.FloatTensor 60 | 61 | # Set up dataloader 62 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 63 | 64 | # Load inception model 65 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 66 | inception_model.eval(); 67 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 68 | def get_pred(x): 69 | if resize: 70 | x = up(x) 71 | x = inception_model(x) 72 | return F.softmax(x).data.cpu().numpy() 73 | 74 | # Get predictions 75 | preds = np.zeros((N, 1000)) 76 | 77 | for i, batch in enumerate(dataloader, 0): 78 | batch = batch.transpose(1, 3) 79 | 80 | batch = batch.type(dtype) 81 | batchv = Variable(batch) 82 | batch_size_i = batch.size()[0] 83 | 84 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 85 | 86 | # Now compute the mean kl-div 87 | split_scores = [] 88 | 89 | for k in range(splits): 90 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 91 | py = np.mean(part, axis=0) 92 | scores = [] 93 | for i in range(part.shape[0]): 94 | pyx = part[i, :] 95 | scores.append(entropy(pyx, py)) 96 | split_scores.append(np.exp(np.mean(scores))) 97 | 98 | return np.mean(split_scores), np.std(split_scores) 99 | 100 | 101 | data_path = 'models/attn/netG_epoch_150/single' 102 | imgs = GeneratedDataset(data_path) 103 | 104 | print(inception_score(imgs)) 105 | -------------------------------------------------------------------------------- /GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query metrix. 3 | Based on each query vector q, it computes a parameterized convex combination of the matrix 4 | based. 5 | H_1 H_2 H_3 ... H_n 6 | q q q q 7 | | | | | 8 | \ | | / 9 | ..... 10 | \ | / 11 | a 12 | Constructs a unit mapping. 13 | $$(H_1 + H_n, q) => (a)$$ 14 | Where H is of `batch x n x dim` and q is of `batch x dim`. 15 | References: 16 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules 17 | http://www.aclweb.org/anthology/D15-1166 18 | https://github.com/taoxugit/AttnGAN/blob/master/code/GlobalAttention.py 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def conv1x1(in_planes, out_planes): 26 | "1x1 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | 30 | 31 | def func_attention(query, context, gamma1): 32 | """ 33 | query: batch x ndf x queryL 34 | context: batch x ndf x ih x iw (sourceL=ihxiw) 35 | mask: batch_size x sourceL 36 | """ 37 | batch_size, queryL = query.size(0), query.size(2) 38 | ih, iw = context.size(2), context.size(3) 39 | sourceL = ih * iw 40 | 41 | # --> batch x sourceL x ndf 42 | context = context.view(batch_size, -1, sourceL) 43 | contextT = torch.transpose(context, 1, 2).contiguous() 44 | 45 | # Get attention 46 | # (batch x sourceL x ndf)(batch x ndf x queryL) 47 | # -->batch x sourceL x queryL 48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper 49 | # --> batch*sourceL x queryL 50 | attn = attn.view(batch_size*sourceL, queryL) 51 | attn = nn.Softmax(dim=1)(attn) # Eq. (8) 52 | 53 | # --> batch x sourceL x queryL 54 | attn = attn.view(batch_size, sourceL, queryL) 55 | # --> batch*queryL x sourceL 56 | attn = torch.transpose(attn, 1, 2).contiguous() 57 | attn = attn.view(batch_size*queryL, sourceL) 58 | # Eq. (9) 59 | attn = attn * gamma1 60 | attn = nn.Softmax(dim=1)(attn) 61 | attn = attn.view(batch_size, queryL, sourceL) 62 | # --> batch x sourceL x queryL 63 | attnT = torch.transpose(attn, 1, 2).contiguous() 64 | 65 | # (batch x ndf x sourceL)(batch x sourceL x queryL) 66 | # --> batch x ndf x queryL 67 | weightedContext = torch.bmm(context, attnT) 68 | 69 | return weightedContext, attn.view(batch_size, -1, ih, iw) 70 | 71 | 72 | class GlobalAttentionGeneral(nn.Module): 73 | def __init__(self, idf, cdf): 74 | super(GlobalAttentionGeneral, self).__init__() 75 | self.conv_context = conv1x1(cdf, idf) 76 | self.sm = nn.Softmax(dim=1) 77 | self.mask = None 78 | 79 | def applyMask(self, mask): 80 | self.mask = mask # batch x sourceL 81 | 82 | def forward(self, input, context): 83 | """ 84 | input: batch x idf x ih x iw (queryL=ihxiw) 85 | context: batch x cdf x sourceL 86 | """ 87 | ih, iw = input.size(2), input.size(3) 88 | queryL = ih * iw 89 | batch_size, sourceL = context.size(0), context.size(2) 90 | 91 | # --> batch x queryL x idf 92 | target = input.view(batch_size, -1, queryL) 93 | targetT = torch.transpose(target, 1, 2).contiguous() 94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1 95 | sourceT = context.unsqueeze(3) 96 | # --> batch x idf x sourceL 97 | sourceT = self.conv_context(sourceT).squeeze(3) 98 | 99 | # Get attention 100 | # (batch x queryL x idf)(batch x idf x sourceL) 101 | # -->batch x queryL x sourceL 102 | attn = torch.bmm(targetT, sourceT) 103 | # --> batch*queryL x sourceL 104 | attn = attn.view(batch_size*queryL, sourceL) 105 | if self.mask is not None: 106 | # batch_size x sourceL --> batch_size*queryL x sourceL 107 | mask = self.mask.repeat(queryL, 1) 108 | attn.data.masked_fill_(mask.data, -float('inf')) 109 | attn = self.sm(attn) # Eq. (2) 110 | # --> batch x queryL x sourceL 111 | attn = attn.view(batch_size, queryL, sourceL) 112 | # --> batch x sourceL x queryL 113 | attn = torch.transpose(attn, 1, 2).contiguous() 114 | 115 | # (batch x idf x sourceL)(batch x sourceL x queryL) 116 | # --> batch x idf x queryL 117 | weightedContext = torch.bmm(sourceT, attn) 118 | weightedContext = weightedContext.view(batch_size, -1, ih, iw) 119 | attn = attn.view(batch_size, -1, ih, iw) 120 | 121 | return weightedContext, attn 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from miscc.config import cfg, cfg_from_file 4 | from datasets import TextDataset 5 | import trainer 6 | 7 | import os 8 | import sys 9 | import time 10 | import random 11 | import pprint 12 | import datetime 13 | import dateutil.tz 14 | import argparse 15 | import numpy as np 16 | 17 | import torch 18 | import torchvision.transforms as transforms 19 | 20 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 21 | sys.path.append(dir_path) 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train a AttnGAN network') 26 | parser.add_argument('--cfg', dest='cfg_file', 27 | help='optional config file', 28 | default='cfg/bird_attn2.yaml', type=str) 29 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) 30 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='data/birds') 31 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def gen_example(wordtoix, algo): 37 | '''generate images from example sentences''' 38 | from nltk.tokenize import RegexpTokenizer 39 | filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) 40 | data_dic = {} 41 | with open(filepath, "r") as f: 42 | filenames = f.read().decode('utf8').split('\n') 43 | for name in filenames: 44 | if len(name) == 0: 45 | continue 46 | filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) 47 | with open(filepath, "r") as f: 48 | print('Load from:', name) 49 | sentences = f.read().decode('utf8').split('\n') 50 | # a list of indices for a sentence 51 | captions = [] 52 | cap_lens = [] 53 | for sent in sentences: 54 | if len(sent) == 0: 55 | continue 56 | sent = sent.replace("\ufffd\ufffd", " ") 57 | tokenizer = RegexpTokenizer(r'\w+') 58 | tokens = tokenizer.tokenize(sent.lower()) 59 | if len(tokens) == 0: 60 | print('sent', sent) 61 | continue 62 | 63 | rev = [] 64 | for t in tokens: 65 | t = t.encode('ascii', 'ignore').decode('ascii') 66 | if len(t) > 0 and t in wordtoix: 67 | rev.append(wordtoix[t]) 68 | captions.append(rev) 69 | cap_lens.append(len(rev)) 70 | max_len = np.max(cap_lens) 71 | 72 | sorted_indices = np.argsort(cap_lens)[::-1] 73 | cap_lens = np.asarray(cap_lens) 74 | cap_lens = cap_lens[sorted_indices] 75 | cap_array = np.zeros((len(captions), max_len), dtype='int64') 76 | for i in range(len(captions)): 77 | idx = sorted_indices[i] 78 | cap = captions[idx] 79 | c_len = len(cap) 80 | cap_array[i, :c_len] = cap 81 | key = name[(name.rfind('/') + 1):] 82 | data_dic[key] = [cap_array, cap_lens, sorted_indices] 83 | algo.gen_example(data_dic) 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | if args.cfg_file is not None: 89 | cfg_from_file(args.cfg_file) 90 | 91 | if args.gpu_id != -1: 92 | cfg.GPU_ID = args.gpu_id 93 | else: 94 | cfg.CUDA = False 95 | 96 | if args.data_dir != '': 97 | cfg.DATA_DIR = args.data_dir 98 | print('Using config:') 99 | pprint.pprint(cfg) 100 | 101 | if not cfg.TRAIN.FLAG: 102 | args.manualSeed = 100 103 | elif args.manualSeed is None: 104 | args.manualSeed = random.randint(1, 10000) 105 | random.seed(args.manualSeed) 106 | np.random.seed(args.manualSeed) 107 | torch.manual_seed(args.manualSeed) 108 | if cfg.CUDA: 109 | torch.cuda.manual_seed_all(args.manualSeed) 110 | 111 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 112 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 113 | output_dir = 'output/%s_%s_%s' % \ 114 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 115 | 116 | split_dir, bshuffle = 'train', True 117 | if not cfg.TRAIN.FLAG: 118 | # bshuffle = False 119 | split_dir = 'test' 120 | 121 | # Get data loader 122 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) 123 | image_transform = transforms.Compose([ 124 | transforms.Scale(int(imsize * 76 / 64)), 125 | transforms.RandomCrop(imsize), 126 | transforms.RandomHorizontalFlip()]) 127 | dataset = TextDataset(cfg.DATA_DIR, split_dir, 128 | base_size=cfg.TREE.BASE_SIZE, 129 | transform=image_transform) 130 | assert dataset 131 | dataloader = torch.utils.data.DataLoader( 132 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 133 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 134 | 135 | # Define models and go to train/evaluate 136 | trainer_ = getattr(trainer, cfg.TRAIN.TRAINER) 137 | algo = trainer_(output_dir, dataloader, dataset.n_words, dataset.ixtoword) 138 | 139 | start_t = time.time() 140 | if cfg.TRAIN.FLAG: 141 | algo.train() 142 | else: 143 | '''generate images from pre-extracted embeddings''' 144 | if cfg.B_VALIDATION: 145 | algo.sampling(split_dir) # generate images for the whole valid dataset 146 | else: 147 | gen_example(dataset.wordtoix, algo) # generate images for customized captions 148 | end_t = time.time() 149 | print('Total time for training:', end_t - start_t) 150 | -------------------------------------------------------------------------------- /miscc/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from miscc.config import cfg 7 | 8 | from GlobalAttention import func_attention 9 | 10 | ####################Loss for image2text ################### 11 | def image_to_text_loss(output, target): 12 | # bs x T x vocab_size - > bs * T x vocab_size 13 | bs, T, vocab_size = output.shape 14 | output = output.view(-1, vocab_size) 15 | # bs x T -> bs * T 16 | target = target.view(-1) 17 | return F.cross_entropy(output, target) 18 | 19 | 20 | ####################Loss for matching text-image################### 21 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 22 | """Returns cosine similarity between x1 and x2, computed along dim. 23 | """ 24 | w12 = torch.sum(x1 * x2, dim) 25 | w1 = torch.norm(x1, 2, dim) 26 | w2 = torch.norm(x2, 2, dim) 27 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 28 | 29 | 30 | def sent_loss(cnn_code, rnn_code, labels, class_ids, 31 | batch_size, eps=1e-8): 32 | # ### Mask mis-match samples ### 33 | # that come from the same class as the real sample ### 34 | masks = [] 35 | if class_ids is not None: 36 | for i in range(batch_size): 37 | mask = (class_ids == class_ids[i]).astype(np.uint8) 38 | mask[i] = 0 39 | masks.append(mask.reshape((1, -1))) 40 | masks = np.concatenate(masks, 0) 41 | # masks: batch_size x batch_size 42 | masks = torch.ByteTensor(masks) 43 | if cfg.CUDA: 44 | masks = masks.cuda() 45 | 46 | # --> seq_len x batch_size x nef 47 | if cnn_code.dim() == 2: 48 | cnn_code = cnn_code.unsqueeze(0) 49 | rnn_code = rnn_code.unsqueeze(0) 50 | 51 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 52 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) 53 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) 54 | # scores* / norm*: seq_len x batch_size x batch_size 55 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) 56 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) 57 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 58 | 59 | # --> batch_size x batch_size 60 | scores0 = scores0.squeeze() 61 | if class_ids is not None: 62 | scores0.data.masked_fill_(masks, -float('inf')) 63 | scores1 = scores0.transpose(0, 1) 64 | if labels is not None: 65 | loss0 = nn.CrossEntropyLoss()(scores0, labels) 66 | loss1 = nn.CrossEntropyLoss()(scores1, labels) 67 | else: 68 | loss0, loss1 = None, None 69 | return loss0, loss1 70 | 71 | 72 | def words_loss(img_features, words_emb, labels, 73 | cap_lens, class_ids, batch_size): 74 | """ 75 | words_emb(query): batch x nef x seq_len 76 | img_features(context): batch x nef x 17 x 17 77 | """ 78 | masks = [] 79 | att_maps = [] 80 | similarities = [] 81 | cap_lens = cap_lens.data.tolist() 82 | for i in range(batch_size): 83 | if class_ids is not None: 84 | mask = (class_ids == class_ids[i]).astype(np.uint8) 85 | mask[i] = 0 86 | masks.append(mask.reshape((1, -1))) 87 | # Get the i-th text description 88 | words_num = cap_lens[i] 89 | # -> 1 x nef x words_num 90 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() 91 | # -> batch_size x nef x words_num 92 | word = word.repeat(batch_size, 1, 1) 93 | # batch x nef x 17*17 94 | context = img_features 95 | """ 96 | word(query): batch x nef x words_num 97 | context: batch x nef x 17 x 17 98 | weiContext: batch x nef x words_num 99 | attn: batch x words_num x 17 x 17 100 | """ 101 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) 102 | att_maps.append(attn[i].unsqueeze(0).contiguous()) 103 | # --> batch_size x words_num x nef 104 | word = word.transpose(1, 2).contiguous() 105 | weiContext = weiContext.transpose(1, 2).contiguous() 106 | # --> batch_size*words_num x nef 107 | word = word.view(batch_size * words_num, -1) 108 | weiContext = weiContext.view(batch_size * words_num, -1) 109 | # 110 | # -->batch_size*words_num 111 | row_sim = cosine_similarity(word, weiContext) 112 | # --> batch_size x words_num 113 | row_sim = row_sim.view(batch_size, words_num) 114 | 115 | # Eq. (10) 116 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() 117 | row_sim = row_sim.sum(dim=1, keepdim=True) 118 | row_sim = torch.log(row_sim) 119 | 120 | # --> 1 x batch_size 121 | # similarities(i, j): the similarity between the i-th image and the j-th text description 122 | similarities.append(row_sim) 123 | 124 | # batch_size x batch_size 125 | similarities = torch.cat(similarities, 1) 126 | if class_ids is not None: 127 | masks = np.concatenate(masks, 0) 128 | # masks: batch_size x batch_size 129 | masks = torch.ByteTensor(masks) 130 | if cfg.CUDA: 131 | masks = masks.cuda() 132 | 133 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 134 | if class_ids is not None: 135 | similarities.data.masked_fill_(masks, -float('inf')) 136 | similarities1 = similarities.transpose(0, 1) 137 | if labels is not None: 138 | loss0 = nn.CrossEntropyLoss()(similarities, labels) 139 | loss1 = nn.CrossEntropyLoss()(similarities1, labels) 140 | else: 141 | loss0, loss1 = None, None 142 | return loss0, loss1, att_maps 143 | 144 | 145 | # ##################Loss for G and Ds############################## 146 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions, 147 | real_labels, fake_labels): 148 | # Forward 149 | real_features = netD(real_imgs) 150 | fake_features = netD(fake_imgs.detach()) 151 | # loss 152 | # 153 | cond_real_logits = netD.COND_DNET(real_features, conditions) 154 | cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) 155 | cond_fake_logits = netD.COND_DNET(fake_features, conditions) 156 | cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) 157 | # 158 | batch_size = real_features.size(0) 159 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) 160 | cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) 161 | 162 | if netD.UNCOND_DNET is not None: 163 | real_logits = netD.UNCOND_DNET(real_features) 164 | fake_logits = netD.UNCOND_DNET(fake_features) 165 | real_errD = nn.BCELoss()(real_logits, real_labels) 166 | fake_errD = nn.BCELoss()(fake_logits, fake_labels) 167 | errD = ((real_errD + cond_real_errD) / 2. + 168 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) 169 | else: 170 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. 171 | return errD 172 | 173 | 174 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels, 175 | words_embs, sent_emb, match_labels, 176 | cap_lens, class_ids): 177 | numDs = len(netsD) 178 | batch_size = real_labels.size(0) 179 | logs = '' 180 | # Forward 181 | errG_total = 0 182 | for i in range(numDs): 183 | features = netsD[i](fake_imgs[i]) 184 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 185 | cond_errG = nn.BCELoss()(cond_logits, real_labels) 186 | if netsD[i].UNCOND_DNET is not None: 187 | logits = netsD[i].UNCOND_DNET(features) 188 | errG = nn.BCELoss()(logits, real_labels) 189 | g_loss = errG + cond_errG 190 | else: 191 | g_loss = cond_errG 192 | errG_total += g_loss 193 | # err_img = errG_total.item() 194 | logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) 195 | 196 | # Ranking loss 197 | if i == (numDs - 1): 198 | # words_features: batch_size x nef x 17 x 17 199 | # sent_code: batch_size x nef 200 | region_features, cnn_code = image_encoder(fake_imgs[i]) 201 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 202 | match_labels, cap_lens, 203 | class_ids, batch_size) 204 | w_loss = (w_loss0 + w_loss1) * \ 205 | cfg.TRAIN.SMOOTH.LAMBDA 206 | # err_words = err_words + w_loss.item() 207 | 208 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 209 | match_labels, class_ids, batch_size) 210 | s_loss = (s_loss0 + s_loss1) * \ 211 | cfg.TRAIN.SMOOTH.LAMBDA 212 | # err_sent = err_sent + s_loss.item() 213 | 214 | errG_total += w_loss + s_loss 215 | logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) 216 | return errG_total, logs 217 | 218 | def cycle_generator_loss(netsD, image_encoder, fake_imgs, real_labels, captions, 219 | words_embs, sent_emb, match_labels, 220 | cap_lens, class_ids): 221 | numDs = len(netsD) 222 | batch_size = real_labels.size(0) 223 | logs = '' 224 | # Forward 225 | errG_total = 0 226 | for i in range(numDs): 227 | features = netsD[i](fake_imgs[i]) 228 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 229 | cond_errG = nn.BCELoss()(cond_logits, real_labels) 230 | if netsD[i].UNCOND_DNET is not None: 231 | logits = netsD[i].UNCOND_DNET(features) 232 | errG = nn.BCELoss()(logits, real_labels) 233 | g_loss = errG + cond_errG 234 | else: 235 | g_loss = cond_errG 236 | errG_total += g_loss 237 | # err_img = errG_total.item() 238 | logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) 239 | 240 | # Ranking loss 241 | if i == (numDs - 1): 242 | # words_features: batch_size x nef x 17 x 17 243 | # sent_code: batch_size x nef 244 | region_features, cnn_code, word_logits = image_encoder(fake_imgs[i], captions) 245 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 246 | match_labels, cap_lens, 247 | class_ids, batch_size) 248 | w_loss = (w_loss0 + w_loss1) * \ 249 | cfg.TRAIN.SMOOTH.LAMBDA 250 | # err_words = err_words + w_loss.item() 251 | 252 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 253 | match_labels, class_ids, batch_size) 254 | s_loss = (s_loss0 + s_loss1) * \ 255 | cfg.TRAIN.SMOOTH.LAMBDA 256 | # err_sent = err_sent + s_loss.item() 257 | 258 | t_loss = image_to_text_loss(word_logits, captions) * cfg.TRAIN.SMOOTH.LAMBDA 259 | 260 | errG_total += w_loss + s_loss + t_loss 261 | logs += 'w_loss: %.2f s_loss: %.2f t_loss: %.2f' % (w_loss.item(), s_loss.item(), t_loss.item()) 262 | return errG_total, logs 263 | 264 | 265 | ################################################################## 266 | def KL_loss(mu, logvar): 267 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 268 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 269 | KLD = torch.mean(KLD_element).mul_(-0.5) 270 | return KLD 271 | -------------------------------------------------------------------------------- /pretrain_DAMSM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on 3 | https://github.com/taoxugit/AttnGAN/blob/master/code/pretrain_DAMSM.py 4 | ''' 5 | 6 | from __future__ import print_function 7 | 8 | from miscc.utils import mkdir_p 9 | from miscc.utils import build_super_images 10 | from miscc.losses import sent_loss, words_loss 11 | from miscc.config import cfg, cfg_from_file 12 | 13 | from datasets import TextDataset 14 | from datasets import prepare_data 15 | 16 | from model import RNN_ENCODER, CNN_ENCODER 17 | 18 | import os 19 | import sys 20 | import time 21 | import random 22 | import pprint 23 | import datetime 24 | import dateutil.tz 25 | import argparse 26 | import numpy as np 27 | from PIL import Image 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.optim as optim 32 | from torch.autograd import Variable 33 | import torch.backends.cudnn as cudnn 34 | import torchvision.transforms as transforms 35 | 36 | 37 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 38 | sys.path.append(dir_path) 39 | 40 | 41 | UPDATE_INTERVAL = 200 42 | def parse_args(): 43 | parser = argparse.ArgumentParser(description='Train a DAMSM network') 44 | parser.add_argument('--cfg', dest='cfg_file', 45 | help='optional config file', 46 | default='cfg/DAMSM/bird.yaml', type=str) 47 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) 48 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='data/birds') 49 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def train(dataloader, cnn_model, rnn_model, batch_size, 55 | labels, optimizer, epoch, ixtoword, image_dir): 56 | cnn_model.train() 57 | rnn_model.train() 58 | s_total_loss0 = 0 59 | s_total_loss1 = 0 60 | w_total_loss0 = 0 61 | w_total_loss1 = 0 62 | count = (epoch + 1) * len(dataloader) 63 | start_time = time.time() 64 | for step, data in enumerate(dataloader, 0): 65 | # print('step', step) 66 | rnn_model.zero_grad() 67 | cnn_model.zero_grad() 68 | 69 | imgs, captions, cap_lens, \ 70 | class_ids, keys = prepare_data(data) 71 | 72 | 73 | # words_features: batch_size x nef x 17 x 17 74 | # sent_code: batch_size x nef 75 | words_features, sent_code = cnn_model(imgs[-1]) 76 | # --> batch_size x nef x 17*17 77 | nef, att_sze = words_features.size(1), words_features.size(2) 78 | # words_features = words_features.view(batch_size, nef, -1) 79 | 80 | hidden = rnn_model.init_hidden(batch_size) 81 | # words_emb: batch_size x nef x seq_len 82 | # sent_emb: batch_size x nef 83 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 84 | 85 | w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, 86 | cap_lens, class_ids, batch_size) 87 | w_total_loss0 += w_loss0.data 88 | w_total_loss1 += w_loss1.data 89 | loss = w_loss0 + w_loss1 90 | 91 | s_loss0, s_loss1 = \ 92 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 93 | loss += s_loss0 + s_loss1 94 | s_total_loss0 += s_loss0.data 95 | s_total_loss1 += s_loss1.data 96 | # 97 | loss.backward() 98 | # 99 | # `clip_grad_norm` helps prevent 100 | # the exploding gradient problem in RNNs / LSTMs. 101 | torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), 102 | cfg.TRAIN.RNN_GRAD_CLIP) 103 | optimizer.step() 104 | 105 | if step % UPDATE_INTERVAL == 0: 106 | count = epoch * len(dataloader) + step 107 | 108 | s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL 109 | s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL 110 | 111 | w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL 112 | w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL 113 | 114 | elapsed = time.time() - start_time 115 | print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 116 | 's_loss {:5.2f} {:5.2f} | ' 117 | 'w_loss {:5.2f} {:5.2f}' 118 | .format(epoch, step, len(dataloader), 119 | elapsed * 1000. / UPDATE_INTERVAL, 120 | s_cur_loss0, s_cur_loss1, 121 | w_cur_loss0, w_cur_loss1)) 122 | s_total_loss0 = 0 123 | s_total_loss1 = 0 124 | w_total_loss0 = 0 125 | w_total_loss1 = 0 126 | start_time = time.time() 127 | # attention Maps 128 | img_set, _ = \ 129 | build_super_images(imgs[-1].cpu(), captions, 130 | ixtoword, attn_maps, att_sze) 131 | if img_set is not None: 132 | im = Image.fromarray(img_set) 133 | fullpath = '%s/attention_maps%d.png' % (image_dir, step) 134 | im.save(fullpath) 135 | return count 136 | 137 | 138 | def evaluate(dataloader, cnn_model, rnn_model, batch_size): 139 | cnn_model.eval() 140 | rnn_model.eval() 141 | s_total_loss = 0 142 | w_total_loss = 0 143 | for step, data in enumerate(dataloader, 0): 144 | real_imgs, captions, cap_lens, \ 145 | class_ids, keys = prepare_data(data) 146 | 147 | words_features, sent_code = cnn_model(real_imgs[-1]) 148 | # nef = words_features.size(1) 149 | # words_features = words_features.view(batch_size, nef, -1) 150 | 151 | hidden = rnn_model.init_hidden(batch_size) 152 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 153 | 154 | w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, 155 | cap_lens, class_ids, batch_size) 156 | w_total_loss += (w_loss0 + w_loss1).data 157 | 158 | s_loss0, s_loss1 = \ 159 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 160 | s_total_loss += (s_loss0 + s_loss1).data 161 | 162 | if step == 50: 163 | break 164 | 165 | s_cur_loss = s_total_loss.item() / step 166 | w_cur_loss = w_total_loss.item() / step 167 | 168 | return s_cur_loss, w_cur_loss 169 | 170 | 171 | def build_models(): 172 | # build model ############################################################ 173 | text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 174 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) 175 | labels = Variable(torch.LongTensor(range(batch_size))) 176 | start_epoch = 0 177 | if cfg.TRAIN.NET_E != '': 178 | state_dict = torch.load(cfg.TRAIN.NET_E) 179 | text_encoder.load_state_dict(state_dict) 180 | print('Load ', cfg.TRAIN.NET_E) 181 | # 182 | name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 183 | state_dict = torch.load(name) 184 | image_encoder.load_state_dict(state_dict) 185 | print('Load ', name) 186 | 187 | istart = cfg.TRAIN.NET_E.rfind('_') + 8 188 | iend = cfg.TRAIN.NET_E.rfind('.') 189 | start_epoch = cfg.TRAIN.NET_E[istart:iend] 190 | start_epoch = int(start_epoch) + 1 191 | print('start_epoch', start_epoch) 192 | if cfg.CUDA: 193 | text_encoder = text_encoder.cuda() 194 | image_encoder = image_encoder.cuda() 195 | labels = labels.cuda() 196 | 197 | return text_encoder, image_encoder, labels, start_epoch 198 | 199 | 200 | if __name__ == "__main__": 201 | args = parse_args() 202 | if args.cfg_file is not None: 203 | cfg_from_file(args.cfg_file) 204 | 205 | if args.gpu_id == -1: 206 | cfg.CUDA = False 207 | else: 208 | cfg.GPU_ID = args.gpu_id 209 | 210 | if args.data_dir != '': 211 | cfg.DATA_DIR = args.data_dir 212 | print('Using config:') 213 | pprint.pprint(cfg) 214 | 215 | if not cfg.TRAIN.FLAG: 216 | args.manualSeed = 100 217 | elif args.manualSeed is None: 218 | args.manualSeed = random.randint(1, 10000) 219 | random.seed(args.manualSeed) 220 | np.random.seed(args.manualSeed) 221 | torch.manual_seed(args.manualSeed) 222 | if cfg.CUDA: 223 | torch.cuda.manual_seed_all(args.manualSeed) 224 | 225 | ########################################################################## 226 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 227 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 228 | output_dir = 'output/%s_%s_%s' % \ 229 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 230 | 231 | model_dir = os.path.join(output_dir, 'Model') 232 | image_dir = os.path.join(output_dir, 'Image') 233 | mkdir_p(model_dir) 234 | mkdir_p(image_dir) 235 | 236 | torch.cuda.set_device(cfg.GPU_ID) 237 | cudnn.benchmark = True 238 | 239 | # Get data loader ################################################## 240 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) 241 | batch_size = cfg.TRAIN.BATCH_SIZE 242 | image_transform = transforms.Compose([ 243 | transforms.Scale(int(imsize * 76 / 64)), 244 | transforms.RandomCrop(imsize), 245 | transforms.RandomHorizontalFlip()]) 246 | dataset = TextDataset(cfg.DATA_DIR, 'train', 247 | base_size=cfg.TREE.BASE_SIZE, 248 | transform=image_transform) 249 | 250 | print(dataset.n_words, dataset.embeddings_num) 251 | assert dataset 252 | dataloader = torch.utils.data.DataLoader( 253 | dataset, batch_size=batch_size, drop_last=True, 254 | shuffle=True, num_workers=int(cfg.WORKERS)) 255 | 256 | # # validation data # 257 | dataset_val = TextDataset(cfg.DATA_DIR, 'test', 258 | base_size=cfg.TREE.BASE_SIZE, 259 | transform=image_transform) 260 | dataloader_val = torch.utils.data.DataLoader( 261 | dataset_val, batch_size=batch_size, drop_last=True, 262 | shuffle=True, num_workers=int(cfg.WORKERS)) 263 | 264 | # Train ############################################################## 265 | text_encoder, image_encoder, labels, start_epoch = build_models() 266 | para = list(text_encoder.parameters()) 267 | for v in image_encoder.parameters(): 268 | if v.requires_grad: 269 | para.append(v) 270 | # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) 271 | # At any point you can hit Ctrl + C to break out of training early. 272 | try: 273 | lr = cfg.TRAIN.ENCODER_LR 274 | for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): 275 | optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) 276 | epoch_start_time = time.time() 277 | count = train(dataloader, image_encoder, text_encoder, 278 | batch_size, labels, optimizer, epoch, 279 | dataset.ixtoword, image_dir) 280 | print('-' * 89) 281 | if len(dataloader_val) > 0: 282 | s_loss, w_loss = evaluate(dataloader_val, image_encoder, 283 | text_encoder, batch_size) 284 | print('| end epoch {:3d} | valid loss ' 285 | '{:5.2f} {:5.2f} | lr {:.5f}|' 286 | .format(epoch, s_loss, w_loss, lr)) 287 | print('-' * 89) 288 | if lr > cfg.TRAIN.ENCODER_LR/10.: 289 | lr *= 0.98 290 | 291 | if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or 292 | epoch == cfg.TRAIN.MAX_EPOCH): 293 | torch.save(image_encoder.state_dict(), 294 | '%s/image_encoder%d.pth' % (model_dir, epoch)) 295 | torch.save(text_encoder.state_dict(), 296 | '%s/text_encoder%d.pth' % (model_dir, epoch)) 297 | print('Save G/Ds models.') 298 | except KeyboardInterrupt: 299 | print('-' * 89) 300 | print('Exiting from training early') 301 | -------------------------------------------------------------------------------- /miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from copy import deepcopy 11 | import skimage.transform 12 | 13 | from miscc.config import cfg 14 | 15 | 16 | # For visualization ################################################ 17 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 18 | 2:[70, 70, 70], 3:[102,102,156], 19 | 4:[190,153,153], 5:[153,153,153], 20 | 6:[250,170, 30], 7:[220, 220, 0], 21 | 8:[107,142, 35], 9:[152,251,152], 22 | 10:[70,130,180], 11:[220,20, 60], 23 | 12:[255, 0, 0], 13:[0, 0, 142], 24 | 14:[119,11, 32], 15:[0, 60,100], 25 | 16:[0, 80, 100], 17:[0, 0, 230], 26 | 18:[0, 0, 70], 19:[0, 0, 0]} 27 | FONT_MAX = 50 28 | 29 | 30 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 31 | num = captions.size(0) 32 | img_txt = Image.fromarray(convas) 33 | # get a font 34 | fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 35 | # fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 36 | # get a drawing context 37 | d = ImageDraw.Draw(img_txt) 38 | sentence_list = [] 39 | for i in range(num): 40 | cap = captions[i].data.cpu().numpy() 41 | sentence = [] 42 | for j in range(len(cap)): 43 | if cap[j] == 0: 44 | break 45 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 46 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 47 | font=fnt, fill=(255, 255, 255, 255)) 48 | sentence.append(word) 49 | sentence_list.append(sentence) 50 | return img_txt, sentence_list 51 | 52 | 53 | def build_super_images(real_imgs, captions, ixtoword, 54 | attn_maps, att_sze, lr_imgs=None, 55 | batch_size=cfg.TRAIN.BATCH_SIZE, 56 | max_word_num=cfg.TEXT.WORDS_NUM): 57 | nvis = 8 58 | real_imgs = real_imgs[:nvis] 59 | if lr_imgs is not None: 60 | lr_imgs = lr_imgs[:nvis] 61 | if att_sze == 17: 62 | vis_size = att_sze * 16 63 | else: 64 | vis_size = real_imgs.size(2) 65 | 66 | text_convas = \ 67 | np.ones([batch_size * FONT_MAX, 68 | (max_word_num + 2) * (vis_size + 2), 3], 69 | dtype=np.uint8) 70 | 71 | for i in range(max_word_num): 72 | istart = (i + 2) * (vis_size + 2) 73 | iend = (i + 3) * (vis_size + 2) 74 | text_convas[:, istart:iend, :] = COLOR_DIC[i] 75 | 76 | 77 | real_imgs = \ 78 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 79 | # [-1, 1] --> [0, 1] 80 | real_imgs.add_(1).div_(2).mul_(255) 81 | real_imgs = real_imgs.data.numpy() 82 | # b x c x h x w --> b x h x w x c 83 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 84 | pad_sze = real_imgs.shape 85 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 86 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) 87 | if lr_imgs is not None: 88 | lr_imgs = \ 89 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) 90 | # [-1, 1] --> [0, 1] 91 | lr_imgs.add_(1).div_(2).mul_(255) 92 | lr_imgs = lr_imgs.data.numpy() 93 | # b x c x h x w --> b x h x w x c 94 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) 95 | 96 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 97 | seq_len = max_word_num 98 | img_set = [] 99 | num = nvis # len(attn_maps) 100 | 101 | text_map, sentences = \ 102 | drawCaption(text_convas, captions, ixtoword, vis_size) 103 | text_map = np.asarray(text_map).astype(np.uint8) 104 | 105 | bUpdate = 1 106 | for i in range(num): 107 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 108 | # --> 1 x 1 x 17 x 17 109 | attn_max = attn.max(dim=1, keepdim=True) 110 | attn = torch.cat([attn_max[0], attn], 1) 111 | # 112 | attn = attn.view(-1, 1, att_sze, att_sze) 113 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 114 | # n x c x h x w --> n x h x w x c 115 | attn = np.transpose(attn, (0, 2, 3, 1)) 116 | num_attn = attn.shape[0] 117 | # 118 | img = real_imgs[i] 119 | if lr_imgs is None: 120 | lrI = img 121 | else: 122 | lrI = lr_imgs[i] 123 | row = [lrI, middle_pad] 124 | row_merge = [img, middle_pad] 125 | row_beforeNorm = [] 126 | minVglobal, maxVglobal = 1, 0 127 | for j in range(num_attn): 128 | one_map = attn[j] 129 | if (vis_size // att_sze) > 1: 130 | one_map = \ 131 | skimage.transform.pyramid_expand(one_map, sigma=20, 132 | upscale=vis_size // att_sze) 133 | row_beforeNorm.append(one_map) 134 | minV = one_map.min() 135 | maxV = one_map.max() 136 | if minVglobal > minV: 137 | minVglobal = minV 138 | if maxVglobal < maxV: 139 | maxVglobal = maxV 140 | for j in range(seq_len + 1): 141 | if j < num_attn: 142 | one_map = row_beforeNorm[j] 143 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) 144 | one_map *= 255 145 | # 146 | PIL_im = Image.fromarray(np.uint8(img)) 147 | PIL_att = Image.fromarray(np.uint8(one_map)) 148 | merged = \ 149 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 150 | mask = Image.new('L', (vis_size, vis_size), (210)) 151 | merged.paste(PIL_im, (0, 0)) 152 | merged.paste(PIL_att, (0, 0), mask) 153 | merged = np.array(merged)[:, :, :3] 154 | else: 155 | one_map = post_pad 156 | merged = post_pad 157 | row.append(one_map) 158 | row.append(middle_pad) 159 | # 160 | row_merge.append(merged) 161 | row_merge.append(middle_pad) 162 | row = np.concatenate(row, 1) 163 | row_merge = np.concatenate(row_merge, 1) 164 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] 165 | if txt.shape[1] != row.shape[1]: 166 | print('txt', txt.shape, 'row', row.shape) 167 | bUpdate = 0 168 | break 169 | row = np.concatenate([txt, row, row_merge], 0) 170 | img_set.append(row) 171 | if bUpdate: 172 | img_set = np.concatenate(img_set, 0) 173 | img_set = img_set.astype(np.uint8) 174 | return img_set, sentences 175 | else: 176 | return None 177 | 178 | 179 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 180 | attn_maps, att_sze, vis_size=256, topK=5): 181 | batch_size = real_imgs.size(0) 182 | max_word_num = np.max(cap_lens) 183 | text_convas = np.ones([batch_size * FONT_MAX, 184 | max_word_num * (vis_size + 2), 3], 185 | dtype=np.uint8) 186 | 187 | real_imgs = \ 188 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 189 | # [-1, 1] --> [0, 1] 190 | real_imgs.add_(1).div_(2).mul_(255) 191 | real_imgs = real_imgs.data.numpy() 192 | # b x c x h x w --> b x h x w x c 193 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 194 | pad_sze = real_imgs.shape 195 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 196 | 197 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 198 | img_set = [] 199 | num = len(attn_maps) 200 | 201 | text_map, sentences = \ 202 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 203 | text_map = np.asarray(text_map).astype(np.uint8) 204 | 205 | bUpdate = 1 206 | for i in range(num): 207 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 208 | # 209 | attn = attn.view(-1, 1, att_sze, att_sze) 210 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 211 | # n x c x h x w --> n x h x w x c 212 | attn = np.transpose(attn, (0, 2, 3, 1)) 213 | num_attn = cap_lens[i] 214 | thresh = 2./float(num_attn) 215 | # 216 | img = real_imgs[i] 217 | row = [] 218 | row_merge = [] 219 | row_txt = [] 220 | row_beforeNorm = [] 221 | conf_score = [] 222 | for j in range(num_attn): 223 | one_map = attn[j] 224 | mask0 = one_map > (2. * thresh) 225 | conf_score.append(np.sum(one_map * mask0)) 226 | mask = one_map > thresh 227 | one_map = one_map * mask 228 | if (vis_size // att_sze) > 1: 229 | one_map = \ 230 | skimage.transform.pyramid_expand(one_map, sigma=20, 231 | upscale=vis_size // att_sze) 232 | minV = one_map.min() 233 | maxV = one_map.max() 234 | one_map = (one_map - minV) / (maxV - minV) 235 | row_beforeNorm.append(one_map) 236 | sorted_indices = np.argsort(conf_score)[::-1] 237 | 238 | for j in range(num_attn): 239 | one_map = row_beforeNorm[j] 240 | one_map *= 255 241 | # 242 | PIL_im = Image.fromarray(np.uint8(img)) 243 | PIL_att = Image.fromarray(np.uint8(one_map)) 244 | merged = \ 245 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 246 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 247 | merged.paste(PIL_im, (0, 0)) 248 | merged.paste(PIL_att, (0, 0), mask) 249 | merged = np.array(merged)[:, :, :3] 250 | 251 | row.append(np.concatenate([one_map, middle_pad], 1)) 252 | # 253 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 254 | # 255 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 256 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 257 | row_txt.append(txt) 258 | # reorder 259 | row_new = [] 260 | row_merge_new = [] 261 | txt_new = [] 262 | for j in range(num_attn): 263 | idx = sorted_indices[j] 264 | row_new.append(row[idx]) 265 | row_merge_new.append(row_merge[idx]) 266 | txt_new.append(row_txt[idx]) 267 | row = np.concatenate(row_new[:topK], 1) 268 | row_merge = np.concatenate(row_merge_new[:topK], 1) 269 | txt = np.concatenate(txt_new[:topK], 1) 270 | if txt.shape[1] != row.shape[1]: 271 | print('Warnings: txt', txt.shape, 'row', row.shape, 272 | 'row_merge_new', row_merge_new.shape) 273 | bUpdate = 0 274 | break 275 | row = np.concatenate([txt, row_merge], 0) 276 | img_set.append(row) 277 | if bUpdate: 278 | img_set = np.concatenate(img_set, 0) 279 | img_set = img_set.astype(np.uint8) 280 | return img_set, sentences 281 | else: 282 | return None 283 | 284 | 285 | #################################################################### 286 | def weights_init(m): 287 | classname = m.__class__.__name__ 288 | if classname.find('Conv') != -1: 289 | nn.init.orthogonal_(m.weight.data, 1.0) 290 | elif classname.find('BatchNorm') != -1: 291 | m.weight.data.normal_(1.0, 0.02) 292 | m.bias.data.fill_(0) 293 | elif classname.find('Linear') != -1: 294 | nn.init.orthogonal_(m.weight.data, 1.0) 295 | if m.bias is not None: 296 | m.bias.data.fill_(0.0) 297 | 298 | 299 | def load_params(model, new_param): 300 | for p, new_p in zip(model.parameters(), new_param): 301 | p.data.copy_(new_p) 302 | 303 | 304 | def copy_G_params(model): 305 | flatten = deepcopy(list(p.data for p in model.parameters())) 306 | return flatten 307 | 308 | 309 | def mkdir_p(path): 310 | try: 311 | os.makedirs(path) 312 | except OSError as exc: # Python >2.5 313 | if exc.errno == errno.EEXIST and os.path.isdir(path): 314 | pass 315 | else: 316 | raise 317 | -------------------------------------------------------------------------------- /pretrain_STREAM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on 3 | https://github.com/taoxugit/AttnGAN/blob/master/code/pretrain_DAMSM.py 4 | ''' 5 | 6 | from __future__ import print_function 7 | 8 | from miscc.utils import mkdir_p 9 | from miscc.utils import build_super_images 10 | from miscc.losses import sent_loss, words_loss, image_to_text_loss 11 | from miscc.config import cfg, cfg_from_file 12 | 13 | from datasets import TextDataset 14 | from datasets import prepare_data 15 | 16 | from model import BERT_RNN_ENCODER, BERT_CNN_ENCODER_RNN_DECODER 17 | 18 | import os 19 | import sys 20 | import time 21 | import random 22 | import pprint 23 | import datetime 24 | import dateutil.tz 25 | import argparse 26 | import numpy as np 27 | from PIL import Image 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.optim as optim 32 | from torch.autograd import Variable 33 | import torch.backends.cudnn as cudnn 34 | import torchvision.transforms as transforms 35 | 36 | 37 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 38 | sys.path.append(dir_path) 39 | 40 | 41 | UPDATE_INTERVAL = 200 42 | def parse_args(): 43 | parser = argparse.ArgumentParser(description='Train a STREAM network') 44 | parser.add_argument('--cfg', dest='cfg_file', 45 | help='optional config file', 46 | default='cfg/STREAM/bird.yaml', type=str) 47 | parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) 48 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='data/birds') 49 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def train(dataloader, cnn_model, rnn_model, batch_size, 55 | labels, optimizer, epoch, ixtoword, image_dir): 56 | cnn_model.train() 57 | rnn_model.train() 58 | s_total_loss0 = 0 59 | s_total_loss1 = 0 60 | w_total_loss0 = 0 61 | w_total_loss1 = 0 62 | t_total_loss = 0 63 | count = (epoch + 1) * len(dataloader) 64 | start_time = time.time() 65 | for step, data in enumerate(dataloader, 0): 66 | # print('step', step) 67 | rnn_model.zero_grad() 68 | cnn_model.zero_grad() 69 | 70 | imgs, captions, cap_lens, \ 71 | class_ids, keys = prepare_data(data) 72 | 73 | # sent_code: batch_size x nef 74 | words_features, sent_code, word_logits = cnn_model(imgs[-1], captions) 75 | # bs x T x vocab_size 76 | 77 | nef, att_sze = words_features.size(1), words_features.size(2) 78 | # words_features = words_features.view(batch_size, nef, -1) 79 | 80 | hidden = rnn_model.init_hidden(batch_size) 81 | # words_emb: batch_size x nef x seq_len 82 | # sent_emb: batch_size x nef 83 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 84 | 85 | w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, 86 | cap_lens, class_ids, batch_size) 87 | w_total_loss0 += w_loss0.data 88 | w_total_loss1 += w_loss1.data 89 | loss = w_loss0 + w_loss1 90 | 91 | s_loss0, s_loss1 = \ 92 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 93 | loss += s_loss0 + s_loss1 94 | s_total_loss0 += s_loss0.data 95 | s_total_loss1 += s_loss1.data 96 | # 97 | 98 | t_loss = image_to_text_loss(word_logits, captions) 99 | loss += t_loss 100 | t_total_loss += t_loss.data 101 | 102 | loss.backward() 103 | # 104 | # `clip_grad_norm` helps prevent 105 | # the exploding gradient problem in RNNs / LSTMs. 106 | torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), 107 | cfg.TRAIN.RNN_GRAD_CLIP) 108 | optimizer.step() 109 | 110 | if step % UPDATE_INTERVAL == 0: 111 | count = epoch * len(dataloader) + step 112 | 113 | s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL 114 | s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL 115 | 116 | w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL 117 | w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL 118 | 119 | t_curr_loss = t_total_loss.item() / UPDATE_INTERVAL 120 | 121 | elapsed = time.time() - start_time 122 | print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 123 | 's_loss {:5.2f} {:5.2f} | ' 124 | 'w_loss {:5.2f} {:5.2f} | ' 125 | 't_loss {:5.2f}' 126 | .format(epoch, step, len(dataloader), 127 | elapsed * 1000. / UPDATE_INTERVAL, 128 | s_cur_loss0, s_cur_loss1, 129 | w_cur_loss0, w_cur_loss1, 130 | t_curr_loss)) 131 | s_total_loss0 = 0 132 | s_total_loss1 = 0 133 | w_total_loss0 = 0 134 | w_total_loss1 = 0 135 | t_total_loss = 0 136 | start_time = time.time() 137 | # attention Maps 138 | img_set, _ = \ 139 | build_super_images(imgs[-1].cpu(), captions, 140 | ixtoword, attn_maps, att_sze) 141 | if img_set is not None: 142 | im = Image.fromarray(img_set) 143 | fullpath = '%s/attention_maps%d.png' % (image_dir, step) 144 | im.save(fullpath) 145 | return count 146 | 147 | 148 | def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels): 149 | cnn_model.eval() 150 | rnn_model.eval() 151 | s_total_loss = 0 152 | w_total_loss = 0 153 | t_total_loss = 0 154 | for step, data in enumerate(dataloader, 0): 155 | imgs, captions, cap_lens, \ 156 | class_ids, keys = prepare_data(data) 157 | 158 | words_features, sent_code, word_logits = cnn_model(imgs[-1], captions) 159 | # nef = words_features.size(1) 160 | # words_features = words_features.view(batch_size, nef, -1) 161 | 162 | hidden = rnn_model.init_hidden(batch_size) 163 | words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) 164 | 165 | w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, 166 | cap_lens, class_ids, batch_size) 167 | w_total_loss += (w_loss0 + w_loss1).data 168 | 169 | s_loss0, s_loss1 = \ 170 | sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) 171 | s_total_loss += (s_loss0 + s_loss1).data 172 | 173 | t_loss = image_to_text_loss(word_logits, captions) 174 | t_total_loss += t_loss.data 175 | 176 | if step == 50: 177 | break 178 | 179 | s_cur_loss = s_total_loss.item() / step 180 | w_cur_loss = w_total_loss.item() / step 181 | t_cur_loss = t_total_loss.item() / step 182 | 183 | return s_cur_loss, w_cur_loss, t_cur_loss 184 | 185 | 186 | def build_models(): 187 | # build model ############################################################ 188 | text_encoder = BERT_RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 189 | image_encoder = BERT_CNN_ENCODER_RNN_DECODER(cfg.TEXT.EMBEDDING_DIM, cfg.CNN_RNN.HIDDEN_DIM, 190 | dataset.n_words, rec_unit=cfg.RNN_TYPE) 191 | 192 | labels = Variable(torch.LongTensor(range(batch_size))) 193 | start_epoch = 0 194 | if cfg.TRAIN.NET_E != '': 195 | state_dict = torch.load(cfg.TRAIN.NET_E) 196 | text_encoder.load_state_dict(state_dict) 197 | print('Load ', cfg.TRAIN.NET_E) 198 | # 199 | name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 200 | state_dict = torch.load(name) 201 | image_encoder.load_state_dict(state_dict) 202 | print('Load ', name) 203 | 204 | istart = cfg.TRAIN.NET_E.rfind('_') + 8 205 | iend = cfg.TRAIN.NET_E.rfind('.') 206 | start_epoch = cfg.TRAIN.NET_E[istart:iend] 207 | start_epoch = int(start_epoch) + 1 208 | print('start_epoch', start_epoch) 209 | if cfg.CUDA: 210 | text_encoder = text_encoder.cuda() 211 | image_encoder = image_encoder.cuda() 212 | labels = labels.cuda() 213 | 214 | return text_encoder, image_encoder, labels, start_epoch 215 | 216 | 217 | if __name__ == "__main__": 218 | args = parse_args() 219 | if args.cfg_file is not None: 220 | cfg_from_file(args.cfg_file) 221 | 222 | if args.gpu_id == -1: 223 | cfg.CUDA = False 224 | else: 225 | cfg.GPU_ID = args.gpu_id 226 | 227 | if args.data_dir != '': 228 | cfg.DATA_DIR = args.data_dir 229 | print('Using config:') 230 | pprint.pprint(cfg) 231 | 232 | if not cfg.TRAIN.FLAG: 233 | args.manualSeed = 100 234 | elif args.manualSeed is None: 235 | args.manualSeed = random.randint(1, 10000) 236 | random.seed(args.manualSeed) 237 | np.random.seed(args.manualSeed) 238 | torch.manual_seed(args.manualSeed) 239 | if cfg.CUDA: 240 | torch.cuda.manual_seed_all(args.manualSeed) 241 | 242 | ########################################################################## 243 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 244 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 245 | output_dir = 'output/%s_%s_%s' % \ 246 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 247 | 248 | model_dir = os.path.join(output_dir, 'Model') 249 | image_dir = os.path.join(output_dir, 'Image') 250 | mkdir_p(model_dir) 251 | mkdir_p(image_dir) 252 | 253 | torch.cuda.set_device(cfg.GPU_ID) 254 | cudnn.benchmark = True 255 | 256 | # Get data loader ################################################## 257 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) 258 | batch_size = cfg.TRAIN.BATCH_SIZE 259 | image_transform = transforms.Compose([ 260 | transforms.Scale(int(imsize * 76 / 64)), 261 | transforms.RandomCrop(imsize), 262 | transforms.RandomHorizontalFlip()]) 263 | dataset = TextDataset(cfg.DATA_DIR, 'train', 264 | base_size=cfg.TREE.BASE_SIZE, 265 | transform=image_transform) 266 | 267 | print(dataset.n_words, dataset.embeddings_num) 268 | assert dataset 269 | dataloader = torch.utils.data.DataLoader( 270 | dataset, batch_size=batch_size, drop_last=True, 271 | shuffle=True, num_workers=int(cfg.WORKERS)) 272 | 273 | # # validation data # 274 | dataset_val = TextDataset(cfg.DATA_DIR, 'test', 275 | base_size=cfg.TREE.BASE_SIZE, 276 | transform=image_transform) 277 | dataloader_val = torch.utils.data.DataLoader( 278 | dataset_val, batch_size=batch_size, drop_last=True, 279 | shuffle=True, num_workers=int(cfg.WORKERS)) 280 | 281 | # Train ############################################################## 282 | text_encoder, image_encoder, labels, start_epoch = build_models() 283 | para = list(text_encoder.parameters()) 284 | for v in image_encoder.parameters(): 285 | if v.requires_grad: 286 | para.append(v) 287 | # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) 288 | # At any point you can hit Ctrl + C to break out of training early. 289 | try: 290 | lr = cfg.TRAIN.ENCODER_LR 291 | for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): 292 | optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) 293 | epoch_start_time = time.time() 294 | count = train(dataloader, image_encoder, text_encoder, 295 | batch_size, labels, optimizer, epoch, 296 | dataset.ixtoword, image_dir) 297 | print('-' * 89) 298 | if len(dataloader_val) > 0: 299 | s_loss, w_loss, t_loss = evaluate(dataloader_val, image_encoder, 300 | text_encoder, batch_size, labels) 301 | print('| end epoch {:3d} | valid loss ' 302 | '{:5.2f} {:5.2f} {:5.2f} | lr {:.5f}|' 303 | .format(epoch, s_loss, w_loss, t_loss, lr)) 304 | print('-' * 89) 305 | if lr > cfg.TRAIN.ENCODER_LR/10.: 306 | lr *= 0.98 307 | 308 | if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or 309 | epoch == cfg.TRAIN.MAX_EPOCH): 310 | torch.save(image_encoder.state_dict(), 311 | '%s/image_encoder%d.pth' % (model_dir, epoch)) 312 | torch.save(text_encoder.state_dict(), 313 | '%s/text_encoder%d.pth' % (model_dir, epoch)) 314 | print('Save G/Ds models.') 315 | except KeyboardInterrupt: 316 | print('-' * 89) 317 | print('Exiting from training early') 318 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from miscc.config import cfg 8 | from collections import defaultdict 9 | from torchvision import transforms 10 | import torch.utils.data as data 11 | from torch.autograd import Variable 12 | from torch.utils.data.dataset import Dataset 13 | from nltk.tokenize import RegexpTokenizer 14 | from pytorch_pretrained_bert import BertTokenizer 15 | from PIL import Image 16 | 17 | 18 | def prepare_data(data): 19 | imgs, captions, captions_lens, class_ids, keys = data 20 | 21 | # sort data by the length in a decreasing order 22 | sorted_cap_lens, sorted_cap_indices = \ 23 | torch.sort(captions_lens, 0, True) 24 | 25 | real_imgs = [] 26 | for i in range(len(imgs)): 27 | imgs[i] = imgs[i][sorted_cap_indices] 28 | if cfg.CUDA: 29 | real_imgs.append(Variable(imgs[i]).cuda()) 30 | else: 31 | real_imgs.append(Variable(imgs[i])) 32 | 33 | captions = captions[sorted_cap_indices].squeeze() 34 | class_ids = class_ids[sorted_cap_indices].numpy() 35 | # sent_indices = sent_indices[sorted_cap_indices] 36 | keys = [keys[i] for i in sorted_cap_indices.numpy()] 37 | # print('keys', type(keys), keys[-1]) # list 38 | if cfg.CUDA: 39 | captions = Variable(captions).cuda() 40 | sorted_cap_lens = Variable(sorted_cap_lens).cuda() 41 | else: 42 | captions = Variable(captions) 43 | sorted_cap_lens = Variable(sorted_cap_lens) 44 | 45 | return [real_imgs, captions, sorted_cap_lens, 46 | class_ids, keys] 47 | 48 | 49 | def get_imgs(img_path, imsize, bbox=None, 50 | transform=None, normalize=None): 51 | img = Image.open(img_path).convert('RGB') 52 | width, height = img.size 53 | if bbox is not None: 54 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 55 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 56 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 57 | y1 = np.maximum(0, center_y - r) 58 | y2 = np.minimum(height, center_y + r) 59 | x1 = np.maximum(0, center_x - r) 60 | x2 = np.minimum(width, center_x + r) 61 | img = img.crop([x1, y1, x2, y2]) 62 | 63 | if transform is not None: 64 | img = transform(img) 65 | 66 | ret = [] 67 | if cfg.GAN.B_DCGAN: 68 | ret = [normalize(img)] 69 | else: 70 | for i in range(cfg.TREE.BRANCH_NUM): 71 | # print(imsize[i]) 72 | if i < (cfg.TREE.BRANCH_NUM - 1): 73 | re_img = transforms.Scale(imsize[i])(img) 74 | else: 75 | re_img = img 76 | ret.append(normalize(re_img)) 77 | 78 | return ret 79 | 80 | 81 | 82 | class TextDataset(Dataset): 83 | """ 84 | Text Dataset 85 | Based on: 86 | https://github.com/taoxugit/AttnGAN/blob/master/code/datasets.py 87 | """ 88 | tokenizer = RegexpTokenizer(r'\w+') 89 | def __init__(self, data_dir, split='train', 90 | base_size=64, 91 | transform=None, target_transform=None): 92 | self.transform = transform 93 | self.norm = transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 96 | self.target_transform = target_transform 97 | self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE 98 | 99 | self.imsize = [] 100 | for i in range(cfg.TREE.BRANCH_NUM): 101 | self.imsize.append(base_size) 102 | base_size = base_size * 2 103 | 104 | self.data = [] 105 | self.data_dir = data_dir 106 | if data_dir.find('birds') != -1: 107 | self.bbox = self.load_bbox() 108 | else: 109 | self.bbox = None 110 | split_dir = os.path.join(data_dir, split) 111 | 112 | self.filenames, self.captions, self.ixtoword, \ 113 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split) 114 | 115 | self.class_id = self.load_class_id(split_dir, len(self.filenames)) 116 | self.number_example = len(self.filenames) 117 | 118 | def load_bbox(self): 119 | data_dir = self.data_dir 120 | bbox_path = os.path.join(data_dir, 'CUB_200_2011', 'CUB_200_2011', 'bounding_boxes.txt') 121 | df_bounding_boxes = pd.read_csv(bbox_path, 122 | delim_whitespace=True, 123 | header=None).astype(int) 124 | # 125 | filepath = os.path.join(data_dir, 'CUB_200_2011', 'CUB_200_2011', 'images.txt') 126 | df_filenames = \ 127 | pd.read_csv(filepath, delim_whitespace=True, header=None) 128 | filenames = df_filenames[1].tolist() 129 | print('Total filenames: ', len(filenames), filenames[0]) 130 | # 131 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 132 | numImgs = len(filenames) 133 | for i in range(0, numImgs): 134 | # bbox = [x-left, y-top, width, height] 135 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 136 | 137 | key = filenames[i][:-4] 138 | filename_bbox[key] = bbox 139 | # 140 | return filename_bbox 141 | 142 | def load_captions(self, data_dir, filenames): 143 | all_captions = [] 144 | for i in range(len(filenames)): 145 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 146 | with open(cap_path, "r") as f: 147 | captions = f.read().split('\n') 148 | cnt = 0 149 | for cap in captions: 150 | if len(cap) == 0: 151 | continue 152 | # picks out sequences of alphanumeric characters as tokens 153 | # and drops everything else 154 | tokens = self.tokenizer.tokenize(cap.lower()) 155 | # print('tokens', tokens) 156 | if len(tokens) == 0: 157 | print('cap', cap) 158 | continue 159 | 160 | tokens_new = [] 161 | for t in tokens: 162 | t = t.encode('ascii', 'ignore').decode('ascii') 163 | if len(t) > 0: 164 | tokens_new.append(t) 165 | all_captions.append(tokens_new) 166 | cnt += 1 167 | if cnt == self.embeddings_num: 168 | break 169 | if cnt < self.embeddings_num: 170 | print('ERROR: the captions for %s less than %d' 171 | % (filenames[i], cnt)) 172 | 173 | return all_captions 174 | 175 | def build_dictionary(self, train_captions, test_captions): 176 | word_counts = defaultdict(float) 177 | captions = train_captions + test_captions 178 | for sent in captions: 179 | for word in sent: 180 | word_counts[word] += 1 181 | 182 | vocab = [w for w in word_counts if word_counts[w] >= 0] 183 | 184 | ixtoword = {} 185 | ixtoword[0] = '' 186 | wordtoix = {} 187 | wordtoix[''] = 0 188 | ix = 1 189 | for w in vocab: 190 | wordtoix[w] = ix 191 | ixtoword[ix] = w 192 | ix += 1 193 | 194 | train_captions_new = [] 195 | for t in train_captions: 196 | rev = [] 197 | for w in t: 198 | if w in wordtoix: 199 | rev.append(wordtoix[w]) 200 | # rev.append(0) # do not need '' token 201 | train_captions_new.append(rev) 202 | 203 | test_captions_new = [] 204 | for t in test_captions: 205 | rev = [] 206 | for w in t: 207 | if w in wordtoix: 208 | rev.append(wordtoix[w]) 209 | # rev.append(0) # do not need '' token 210 | test_captions_new.append(rev) 211 | 212 | return [train_captions_new, test_captions_new, 213 | ixtoword, wordtoix, len(ixtoword)] 214 | 215 | def load_text_data(self, data_dir, split): 216 | train_names = self.load_filenames(data_dir, 'train') 217 | test_names = self.load_filenames(data_dir, 'test') 218 | filepath = os.path.join(data_dir, 'captions.pickle') 219 | if not os.path.isfile(filepath): 220 | train_captions = self.load_captions(data_dir, train_names) 221 | test_captions = self.load_captions(data_dir, test_names) 222 | 223 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 224 | self.build_dictionary(train_captions, test_captions) 225 | with open(filepath, 'wb') as f: 226 | pickle.dump([train_captions, test_captions, 227 | ixtoword, wordtoix], f, protocol=2) 228 | print('Save to: ', filepath) 229 | else: 230 | with open(filepath, 'rb') as f: 231 | x = pickle.load(f) 232 | train_captions, test_captions = x[0], x[1] 233 | ixtoword, wordtoix = x[2], x[3] 234 | del x 235 | n_words = len(ixtoword) 236 | print('Load from: ', filepath) 237 | if split == 'train': 238 | # a list of list: each list contains 239 | # the indices of words in a sentence 240 | captions = train_captions 241 | filenames = train_names 242 | else: # split=='test' 243 | captions = test_captions 244 | filenames = test_names 245 | return filenames, captions, ixtoword, wordtoix, n_words 246 | 247 | def load_class_id(self, data_dir, total_num): 248 | if os.path.isfile(data_dir + '/class_info.pickle'): 249 | # encoding = latin1 because of incompatability issue b/w python 2 and python 3 250 | class_id = pickle.load(open(data_dir + '/class_info.pickle', 'rb'), encoding='latin1') 251 | else: 252 | class_id = np.arange(total_num) 253 | return class_id 254 | 255 | def load_filenames(self, data_dir, split): 256 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 257 | if os.path.isfile(filepath): 258 | with open(filepath, 'rb') as f: 259 | filenames = pickle.load(f) 260 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 261 | else: 262 | filenames = [] 263 | return filenames 264 | 265 | def get_caption(self, sent_ix): 266 | # a list of indices for a sentence 267 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') 268 | if (sent_caption == 0).sum() > 0: 269 | print('ERROR: do not need END (0) token', sent_caption) 270 | num_words = len(sent_caption) 271 | # pad with 0s (i.e., '') 272 | x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') 273 | x_len = num_words 274 | if num_words <= cfg.TEXT.WORDS_NUM: 275 | x[:num_words, 0] = sent_caption 276 | else: 277 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum 278 | np.random.shuffle(ix) 279 | ix = ix[:cfg.TEXT.WORDS_NUM] 280 | ix = np.sort(ix) 281 | x[:, 0] = sent_caption[ix] 282 | x_len = cfg.TEXT.WORDS_NUM 283 | return x, x_len 284 | 285 | def get_imgs(self, img_path, imsize, bbox=None, 286 | transform=None, normalize=None): 287 | img = Image.open(img_path).convert('RGB') 288 | width, height = img.size 289 | if bbox is not None: 290 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 291 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 292 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 293 | y1 = np.maximum(0, center_y - r) 294 | y2 = np.minimum(height, center_y + r) 295 | x1 = np.maximum(0, center_x - r) 296 | x2 = np.minimum(width, center_x + r) 297 | img = img.crop([x1, y1, x2, y2]) 298 | 299 | if transform is not None: 300 | img = transform(img) 301 | 302 | ret = [] 303 | if cfg.GAN.B_DCGAN: 304 | ret = [normalize(img)] 305 | else: 306 | for i in range(cfg.TREE.BRANCH_NUM): 307 | # print(imsize[i]) 308 | if i < (cfg.TREE.BRANCH_NUM - 1): 309 | re_img = transforms.Scale(imsize[i])(img) 310 | else: 311 | re_img = img 312 | ret.append(normalize(re_img)) 313 | 314 | return ret 315 | 316 | def __getitem__(self, index): 317 | # 318 | key = self.filenames[index] 319 | cls_id = self.class_id[index] 320 | # 321 | if self.bbox is not None: 322 | bbox = self.bbox[key] 323 | data_dir = '%s/CUB_200_2011/CUB_200_2011' % self.data_dir 324 | else: 325 | bbox = None 326 | data_dir = self.data_dir 327 | # 328 | img_name = '%s/images/%s.jpg' % (data_dir, key) 329 | imgs = self.get_imgs(img_name, self.imsize, 330 | bbox, self.transform, normalize=self.norm) 331 | # random select a sentence 332 | sent_ix = np.random.randint(0, self.embeddings_num) 333 | new_sent_ix = index * self.embeddings_num + sent_ix 334 | caps, cap_len = self.get_caption(new_sent_ix) 335 | return imgs, caps, cap_len, cls_id, key 336 | 337 | 338 | def __len__(self): 339 | return len(self.filenames) 340 | 341 | 342 | class TextBertDataset(TextDataset): 343 | """ 344 | Text dataset on Bert 345 | https://github.com/huggingface/pytorch-pretrained-BERT 346 | """ 347 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 348 | def __init__(self, *args, **kwargs): 349 | super().__init__(*args, **kwargs) # Load pre-trained model tokenizer (vocabulary) 350 | 351 | def load_captions(self, data_dir, filenames): 352 | all_captions = [] 353 | for i in range(len(filenames)): 354 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 355 | with open(cap_path, "r") as f: 356 | captions = f.read().split('\n') 357 | cnt = 0 358 | for cap in captions: 359 | if len(cap) == 0: 360 | continue 361 | # picks out sequences of alphanumeric characters as tokens 362 | # and drops everything else 363 | tokens = self.tokenizer.tokenize(cap.lower()) 364 | # print('tokens', tokens) 365 | if len(tokens) == 0: 366 | print('cap', cap) 367 | continue 368 | 369 | tokens_new = [] 370 | for t in tokens: 371 | t = t.encode('ascii', 'ignore').decode('ascii') 372 | if len(t) > 0: 373 | tokens_new.append(t) 374 | all_captions.append(tokens_new) 375 | cnt += 1 376 | if cnt == self.embeddings_num: 377 | break 378 | if cnt < self.embeddings_num: 379 | print('ERROR: the captions for %s less than %d' 380 | % (filenames[i], cnt)) 381 | 382 | return all_captions 383 | 384 | def load_text_data(self, data_dir, split): 385 | train_names = self.load_filenames(data_dir, 'train') 386 | test_names = self.load_filenames(data_dir, 'test') 387 | filepath = os.path.join(data_dir, 'bert_captions.pickle') 388 | if not os.path.isfile(filepath): 389 | train_captions = self.load_captions(data_dir, train_names) 390 | test_captions = self.load_captions(data_dir, test_names) 391 | 392 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 393 | self.build_dictionary(train_captions, test_captions) 394 | with open(filepath, 'wb') as f: 395 | pickle.dump([train_captions, test_captions, 396 | ixtoword, wordtoix], f, protocol=2) 397 | print('Save to: ', filepath) 398 | else: 399 | with open(filepath, 'rb') as f: 400 | x = pickle.load(f) 401 | train_captions, test_captions = x[0], x[1] 402 | ixtoword, wordtoix = x[2], x[3] 403 | del x 404 | n_words = len(ixtoword) 405 | print('Load from: ', filepath) 406 | if split == 'train': 407 | # a list of list: each list contains 408 | # the indices of words in a sentence 409 | captions = train_captions 410 | filenames = train_names 411 | else: # split=='test' 412 | captions = test_captions 413 | filenames = test_names 414 | return filenames, captions, ixtoword, wordtoix, n_words 415 | 416 | 417 | def build_dictionary(self, train_captions, test_captions): 418 | """ 419 | Tokenize according to bert model 420 | """ 421 | captions = train_captions + test_captions 422 | ixtoword = {} 423 | wordtoix = {} 424 | 425 | 426 | train_captions_new = [] 427 | for sent in train_captions: 428 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(sent) 429 | train_captions_new.append(indexed_tokens) 430 | for idx, word in zip(indexed_tokens, sent): 431 | wordtoix[word] = idx 432 | ixtoword[idx] = word 433 | 434 | test_captions_new = [] 435 | for sent in test_captions: 436 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(sent) 437 | test_captions_new.append(indexed_tokens) 438 | for idx, word in zip(indexed_tokens, sent): 439 | wordtoix[word] = idx 440 | ixtoword[idx] = word 441 | 442 | return [train_captions_new, test_captions_new, 443 | ixtoword, wordtoix, len(ixtoword)] 444 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on 3 | https://github.com/taoxugit/AttnGAN/blob/master/code/model.py 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torch.autograd import Variable 9 | from torchvision import models 10 | import torch.utils.model_zoo as model_zoo 11 | import torch.nn.functional as F 12 | 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | 15 | from miscc.config import cfg 16 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET 17 | from pytorch_pretrained_bert import BertModel 18 | 19 | 20 | class Upsample(nn.Module): 21 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): 22 | super().__init__() 23 | self.size = size 24 | self.scale_factor = float(scale_factor) if scale_factor else None 25 | self.mode = mode 26 | self.align_corners = align_corners 27 | 28 | def forward(self, x): 29 | return F.interpolate(x, self.size, self.scale_factor, self.mode, self.align_corners) 30 | 31 | def extra_repr(self): 32 | if self.scale_factor is not None: 33 | info = 'scale_factor=' + str(self.scale_factor) 34 | else: 35 | info = 'size=' + str(self.size) 36 | info += ', mode=' + self.mode 37 | return info 38 | 39 | class GLU(nn.Module): 40 | def __init__(self): 41 | super(GLU, self).__init__() 42 | 43 | def forward(self, x): 44 | nc = x.size(1) 45 | assert nc % 2 == 0, 'channels dont divide 2!' 46 | nc = int(nc/2) 47 | return x[:, :nc] * torch.sigmoid(x[:, nc:]) 48 | 49 | 50 | def conv1x1(in_planes, out_planes, bias=False): 51 | "1x1 convolution with padding" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 53 | padding=0, bias=bias) 54 | 55 | 56 | def conv3x3(in_planes, out_planes): 57 | "3x3 convolution with padding" 58 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | 61 | 62 | # Upsale the spatial size by a factor of 2 63 | def upBlock(in_planes, out_planes): 64 | block = nn.Sequential( 65 | Upsample(scale_factor=2, mode='nearest'), 66 | conv3x3(in_planes, out_planes * 2), 67 | nn.BatchNorm2d(out_planes * 2), 68 | GLU()) 69 | return block 70 | 71 | 72 | # Keep the spatial size 73 | def Block3x3_relu(in_planes, out_planes): 74 | block = nn.Sequential( 75 | conv3x3(in_planes, out_planes * 2), 76 | nn.BatchNorm2d(out_planes * 2), 77 | GLU()) 78 | return block 79 | 80 | 81 | class ResBlock(nn.Module): 82 | def __init__(self, channel_num): 83 | super(ResBlock, self).__init__() 84 | self.block = nn.Sequential( 85 | conv3x3(channel_num, channel_num * 2), 86 | nn.BatchNorm2d(channel_num * 2), 87 | GLU(), 88 | conv3x3(channel_num, channel_num), 89 | nn.BatchNorm2d(channel_num)) 90 | 91 | def forward(self, x): 92 | residual = x 93 | out = self.block(x) 94 | out += residual 95 | return out 96 | 97 | 98 | # ############## Text2Image Encoder-Decoder ####### 99 | class RNN_ENCODER(nn.Module): 100 | def __init__(self, ntoken, ninput=300, drop_prob=0.5, 101 | nhidden=128, nlayers=1, bidirectional=True): 102 | super(RNN_ENCODER, self).__init__() 103 | self.n_steps = cfg.TEXT.WORDS_NUM 104 | self.ntoken = ntoken # size of the dictionary 105 | self.ninput = ninput # size of each embedding vector 106 | self.drop_prob = drop_prob # probability of an element to be zeroed 107 | self.nlayers = nlayers # Number of recurrent layers 108 | self.bidirectional = bidirectional 109 | self.rnn_type = cfg.RNN_TYPE 110 | if bidirectional: 111 | self.num_directions = 2 112 | else: 113 | self.num_directions = 1 114 | # number of features in the hidden state 115 | self.nhidden = nhidden // self.num_directions 116 | 117 | self.define_module() 118 | self.init_weights() 119 | 120 | def define_module(self): 121 | self.encoder = nn.Embedding(self.ntoken, self.ninput) 122 | self.drop = nn.Dropout(self.drop_prob) 123 | if self.rnn_type == 'LSTM': 124 | # dropout: If non-zero, introduces a dropout layer on 125 | # the outputs of each RNN layer except the last layer 126 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 127 | self.nlayers, batch_first=True, 128 | dropout=self.drop_prob, 129 | bidirectional=self.bidirectional) 130 | elif self.rnn_type == 'GRU': 131 | self.rnn = nn.GRU(self.ninput, self.nhidden, 132 | self.nlayers, batch_first=True, 133 | dropout=self.drop_prob, 134 | bidirectional=self.bidirectional) 135 | else: 136 | raise NotImplementedError 137 | 138 | def init_weights(self): 139 | initrange = 0.1 140 | self.encoder.weight.data.uniform_(-initrange, initrange) 141 | # Do not need to initialize RNN parameters, which have been initialized 142 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM 143 | # self.decoder.weight.data.uniform_(-initrange, initrange) 144 | # self.decoder.bias.data.fill_(0) 145 | 146 | def init_hidden(self, bsz): 147 | weight = next(self.parameters()).data 148 | if self.rnn_type == 'LSTM': 149 | return (Variable(weight.new(self.nlayers * self.num_directions, 150 | bsz, self.nhidden).zero_()), 151 | Variable(weight.new(self.nlayers * self.num_directions, 152 | bsz, self.nhidden).zero_())) 153 | else: 154 | return Variable(weight.new(self.nlayers * self.num_directions, 155 | bsz, self.nhidden).zero_()) 156 | 157 | def forward(self, captions, cap_lens, hidden, mask=None): 158 | # input: torch.LongTensor of size batch x n_steps 159 | # --> emb: batch x n_steps x ninput 160 | emb = self.drop(self.encoder(captions)) 161 | # 162 | # Returns: a PackedSequence object 163 | cap_lens = cap_lens.data.tolist() 164 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True) 165 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 166 | # tensor containing the initial hidden state for each element in batch. 167 | # #output (batch, seq_len, hidden_size * num_directions) 168 | # #or a PackedSequence object: 169 | # tensor containing output features (h_t) from the last layer of RNN 170 | output, hidden = self.rnn(emb, hidden) 171 | # PackedSequence object 172 | # --> (batch, seq_len, hidden_size * num_directions) 173 | output = pad_packed_sequence(output, batch_first=True)[0] 174 | # output = self.drop(output) 175 | # --> batch x hidden_size*num_directions x seq_len 176 | words_emb = output.transpose(1, 2) 177 | # --> batch x num_directions*hidden_size 178 | if self.rnn_type == 'LSTM': 179 | sent_emb = hidden[0].transpose(0, 1).contiguous() 180 | else: 181 | sent_emb = hidden.transpose(0, 1).contiguous() 182 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 183 | return words_emb, sent_emb 184 | 185 | class BERT_RNN_ENCODER(RNN_ENCODER): 186 | def define_module(self): 187 | self.encoder = BertModel.from_pretrained('bert-base-uncased') 188 | for param in self.encoder.parameters(): 189 | param.requires_grad = False 190 | self.bert_linear = nn.Linear(768, self.ninput) 191 | self.drop = nn.Dropout(self.drop_prob) 192 | if self.rnn_type == 'LSTM': 193 | # dropout: If non-zero, introduces a dropout layer on 194 | # the outputs of each RNN layer except the last layer 195 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 196 | self.nlayers, batch_first=True, 197 | dropout=self.drop_prob, 198 | bidirectional=self.bidirectional) 199 | elif self.rnn_type == 'GRU': 200 | self.rnn = nn.GRU(self.ninput, self.nhidden, 201 | self.nlayers, batch_first=True, 202 | dropout=self.drop_prob, 203 | bidirectional=self.bidirectional) 204 | else: 205 | raise NotImplementedError 206 | 207 | def init_weights(self): 208 | initrange = 0.1 209 | self.bert_linear.weight.data.uniform_(-initrange, initrange) 210 | # Do not need to initialize RNN parameters, which have been initialized 211 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM 212 | # self.decoder.weight.data.uniform_(-initrange, initrange) 213 | # self.decoder.bias.data.fill_(0) 214 | 215 | def forward(self, captions, cap_lens, hidden, mask=None): 216 | # input: torch.LongTensor of size batch x n_steps 217 | # --> emb: batch x n_steps x ninput 218 | emb, _ = self.encoder(captions, output_all_encoded_layers=False) 219 | emb = self.bert_linear(emb) 220 | emb = self.drop(emb) 221 | # 222 | # Returns: a PackedSequence object 223 | cap_lens = cap_lens.data.tolist() 224 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True) 225 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 226 | # tensor containing the initial hidden state for each element in batch. 227 | # #output (batch, seq_len, hidden_size * num_directions) 228 | # #or a PackedSequence object: 229 | # tensor containing output features (h_t) from the last layer of RNN 230 | output, hidden = self.rnn(emb, hidden) 231 | # PackedSequence object 232 | # --> (batch, seq_len, hidden_size * num_directions) 233 | output = pad_packed_sequence(output, batch_first=True)[0] 234 | # output = self.drop(output) 235 | # --> batch x hidden_size*num_directions x seq_len 236 | words_emb = output.transpose(1, 2) 237 | # --> batch x num_directions*hidden_size 238 | if self.rnn_type == 'LSTM': 239 | sent_emb = hidden[0].transpose(0, 1).contiguous() 240 | else: 241 | sent_emb = hidden.transpose(0, 1).contiguous() 242 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 243 | return words_emb, sent_emb 244 | 245 | class CNN_ENCODER(nn.Module): 246 | def __init__(self, nef): 247 | super(CNN_ENCODER, self).__init__() 248 | if cfg.TRAIN.FLAG: 249 | self.nef = nef 250 | else: 251 | self.nef = 256 # define a uniform ranker 252 | 253 | model = models.inception_v3() 254 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' 255 | model.load_state_dict(model_zoo.load_url(url)) 256 | for param in model.parameters(): # freeze inception model 257 | param.requires_grad = False 258 | print('Load pretrained model from ', url) 259 | # print(model) 260 | 261 | self.define_module(model) 262 | self.init_trainable_weights() 263 | 264 | def define_module(self, model): 265 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 266 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 267 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 268 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 269 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 270 | self.Mixed_5b = model.Mixed_5b 271 | self.Mixed_5c = model.Mixed_5c 272 | self.Mixed_5d = model.Mixed_5d 273 | self.Mixed_6a = model.Mixed_6a 274 | self.Mixed_6b = model.Mixed_6b 275 | self.Mixed_6c = model.Mixed_6c 276 | self.Mixed_6d = model.Mixed_6d 277 | self.Mixed_6e = model.Mixed_6e 278 | self.Mixed_7a = model.Mixed_7a 279 | self.Mixed_7b = model.Mixed_7b 280 | self.Mixed_7c = model.Mixed_7c 281 | 282 | self.emb_features = conv1x1(768, self.nef) 283 | self.emb_cnn_code = nn.Linear(2048, self.nef) 284 | 285 | def init_trainable_weights(self): 286 | initrange = 0.1 287 | self.emb_features.weight.data.uniform_(-initrange, initrange) 288 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) 289 | 290 | def forward(self, x): 291 | features = None 292 | # --> fixed-size input: batch x 3 x 299 x 299 293 | x = Upsample(size=(299, 299), mode='bilinear')(x) 294 | # 299 x 299 x 3 295 | x = self.Conv2d_1a_3x3(x) 296 | # 149 x 149 x 32 297 | x = self.Conv2d_2a_3x3(x) 298 | # 147 x 147 x 32 299 | x = self.Conv2d_2b_3x3(x) 300 | # 147 x 147 x 64 301 | x = F.max_pool2d(x, kernel_size=3, stride=2) 302 | # 73 x 73 x 64 303 | x = self.Conv2d_3b_1x1(x) 304 | # 73 x 73 x 80 305 | x = self.Conv2d_4a_3x3(x) 306 | # 71 x 71 x 192 307 | 308 | x = F.max_pool2d(x, kernel_size=3, stride=2) 309 | # 35 x 35 x 192 310 | x = self.Mixed_5b(x) 311 | # 35 x 35 x 256 312 | x = self.Mixed_5c(x) 313 | # 35 x 35 x 288 314 | x = self.Mixed_5d(x) 315 | # 35 x 35 x 288 316 | 317 | x = self.Mixed_6a(x) 318 | # 17 x 17 x 768 319 | x = self.Mixed_6b(x) 320 | # 17 x 17 x 768 321 | x = self.Mixed_6c(x) 322 | # 17 x 17 x 768 323 | x = self.Mixed_6d(x) 324 | # 17 x 17 x 768 325 | x = self.Mixed_6e(x) 326 | # 17 x 17 x 768 327 | 328 | # image region features 329 | features = x 330 | # 17 x 17 x 768 331 | 332 | x = self.Mixed_7a(x) 333 | # 8 x 8 x 1280 334 | x = self.Mixed_7b(x) 335 | # 8 x 8 x 2048 336 | x = self.Mixed_7c(x) 337 | # 8 x 8 x 2048 338 | x = F.avg_pool2d(x, kernel_size=8) 339 | # 1 x 1 x 2048 340 | # x = F.dropout(x, training=self.training) 341 | # 1 x 1 x 2048 342 | x = x.view(x.size(0), -1) 343 | # 2048 344 | 345 | # global image features 346 | cnn_code = self.emb_cnn_code(x) # nef 347 | 348 | if features is not None: 349 | features = self.emb_features(features) # 17 x 17 x nef 350 | return features, cnn_code 351 | 352 | 353 | # ############## Image2text Encoder-Decoder ####### 354 | class CNN_ENCODER_RNN_DECODER(CNN_ENCODER): 355 | def __init__(self, emb_size, hidden_size, vocab_size, nlayers=1, bidirectional=True, rec_unit='LSTM', dropout=0.5): 356 | """ 357 | Based on https://github.com/komiya-m/MirrorGAN/blob/master/model.py 358 | :param emb_size: size of word embeddings 359 | :param hidden_size: size of hidden state of the recurrent unit 360 | :param vocab_size: size of the vocabulary (output of the network) 361 | :param rec_unit: type of recurrent unit (default=gru) 362 | """ 363 | self.dropout = dropout 364 | self.nlayers = nlayers 365 | self.bidirectional = bidirectional 366 | self.num_directions = 2 if self.bidirectional else 1 367 | __rec_units = { 368 | 'GRU': nn.GRU, 369 | 'LSTM': nn.LSTM, 370 | } 371 | assert rec_unit in __rec_units, 'Specified recurrent unit is not available' 372 | 373 | super().__init__(emb_size) 374 | 375 | self.hidden_linear = nn.Linear(emb_size, hidden_size) 376 | self.encoder = nn.Embedding(vocab_size, emb_size) 377 | self.rnn = __rec_units[rec_unit](emb_size, hidden_size, num_layers=self.nlayers, 378 | batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) 379 | self.out = nn.Linear(self.num_directions * hidden_size, vocab_size) 380 | 381 | def forward(self, x, captions): 382 | # (bs x 17 x 17 x nef), (bs x nef) 383 | features, cnn_code = super().forward(x) 384 | # (bs x nef) 385 | cnn_hidden = self.hidden_linear(cnn_code) 386 | # (bs x hidden_size) 387 | 388 | # (num_layers * num_directions, batch, hidden_size) 389 | h_0 = cnn_hidden.unsqueeze(0).repeat(self.nlayers * self.num_directions, 1, 1) 390 | c_0 = torch.zeros(h_0.shape).to(h_0.device) 391 | 392 | # bs x T x vocab_size 393 | text_embeddings = self.encoder(captions) 394 | # bs x T x nef 395 | output, (hn, cn) = self.rnn(text_embeddings, (h_0, c_0)) 396 | # bs, T, hidden_size 397 | logits = self.out(output) 398 | # bs, T, vocab_size 399 | 400 | return features, cnn_code, logits 401 | 402 | class BERT_CNN_ENCODER_RNN_DECODER(CNN_ENCODER): 403 | def __init__(self, emb_size, hidden_size, vocab_size, nlayers=1, bidirectional=True, rec_unit='LSTM', dropout=0.5): 404 | """ 405 | Based on https://github.com/komiya-m/MirrorGAN/blob/master/model.py 406 | :param emb_size: size of word embeddings 407 | :param hidden_size: size of hidden state of the recurrent unit 408 | :param vocab_size: size of the vocabulary (output of the network) 409 | :param rec_unit: type of recurrent unit (default=gru) 410 | """ 411 | self.dropout = dropout 412 | self.nlayers = nlayers 413 | self.bidirectional = bidirectional 414 | self.num_directions = 2 if self.bidirectional else 1 415 | __rec_units = { 416 | 'GRU': nn.GRU, 417 | 'LSTM': nn.LSTM, 418 | } 419 | assert rec_unit in __rec_units, 'Specified recurrent unit is not available' 420 | 421 | super().__init__(emb_size) 422 | 423 | self.hidden_linear = nn.Linear(emb_size, hidden_size) 424 | self.encoder = BertModel.from_pretrained('bert-base-uncased') 425 | for param in self.encoder.parameters(): 426 | param.requires_grad = False 427 | 428 | self.bert_linear = nn.Linear(768, emb_size) 429 | self.rnn = __rec_units[rec_unit](emb_size, hidden_size, num_layers=self.nlayers, 430 | batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) 431 | 432 | self.out = nn.Linear(self.num_directions * hidden_size, vocab_size) 433 | 434 | def forward(self, x, captions): 435 | # (bs x 17 x 17 x nef), (bs x nef) 436 | features, cnn_code = super().forward(x) 437 | # (bs x nef) 438 | cnn_hidden = self.hidden_linear(cnn_code) 439 | # (bs x hidden_size) 440 | 441 | # (num_layers * num_directions, batch, hidden_size) 442 | h_0 = cnn_hidden.unsqueeze(0).repeat(self.nlayers * self.num_directions, 1, 1) 443 | c_0 = torch.zeros(h_0.shape).to(h_0.device) 444 | 445 | # bs x T x vocab_size 446 | # get last layer of bert encoder 447 | text_embeddings, _ = self.encoder(captions, output_all_encoded_layers=False) 448 | # bs x T x 768 449 | text_embeddings = self.bert_linear(text_embeddings) 450 | # bs x T x emb_size 451 | output, (hn, cn) = self.rnn(text_embeddings, (h_0, c_0)) 452 | # bs, T, hidden_size 453 | logits = self.out(output) 454 | # bs, T, vocab_size 455 | 456 | return features, cnn_code, logits 457 | 458 | 459 | # ############## G networks ################### 460 | class CA_NET(nn.Module): 461 | # some code is modified from vae examples 462 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 463 | def __init__(self): 464 | super(CA_NET, self).__init__() 465 | self.t_dim = cfg.TEXT.EMBEDDING_DIM 466 | self.c_dim = cfg.GAN.CONDITION_DIM 467 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) 468 | self.relu = GLU() 469 | 470 | def encode(self, text_embedding): 471 | x = self.relu(self.fc(text_embedding)) 472 | mu = x[:, :self.c_dim] 473 | logvar = x[:, self.c_dim:] 474 | return mu, logvar 475 | 476 | def reparametrize(self, mu, logvar): 477 | std = logvar.mul(0.5).exp_() 478 | if cfg.CUDA: 479 | eps = torch.cuda.FloatTensor(std.size()).normal_() 480 | else: 481 | eps = torch.FloatTensor(std.size()).normal_() 482 | eps = Variable(eps) 483 | return eps.mul(std).add_(mu) 484 | 485 | def forward(self, text_embedding): 486 | mu, logvar = self.encode(text_embedding) 487 | c_code = self.reparametrize(mu, logvar) 488 | return c_code, mu, logvar 489 | 490 | 491 | class INIT_STAGE_G(nn.Module): 492 | def __init__(self, ngf, ncf): 493 | super(INIT_STAGE_G, self).__init__() 494 | self.gf_dim = ngf 495 | self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM 496 | 497 | self.define_module() 498 | 499 | def define_module(self): 500 | nz, ngf = self.in_dim, self.gf_dim 501 | self.fc = nn.Sequential( 502 | nn.Linear(nz, ngf * 4 * 4 * 2, bias=False), 503 | nn.BatchNorm1d(ngf * 4 * 4 * 2), 504 | GLU()) 505 | 506 | self.upsample1 = upBlock(ngf, ngf // 2) 507 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 508 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 509 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 510 | 511 | def forward(self, z_code, c_code): 512 | """ 513 | :param z_code: batch x cfg.GAN.Z_DIM 514 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM 515 | :return: batch x ngf/16 x 64 x 64 516 | """ 517 | c_z_code = torch.cat((c_code, z_code), 1) 518 | # state size ngf x 4 x 4 519 | out_code = self.fc(c_z_code) 520 | out_code = out_code.view(-1, self.gf_dim, 4, 4) 521 | # state size ngf/3 x 8 x 8 522 | out_code = self.upsample1(out_code) 523 | # state size ngf/4 x 16 x 16 524 | out_code = self.upsample2(out_code) 525 | # state size ngf/8 x 32 x 32 526 | out_code32 = self.upsample3(out_code) 527 | # state size ngf/16 x 64 x 64 528 | out_code64 = self.upsample4(out_code32) 529 | 530 | return out_code64 531 | 532 | 533 | class NEXT_STAGE_G(nn.Module): 534 | def __init__(self, ngf, nef, ncf): 535 | super(NEXT_STAGE_G, self).__init__() 536 | self.gf_dim = ngf 537 | self.ef_dim = nef 538 | self.cf_dim = ncf 539 | self.num_residual = cfg.GAN.R_NUM 540 | self.define_module() 541 | 542 | def _make_layer(self, block, channel_num): 543 | layers = [] 544 | for i in range(cfg.GAN.R_NUM): 545 | layers.append(block(channel_num)) 546 | return nn.Sequential(*layers) 547 | 548 | def define_module(self): 549 | ngf = self.gf_dim 550 | self.att = ATT_NET(ngf, self.ef_dim) 551 | self.residual = self._make_layer(ResBlock, ngf * 2) 552 | self.upsample = upBlock(ngf * 2, ngf) 553 | 554 | def forward(self, h_code, c_code, word_embs, mask): 555 | """ 556 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw) 557 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len) 558 | c_code1: batch x idf x queryL 559 | att1: batch x sourceL x queryL 560 | """ 561 | self.att.applyMask(mask) 562 | c_code, att = self.att(h_code, word_embs) 563 | h_c_code = torch.cat((h_code, c_code), 1) 564 | out_code = self.residual(h_c_code) 565 | 566 | # state size ngf/2 x 2in_size x 2in_size 567 | out_code = self.upsample(out_code) 568 | 569 | return out_code, att 570 | 571 | # class SENTENCE_NEXT_STAGE_G(NEXT_STAGE_G): 572 | # def define_module(self): 573 | # ngf = self.gf_dim 574 | # self.w_att = ATT_NET(ngf, self.ef_dim) 575 | # self.s_att = ATT_NET(ngf, self.ef_dim) 576 | # self.residual = self._make_layer(ResBlock, ngf * 2) 577 | # self.upsample = upBlock(ngf * 2, ngf) 578 | # 579 | # def forward(self, h_code, c_code, sent_embs, word_embs, mask): 580 | # """ 581 | # h_code1(query): batch x idf x ih x iw (queryL=ihxiw) 582 | # word_embs(context): batch x cdf x sourceL (sourceL=seq_len) 583 | # c_code1: batch x idf x queryL 584 | # att1: batch x sourceL x queryL 585 | # """ 586 | # self.w_att.applyMask(mask) 587 | # c_code, w_att = self.w_att(h_code, word_embs) 588 | # sc_code, s_att = self.s_att(h_code, sent_embs) 589 | # 590 | # h_c_code = torch.cat((h_code, c_code, sc_code), 1) 591 | # out_code = self.residual(h_c_code) 592 | # 593 | # # state size ngf/2 x 2in_size x 2in_size 594 | # out_code = self.upsample(out_code) 595 | # 596 | # return out_code, s_att, w_att 597 | 598 | 599 | 600 | class GET_IMAGE_G(nn.Module): 601 | def __init__(self, ngf): 602 | super(GET_IMAGE_G, self).__init__() 603 | self.gf_dim = ngf 604 | self.img = nn.Sequential( 605 | conv3x3(ngf, 3), 606 | nn.Tanh() 607 | ) 608 | 609 | def forward(self, h_code): 610 | out_img = self.img(h_code) 611 | return out_img 612 | 613 | 614 | class G_NET(nn.Module): 615 | def __init__(self): 616 | super(G_NET, self).__init__() 617 | ngf = cfg.GAN.GF_DIM 618 | nef = cfg.TEXT.EMBEDDING_DIM 619 | ncf = cfg.GAN.CONDITION_DIM 620 | self.ca_net = CA_NET() 621 | 622 | if cfg.TREE.BRANCH_NUM > 0: 623 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 624 | self.img_net1 = GET_IMAGE_G(ngf) 625 | # gf x 64 x 64 626 | if cfg.TREE.BRANCH_NUM > 1: 627 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 628 | self.img_net2 = GET_IMAGE_G(ngf) 629 | if cfg.TREE.BRANCH_NUM > 2: 630 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 631 | self.img_net3 = GET_IMAGE_G(ngf) 632 | 633 | def forward(self, z_code, sent_emb, word_embs, mask): 634 | """ 635 | :param z_code: batch x cfg.GAN.Z_DIM 636 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 637 | :param word_embs: batch x cdf x seq_len 638 | :param mask: batch x seq_len 639 | :return: 640 | """ 641 | fake_imgs = [] 642 | att_maps = [] 643 | c_code, mu, logvar = self.ca_net(sent_emb) 644 | 645 | if cfg.TREE.BRANCH_NUM > 0: 646 | h_code1 = self.h_net1(z_code, c_code) 647 | fake_img1 = self.img_net1(h_code1) 648 | fake_imgs.append(fake_img1) 649 | if cfg.TREE.BRANCH_NUM > 1: 650 | h_code2, att1 = \ 651 | self.h_net2(h_code1, c_code, word_embs, mask) 652 | fake_img2 = self.img_net2(h_code2) 653 | fake_imgs.append(fake_img2) 654 | if att1 is not None: 655 | att_maps.append(att1) 656 | if cfg.TREE.BRANCH_NUM > 2: 657 | h_code3, att2 = \ 658 | self.h_net3(h_code2, c_code, word_embs, mask) 659 | fake_img3 = self.img_net3(h_code3) 660 | fake_imgs.append(fake_img3) 661 | if att2 is not None: 662 | att_maps.append(att2) 663 | 664 | return fake_imgs, att_maps, mu, logvar 665 | 666 | # 667 | # class SENTENCE_G_NET(G_NET): 668 | # def __init__(self): 669 | # super(G_NET, self).__init__() 670 | # ngf = cfg.GAN.GF_DIM 671 | # nef = cfg.TEXT.EMBEDDING_DIM 672 | # ncf = cfg.GAN.CONDITION_DIM 673 | # self.ca_net = CA_NET() 674 | # 675 | # if cfg.TREE.BRANCH_NUM > 0: 676 | # self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 677 | # self.img_net1 = GET_IMAGE_G(ngf) 678 | # # gf x 64 x 64 679 | # if cfg.TREE.BRANCH_NUM > 1: 680 | # self.h_net2 = SENTENCE_NEXT_STAGE_G(ngf, nef, ncf) 681 | # self.img_net2 = GET_IMAGE_G(ngf) 682 | # if cfg.TREE.BRANCH_NUM > 2: 683 | # self.h_net3 = SENTENCE_NEXT_STAGE_G(ngf, nef, ncf) 684 | # self.img_net3 = GET_IMAGE_G(ngf) 685 | # 686 | # def forward(self, z_code, sent_emb, word_embs, mask): 687 | # """ 688 | # :param z_code: batch x cfg.GAN.Z_DIM 689 | # :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 690 | # :param word_embs: batch x cdf x seq_len 691 | # :param mask: batch x seq_len 692 | # :return: 693 | # """ 694 | # fake_imgs = [] 695 | # w_att_maps = [] 696 | # s_att_maps = [] 697 | # c_code, mu, logvar = self.ca_net(sent_emb) 698 | # 699 | # if cfg.TREE.BRANCH_NUM > 0: 700 | # h_code1 = self.h_net1(z_code, c_code) 701 | # fake_img1 = self.img_net1(h_code1) 702 | # fake_imgs.append(fake_img1) 703 | # if cfg.TREE.BRANCH_NUM > 1: 704 | # h_code2, s_att1, w_att1 = \ 705 | # self.h_net2(h_code1, c_code, sent_emb, word_embs, mask) 706 | # fake_img2 = self.img_net2(h_code2) 707 | # fake_imgs.append(fake_img2) 708 | # if w_att1 is not None: 709 | # w_att_maps.append(w_att1) 710 | # if s_att1 is not None: 711 | # s_att_maps.append(s_att1) 712 | # if cfg.TREE.BRANCH_NUM > 2: 713 | # h_code3, s_att2, w_att2 = \ 714 | # self.h_net3(h_code2, c_code, sent_emb, word_embs, mask) 715 | # fake_img3 = self.img_net3(h_code3) 716 | # fake_imgs.append(fake_img3) 717 | # if w_att2 is not None: 718 | # w_att_maps.append(w_att2) 719 | # if s_att2 is not None: 720 | # s_att_maps.append(s_att2) 721 | # 722 | # return fake_imgs, s_att_maps, w_att_maps, mu, logvar 723 | 724 | 725 | class G_DCGAN(nn.Module): 726 | def __init__(self): 727 | super(G_DCGAN, self).__init__() 728 | ngf = cfg.GAN.GF_DIM 729 | nef = cfg.TEXT.EMBEDDING_DIM 730 | ncf = cfg.GAN.CONDITION_DIM 731 | self.ca_net = CA_NET() 732 | 733 | # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64 734 | if cfg.TREE.BRANCH_NUM > 0: 735 | self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) 736 | # gf x 64 x 64 737 | if cfg.TREE.BRANCH_NUM > 1: 738 | self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) 739 | if cfg.TREE.BRANCH_NUM > 2: 740 | self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) 741 | self.img_net = GET_IMAGE_G(ngf) 742 | 743 | def forward(self, z_code, sent_emb, word_embs, mask): 744 | """ 745 | :param z_code: batch x cfg.GAN.Z_DIM 746 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 747 | :param word_embs: batch x cdf x seq_len 748 | :param mask: batch x seq_len 749 | :return: 750 | """ 751 | att_maps = [] 752 | c_code, mu, logvar = self.ca_net(sent_emb) 753 | if cfg.TREE.BRANCH_NUM > 0: 754 | h_code = self.h_net1(z_code, c_code) 755 | if cfg.TREE.BRANCH_NUM > 1: 756 | h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask) 757 | if att1 is not None: 758 | att_maps.append(att1) 759 | if cfg.TREE.BRANCH_NUM > 2: 760 | h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask) 761 | if att2 is not None: 762 | att_maps.append(att2) 763 | 764 | fake_imgs = self.img_net(h_code) 765 | return [fake_imgs], att_maps, mu, logvar 766 | 767 | 768 | # ############## D networks ########################## 769 | def Block3x3_leakRelu(in_planes, out_planes): 770 | block = nn.Sequential( 771 | conv3x3(in_planes, out_planes), 772 | nn.BatchNorm2d(out_planes), 773 | nn.LeakyReLU(0.2, inplace=True) 774 | ) 775 | return block 776 | 777 | 778 | # Downsale the spatial size by a factor of 2 779 | def downBlock(in_planes, out_planes): 780 | block = nn.Sequential( 781 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 782 | nn.BatchNorm2d(out_planes), 783 | nn.LeakyReLU(0.2, inplace=True) 784 | ) 785 | return block 786 | 787 | 788 | # Downsale the spatial size by a factor of 16 789 | def encode_image_by_16times(ndf): 790 | encode_img = nn.Sequential( 791 | # --> state size. ndf x in_size/2 x in_size/2 792 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 793 | nn.LeakyReLU(0.2, inplace=True), 794 | # --> state size 2ndf x x in_size/4 x in_size/4 795 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 796 | nn.BatchNorm2d(ndf * 2), 797 | nn.LeakyReLU(0.2, inplace=True), 798 | # --> state size 4ndf x in_size/8 x in_size/8 799 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 800 | nn.BatchNorm2d(ndf * 4), 801 | nn.LeakyReLU(0.2, inplace=True), 802 | # --> state size 8ndf x in_size/16 x in_size/16 803 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 804 | nn.BatchNorm2d(ndf * 8), 805 | nn.LeakyReLU(0.2, inplace=True) 806 | ) 807 | return encode_img 808 | 809 | 810 | class D_GET_LOGITS(nn.Module): 811 | def __init__(self, ndf, nef, bcondition=False): 812 | super(D_GET_LOGITS, self).__init__() 813 | self.df_dim = ndf 814 | self.ef_dim = nef 815 | self.bcondition = bcondition 816 | if self.bcondition: 817 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8) 818 | 819 | self.outlogits = nn.Sequential( 820 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 821 | nn.Sigmoid()) 822 | 823 | def forward(self, h_code, c_code=None): 824 | if self.bcondition and c_code is not None: 825 | # conditioning output 826 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 827 | c_code = c_code.repeat(1, 1, 4, 4) 828 | # state size (ngf+egf) x 4 x 4 829 | h_c_code = torch.cat((h_code, c_code), 1) 830 | # state size ngf x in_size x in_size 831 | h_c_code = self.jointConv(h_c_code) 832 | else: 833 | h_c_code = h_code 834 | 835 | output = self.outlogits(h_c_code) 836 | return output.view(-1) 837 | 838 | 839 | # For 64 x 64 images 840 | class D_NET64(nn.Module): 841 | def __init__(self, b_jcu=True): 842 | super(D_NET64, self).__init__() 843 | ndf = cfg.GAN.DF_DIM 844 | nef = cfg.TEXT.EMBEDDING_DIM 845 | self.img_code_s16 = encode_image_by_16times(ndf) 846 | if b_jcu: 847 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 848 | else: 849 | self.UNCOND_DNET = None 850 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 851 | 852 | def forward(self, x_var): 853 | x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df 854 | return x_code4 855 | 856 | 857 | # For 128 x 128 images 858 | class D_NET128(nn.Module): 859 | def __init__(self, b_jcu=True): 860 | super(D_NET128, self).__init__() 861 | ndf = cfg.GAN.DF_DIM 862 | nef = cfg.TEXT.EMBEDDING_DIM 863 | self.img_code_s16 = encode_image_by_16times(ndf) 864 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 865 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) 866 | # 867 | if b_jcu: 868 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 869 | else: 870 | self.UNCOND_DNET = None 871 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 872 | 873 | def forward(self, x_var): 874 | x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df 875 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df 876 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df 877 | return x_code4 878 | 879 | 880 | # For 256 x 256 images 881 | class D_NET256(nn.Module): 882 | def __init__(self, b_jcu=True): 883 | super(D_NET256, self).__init__() 884 | ndf = cfg.GAN.DF_DIM 885 | nef = cfg.TEXT.EMBEDDING_DIM 886 | self.img_code_s16 = encode_image_by_16times(ndf) 887 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 888 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32) 889 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16) 890 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8) 891 | if b_jcu: 892 | self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) 893 | else: 894 | self.UNCOND_DNET = None 895 | self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) 896 | 897 | def forward(self, x_var): 898 | x_code16 = self.img_code_s16(x_var) 899 | x_code8 = self.img_code_s32(x_code16) 900 | x_code4 = self.img_code_s64(x_code8) 901 | x_code4 = self.img_code_s64_1(x_code4) 902 | x_code4 = self.img_code_s64_2(x_code4) 903 | return x_code4 904 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | import torch.backends.cudnn as cudnn 9 | 10 | from PIL import Image 11 | 12 | from miscc.config import cfg 13 | from miscc.utils import mkdir_p 14 | from miscc.utils import build_super_images, build_super_images2 15 | from miscc.utils import weights_init, load_params, copy_G_params 16 | from model import G_DCGAN, G_NET 17 | from datasets import prepare_data 18 | from model import RNN_ENCODER, CNN_ENCODER, CNN_ENCODER_RNN_DECODER, \ 19 | BERT_CNN_ENCODER_RNN_DECODER, BERT_RNN_ENCODER 20 | 21 | from miscc.losses import words_loss, cycle_generator_loss 22 | from miscc.losses import discriminator_loss, generator_loss, KL_loss 23 | import os 24 | import time 25 | import numpy as np 26 | import sys 27 | 28 | # ################# Text to image task############################ # 29 | class condGANTrainer(object): 30 | def __init__(self, output_dir, data_loader, n_words, ixtoword): 31 | if cfg.TRAIN.FLAG: 32 | self.model_dir = os.path.join(output_dir, 'Model') 33 | self.image_dir = os.path.join(output_dir, 'Image') 34 | mkdir_p(self.model_dir) 35 | mkdir_p(self.image_dir) 36 | 37 | torch.cuda.set_device(cfg.GPU_ID) 38 | cudnn.benchmark = True 39 | 40 | self.batch_size = cfg.TRAIN.BATCH_SIZE 41 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 42 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 43 | 44 | self.n_words = n_words 45 | self.ixtoword = ixtoword 46 | self.data_loader = data_loader 47 | self.num_batches = len(self.data_loader) 48 | 49 | def build_models(self): 50 | # ###################encoders######################################## # 51 | if cfg.TRAIN.NET_E == '': 52 | print('Error: no pretrained text-image encoders') 53 | return 54 | 55 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) 56 | img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 57 | state_dict = \ 58 | torch.load(img_encoder_path, map_location=lambda storage, loc: storage) 59 | image_encoder.load_state_dict(state_dict) 60 | for p in image_encoder.parameters(): 61 | p.requires_grad = False 62 | print('Load image encoder from:', img_encoder_path) 63 | image_encoder.eval() 64 | 65 | text_encoder = \ 66 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 67 | state_dict = \ 68 | torch.load(cfg.TRAIN.NET_E, 69 | map_location=lambda storage, loc: storage) 70 | text_encoder.load_state_dict(state_dict) 71 | for p in text_encoder.parameters(): 72 | p.requires_grad = False 73 | print('Load text encoder from:', cfg.TRAIN.NET_E) 74 | text_encoder.eval() 75 | 76 | # #######################generator and discriminators############## # 77 | netsD = [] 78 | if cfg.GAN.B_DCGAN: 79 | if cfg.TREE.BRANCH_NUM ==1: 80 | from model import D_NET64 as D_NET 81 | elif cfg.TREE.BRANCH_NUM == 2: 82 | from model import D_NET128 as D_NET 83 | else: # cfg.TREE.BRANCH_NUM == 3: 84 | from model import D_NET256 as D_NET 85 | # TODO: elif cfg.TREE.BRANCH_NUM > 3: 86 | netG = G_DCGAN() 87 | netsD = [D_NET(b_jcu=False)] 88 | else: 89 | from model import D_NET64, D_NET128, D_NET256 90 | netG = G_NET() 91 | if cfg.TREE.BRANCH_NUM > 0: 92 | netsD.append(D_NET64()) 93 | if cfg.TREE.BRANCH_NUM > 1: 94 | netsD.append(D_NET128()) 95 | if cfg.TREE.BRANCH_NUM > 2: 96 | netsD.append(D_NET256()) 97 | # TODO: if cfg.TREE.BRANCH_NUM > 3: 98 | netG.apply(weights_init) 99 | # print(netG) 100 | for i in range(len(netsD)): 101 | netsD[i].apply(weights_init) 102 | # print(netsD[i]) 103 | print('# of netsD', len(netsD)) 104 | # 105 | epoch = 0 106 | if cfg.TRAIN.NET_G != '': 107 | state_dict = \ 108 | torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) 109 | netG.load_state_dict(state_dict) 110 | print('Load G from: ', cfg.TRAIN.NET_G) 111 | istart = cfg.TRAIN.NET_G.rfind('_') + 1 112 | iend = cfg.TRAIN.NET_G.rfind('.') 113 | epoch = cfg.TRAIN.NET_G[istart:iend] 114 | epoch = int(epoch) + 1 115 | if cfg.TRAIN.B_NET_D: 116 | Gname = cfg.TRAIN.NET_G 117 | for i in range(len(netsD)): 118 | s_tmp = Gname[:Gname.rfind('/')] 119 | Dname = '%s/netD%d.pth' % (s_tmp, i) 120 | print('Load D from: ', Dname) 121 | state_dict = \ 122 | torch.load(Dname, map_location=lambda storage, loc: storage) 123 | netsD[i].load_state_dict(state_dict) 124 | # ########################################################### # 125 | if cfg.CUDA: 126 | text_encoder = text_encoder.cuda() 127 | image_encoder = image_encoder.cuda() 128 | netG.cuda() 129 | for i in range(len(netsD)): 130 | netsD[i].cuda() 131 | return [text_encoder, image_encoder, netG, netsD, epoch] 132 | 133 | def define_optimizers(self, netG, netsD): 134 | optimizersD = [] 135 | num_Ds = len(netsD) 136 | for i in range(num_Ds): 137 | opt = optim.Adam(netsD[i].parameters(), 138 | lr=cfg.TRAIN.DISCRIMINATOR_LR, 139 | betas=(0.5, 0.999)) 140 | optimizersD.append(opt) 141 | 142 | optimizerG = optim.Adam(netG.parameters(), 143 | lr=cfg.TRAIN.GENERATOR_LR, 144 | betas=(0.5, 0.999)) 145 | 146 | return optimizerG, optimizersD 147 | 148 | def prepare_labels(self): 149 | batch_size = self.batch_size 150 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 151 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 152 | match_labels = Variable(torch.LongTensor(range(batch_size))) 153 | if cfg.CUDA: 154 | real_labels = real_labels.cuda() 155 | fake_labels = fake_labels.cuda() 156 | match_labels = match_labels.cuda() 157 | 158 | return real_labels, fake_labels, match_labels 159 | 160 | def save_model(self, netG, avg_param_G, netsD, epoch): 161 | backup_para = copy_G_params(netG) 162 | load_params(netG, avg_param_G) 163 | torch.save(netG.state_dict(), 164 | '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)) 165 | load_params(netG, backup_para) 166 | # 167 | for i in range(len(netsD)): 168 | netD = netsD[i] 169 | torch.save(netD.state_dict(), 170 | '%s/netD%d.pth' % (self.model_dir, i)) 171 | print('Save G/Ds models.') 172 | 173 | def set_requires_grad_value(self, models_list, brequires): 174 | for i in range(len(models_list)): 175 | for p in models_list[i].parameters(): 176 | p.requires_grad = brequires 177 | 178 | def save_img_results(self, netG, noise, sent_emb, words_embs, mask, 179 | image_encoder, captions, cap_lens, 180 | gen_iterations, name='current'): 181 | # Save images 182 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 183 | for i in range(len(attention_maps)): 184 | if len(fake_imgs) > 1: 185 | img = fake_imgs[i + 1].detach().cpu() 186 | lr_img = fake_imgs[i].detach().cpu() 187 | else: 188 | img = fake_imgs[0].detach().cpu() 189 | lr_img = None 190 | attn_maps = attention_maps[i] 191 | att_sze = attn_maps.size(2) 192 | img_set, _ = \ 193 | build_super_images(img, captions, self.ixtoword, 194 | attn_maps, att_sze, lr_imgs=lr_img) 195 | if img_set is not None: 196 | im = Image.fromarray(img_set) 197 | fullpath = '%s/G_%s_%d_%d.png'\ 198 | % (self.image_dir, name, gen_iterations, i) 199 | im.save(fullpath) 200 | 201 | # for i in range(len(netsD)): 202 | i = -1 203 | img = fake_imgs[i].detach() 204 | region_features, _ = image_encoder(img) 205 | att_sze = region_features.size(2) 206 | _, _, att_maps = words_loss(region_features.detach(), 207 | words_embs.detach(), 208 | None, cap_lens, 209 | None, self.batch_size) 210 | img_set, _ = \ 211 | build_super_images(fake_imgs[i].detach().cpu(), 212 | captions, self.ixtoword, att_maps, att_sze) 213 | if img_set is not None: 214 | im = Image.fromarray(img_set) 215 | fullpath = '%s/D_%s_%d.png'\ 216 | % (self.image_dir, name, gen_iterations) 217 | im.save(fullpath) 218 | 219 | def train(self): 220 | text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() 221 | avg_param_G = copy_G_params(netG) 222 | optimizerG, optimizersD = self.define_optimizers(netG, netsD) 223 | real_labels, fake_labels, match_labels = self.prepare_labels() 224 | 225 | batch_size = self.batch_size 226 | nz = cfg.GAN.Z_DIM 227 | noise = Variable(torch.FloatTensor(batch_size, nz)) 228 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) 229 | if cfg.CUDA: 230 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 231 | 232 | gen_iterations = 0 233 | # gen_iterations = start_epoch * self.num_batches 234 | for epoch in range(start_epoch, self.max_epoch): 235 | start_t = time.time() 236 | 237 | data_iter = iter(self.data_loader) 238 | step = 0 239 | while step < self.num_batches: 240 | # reset requires_grad to be trainable for all Ds 241 | # self.set_requires_grad_value(netsD, True) 242 | 243 | ###################################################### 244 | # (1) Prepare training data and Compute text embeddings 245 | ###################################################### 246 | data = data_iter.next() 247 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 248 | 249 | hidden = text_encoder.init_hidden(batch_size) 250 | # words_embs: batch_size x nef x seq_len 251 | # sent_emb: batch_size x nef 252 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 253 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 254 | mask = (captions == 0) 255 | num_words = words_embs.size(2) 256 | if mask.size(1) > num_words: 257 | mask = mask[:, :num_words] 258 | 259 | ####################################################### 260 | # (2) Generate fake images 261 | ###################################################### 262 | noise.data.normal_(0, 1) 263 | fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) 264 | 265 | ####################################################### 266 | # (3) Update D network 267 | ###################################################### 268 | errD_total = 0 269 | D_logs = '' 270 | for i in range(len(netsD)): 271 | netsD[i].zero_grad() 272 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], 273 | sent_emb, real_labels, fake_labels) 274 | # backward and update parameters 275 | errD.backward() 276 | optimizersD[i].step() 277 | errD_total += errD 278 | D_logs += 'errD%d: %.2f ' % (i, errD.item()) 279 | 280 | ####################################################### 281 | # (4) Update G network: maximize log(D(G(z))) 282 | ###################################################### 283 | # compute total loss for training G 284 | step += 1 285 | gen_iterations += 1 286 | 287 | # do not need to compute gradient for Ds 288 | # self.set_requires_grad_value(netsD, False) 289 | netG.zero_grad() 290 | errG_total, G_logs = \ 291 | generator_loss(netsD, image_encoder, fake_imgs, real_labels, 292 | words_embs, sent_emb, match_labels, cap_lens, class_ids) 293 | kl_loss = KL_loss(mu, logvar) 294 | errG_total += kl_loss 295 | G_logs += 'kl_loss: %.2f ' % kl_loss.item() 296 | # backward and update parameters 297 | errG_total.backward() 298 | optimizerG.step() 299 | for p, avg_p in zip(netG.parameters(), avg_param_G): 300 | avg_p.mul_(0.999).add_(0.001, p.data) 301 | 302 | if gen_iterations % 100 == 0: 303 | print(D_logs + '\n' + G_logs) 304 | # save images 305 | if gen_iterations % 1000 == 0: 306 | backup_para = copy_G_params(netG) 307 | load_params(netG, avg_param_G) 308 | self.save_img_results(netG, fixed_noise, sent_emb, 309 | words_embs, mask, image_encoder, 310 | captions, cap_lens, epoch, name='average') 311 | load_params(netG, backup_para) 312 | # 313 | # self.save_img_results(netG, fixed_noise, sent_emb, 314 | # words_embs, mask, image_encoder, 315 | # captions, cap_lens, 316 | # epoch, name='current') 317 | end_t = time.time() 318 | 319 | print('''[%d/%d][%d] 320 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' 321 | % (epoch, self.max_epoch, self.num_batches, 322 | errD_total.item(), errG_total.item(), 323 | end_t - start_t)) 324 | 325 | if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: 326 | self.save_model(netG, avg_param_G, netsD, epoch) 327 | 328 | self.save_model(netG, avg_param_G, netsD, self.max_epoch) 329 | 330 | def save_singleimages(self, images, filenames, save_dir, 331 | split_dir, sentenceID=0): 332 | for i in range(images.size(0)): 333 | s_tmp = '%s/single_samples/%s/%s' %\ 334 | (save_dir, split_dir, filenames[i]) 335 | folder = s_tmp[:s_tmp.rfind('/')] 336 | if not os.path.isdir(folder): 337 | print('Make a new folder: ', folder) 338 | mkdir_p(folder) 339 | 340 | fullpath = '%s_%d.jpg' % (s_tmp, sentenceID) 341 | # range from [-1, 1] to [0, 1] 342 | # img = (images[i] + 1.0) / 2 343 | img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() 344 | # range from [0, 1] to [0, 255] 345 | ndarr = img.permute(1, 2, 0).data.cpu().numpy() 346 | im = Image.fromarray(ndarr) 347 | im.save(fullpath) 348 | 349 | def sampling(self, split_dir): 350 | if cfg.TRAIN.NET_G == '': 351 | print('Error: the path for morels is not found!') 352 | else: 353 | if split_dir == 'test': 354 | split_dir = 'valid' 355 | # Build and load the generator 356 | if cfg.GAN.B_DCGAN: 357 | netG = G_DCGAN() 358 | else: 359 | netG = G_NET() 360 | netG.apply(weights_init) 361 | netG.cuda() 362 | netG.eval() 363 | # 364 | text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 365 | state_dict = \ 366 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 367 | text_encoder.load_state_dict(state_dict) 368 | print('Load text encoder from:', cfg.TRAIN.NET_E) 369 | text_encoder = text_encoder.cuda() 370 | text_encoder.eval() 371 | 372 | batch_size = self.batch_size 373 | nz = cfg.GAN.Z_DIM 374 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 375 | noise = noise.cuda() 376 | 377 | model_dir = cfg.TRAIN.NET_G 378 | state_dict = \ 379 | torch.load(model_dir, map_location=lambda storage, loc: storage) 380 | # state_dict = torch.load(cfg.TRAIN.NET_G) 381 | netG.load_state_dict(state_dict) 382 | print('Load G from: ', model_dir) 383 | 384 | # the path to save generated images 385 | s_tmp = model_dir[:model_dir.rfind('.pth')] 386 | save_dir = '%s/%s' % (s_tmp, split_dir) 387 | mkdir_p(save_dir) 388 | 389 | cnt = 0 390 | 391 | for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): 392 | for step, data in enumerate(self.data_loader, 0): 393 | cnt += batch_size 394 | if step % 100 == 0: 395 | print('step: ', step) 396 | # if step > 50: 397 | # break 398 | 399 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 400 | 401 | hidden = text_encoder.init_hidden(batch_size) 402 | # words_embs: batch_size x nef x seq_len 403 | # sent_emb: batch_size x nef 404 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 405 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 406 | mask = (captions == 0) 407 | num_words = words_embs.size(2) 408 | if mask.size(1) > num_words: 409 | mask = mask[:, :num_words] 410 | 411 | ####################################################### 412 | # (2) Generate fake images 413 | ###################################################### 414 | noise.data.normal_(0, 1) 415 | fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) 416 | for j in range(batch_size): 417 | s_tmp = '%s/single/%s' % (save_dir, keys[j]) 418 | folder = s_tmp[:s_tmp.rfind('/')] 419 | if not os.path.isdir(folder): 420 | print('Make a new folder: ', folder) 421 | mkdir_p(folder) 422 | k = -1 423 | # for k in range(len(fake_imgs)): 424 | im = fake_imgs[k][j].data.cpu().numpy() 425 | # [-1, 1] --> [0, 255] 426 | im = (im + 1.0) * 127.5 427 | im = im.astype(np.uint8) 428 | im = np.transpose(im, (1, 2, 0)) 429 | im = Image.fromarray(im) 430 | fullpath = '%s_s%d.png' % (s_tmp, k) 431 | im.save(fullpath) 432 | 433 | def gen_example(self, data_dic): 434 | if cfg.TRAIN.NET_G == '': 435 | print('Error: the path for morels is not found!') 436 | else: 437 | # Build and load the generator 438 | text_encoder = \ 439 | RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 440 | state_dict = \ 441 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 442 | text_encoder.load_state_dict(state_dict) 443 | print('Load text encoder from:', cfg.TRAIN.NET_E) 444 | text_encoder = text_encoder.cuda() 445 | text_encoder.eval() 446 | 447 | # the path to save generated images 448 | if cfg.GAN.B_DCGAN: 449 | netG = G_DCGAN() 450 | else: 451 | netG = G_NET() 452 | s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] 453 | model_dir = cfg.TRAIN.NET_G 454 | state_dict = \ 455 | torch.load(model_dir, map_location=lambda storage, loc: storage) 456 | netG.load_state_dict(state_dict) 457 | print('Load G from: ', model_dir) 458 | netG.cuda() 459 | netG.eval() 460 | for key in data_dic: 461 | save_dir = '%s/%s' % (s_tmp, key) 462 | mkdir_p(save_dir) 463 | captions, cap_lens, sorted_indices = data_dic[key] 464 | 465 | batch_size = captions.shape[0] 466 | nz = cfg.GAN.Z_DIM 467 | captions = Variable(torch.from_numpy(captions), volatile=True) 468 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) 469 | 470 | captions = captions.cuda() 471 | cap_lens = cap_lens.cuda() 472 | for i in range(1): # 16 473 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 474 | noise = noise.cuda() 475 | ####################################################### 476 | # (1) Extract text embeddings 477 | ###################################################### 478 | hidden = text_encoder.init_hidden(batch_size) 479 | # words_embs: batch_size x nef x seq_len 480 | # sent_emb: batch_size x nef 481 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 482 | mask = (captions == 0) 483 | ####################################################### 484 | # (2) Generate fake images 485 | ###################################################### 486 | noise.data.normal_(0, 1) 487 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 488 | # G attention 489 | cap_lens_np = cap_lens.cpu().data.numpy() 490 | for j in range(batch_size): 491 | save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) 492 | for k in range(len(fake_imgs)): 493 | im = fake_imgs[k][j].data.cpu().numpy() 494 | im = (im + 1.0) * 127.5 495 | im = im.astype(np.uint8) 496 | # print('im', im.shape) 497 | im = np.transpose(im, (1, 2, 0)) 498 | # print('im', im.shape) 499 | im = Image.fromarray(im) 500 | fullpath = '%s_g%d.png' % (save_name, k) 501 | im.save(fullpath) 502 | 503 | for k in range(len(attention_maps)): 504 | if len(fake_imgs) > 1: 505 | im = fake_imgs[k + 1].detach().cpu() 506 | else: 507 | im = fake_imgs[0].detach().cpu() 508 | attn_maps = attention_maps[k] 509 | att_sze = attn_maps.size(2) 510 | img_set, sentences = \ 511 | build_super_images2(im[j].unsqueeze(0), 512 | captions[j].unsqueeze(0), 513 | [cap_lens_np[j]], self.ixtoword, 514 | [attn_maps[j]], att_sze) 515 | if img_set is not None: 516 | im = Image.fromarray(img_set) 517 | fullpath = '%s_a%d.png' % (save_name, k) 518 | im.save(fullpath) 519 | 520 | 521 | ############# CYCLE GAN ########## 522 | class CycleGANTrainer(condGANTrainer): 523 | def build_models(self): 524 | # ###################encoders######################################## # 525 | if cfg.TRAIN.NET_E == '': 526 | print('Error: no pretrained text-image encoders') 527 | return 528 | image_encoder = BERT_CNN_ENCODER_RNN_DECODER(cfg.TEXT.EMBEDDING_DIM, cfg.CNN_RNN.HIDDEN_DIM, 529 | self.n_words, rec_unit=cfg.RNN_TYPE) 530 | img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 531 | state_dict = \ 532 | torch.load(img_encoder_path, map_location=lambda storage, loc: storage) 533 | image_encoder.load_state_dict(state_dict) 534 | for p in image_encoder.parameters(): 535 | p.requires_grad = False 536 | print('Load image encoder from:', img_encoder_path) 537 | # image_encoder.eval() 538 | 539 | text_encoder = \ 540 | BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 541 | state_dict = \ 542 | torch.load(cfg.TRAIN.NET_E, 543 | map_location=lambda storage, loc: storage) 544 | text_encoder.load_state_dict(state_dict) 545 | for p in text_encoder.parameters(): 546 | p.requires_grad = False 547 | print('Load text encoder from:', cfg.TRAIN.NET_E) 548 | text_encoder.eval() 549 | 550 | # #######################generator and discriminators############## # 551 | netsD = [] 552 | if cfg.GAN.B_DCGAN: 553 | if cfg.TREE.BRANCH_NUM ==1: 554 | from model import D_NET64 as D_NET 555 | elif cfg.TREE.BRANCH_NUM == 2: 556 | from model import D_NET128 as D_NET 557 | else: # cfg.TREE.BRANCH_NUM == 3: 558 | from model import D_NET256 as D_NET 559 | # TODO: elif cfg.TREE.BRANCH_NUM > 3: 560 | netG = G_DCGAN() 561 | netsD = [D_NET(b_jcu=False)] 562 | else: 563 | from model import D_NET64, D_NET128, D_NET256 564 | netG = G_NET() 565 | if cfg.TREE.BRANCH_NUM > 0: 566 | netsD.append(D_NET64()) 567 | if cfg.TREE.BRANCH_NUM > 1: 568 | netsD.append(D_NET128()) 569 | if cfg.TREE.BRANCH_NUM > 2: 570 | netsD.append(D_NET256()) 571 | # TODO: if cfg.TREE.BRANCH_NUM > 3: 572 | netG.apply(weights_init) 573 | # print(netG) 574 | for i in range(len(netsD)): 575 | netsD[i].apply(weights_init) 576 | # print(netsD[i]) 577 | print('# of netsD', len(netsD)) 578 | # 579 | epoch = 0 580 | if cfg.TRAIN.NET_G != '': 581 | state_dict = \ 582 | torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) 583 | netG.load_state_dict(state_dict) 584 | print('Load G from: ', cfg.TRAIN.NET_G) 585 | istart = cfg.TRAIN.NET_G.rfind('_') + 1 586 | iend = cfg.TRAIN.NET_G.rfind('.') 587 | epoch = cfg.TRAIN.NET_G[istart:iend] 588 | epoch = int(epoch) + 1 589 | if cfg.TRAIN.B_NET_D: 590 | Gname = cfg.TRAIN.NET_G 591 | for i in range(len(netsD)): 592 | s_tmp = Gname[:Gname.rfind('/')] 593 | Dname = '%s/netD%d.pth' % (s_tmp, i) 594 | print('Load D from: ', Dname) 595 | state_dict = \ 596 | torch.load(Dname, map_location=lambda storage, loc: storage) 597 | netsD[i].load_state_dict(state_dict) 598 | # ########################################################### # 599 | if cfg.CUDA: 600 | text_encoder = text_encoder.cuda() 601 | image_encoder = image_encoder.cuda() 602 | netG.cuda() 603 | for i in range(len(netsD)): 604 | netsD[i].cuda() 605 | return [text_encoder, image_encoder, netG, netsD, epoch] 606 | 607 | def save_img_results(self, netG, noise, sent_emb, words_embs, mask, 608 | image_encoder, captions, cap_lens, 609 | gen_iterations, name='current'): 610 | # Save images 611 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 612 | for i in range(len(attention_maps)): 613 | if len(fake_imgs) > 1: 614 | img = fake_imgs[i + 1].detach().cpu() 615 | lr_img = fake_imgs[i].detach().cpu() 616 | else: 617 | img = fake_imgs[0].detach().cpu() 618 | lr_img = None 619 | attn_maps = attention_maps[i] 620 | att_sze = attn_maps.size(2) 621 | img_set, _ = \ 622 | build_super_images(img, captions, self.ixtoword, 623 | attn_maps, att_sze, lr_imgs=lr_img) 624 | if img_set is not None: 625 | im = Image.fromarray(img_set) 626 | fullpath = '%s/G_%s_%d_%d.png'\ 627 | % (self.image_dir, name, gen_iterations, i) 628 | im.save(fullpath) 629 | 630 | # for i in range(len(netsD)): 631 | i = -1 632 | img = fake_imgs[i].detach() 633 | region_features, _, _ = image_encoder(img, captions) 634 | att_sze = region_features.size(2) 635 | _, _, att_maps = words_loss(region_features.detach(), 636 | words_embs.detach(), 637 | None, cap_lens, 638 | None, self.batch_size) 639 | img_set, _ = \ 640 | build_super_images(fake_imgs[i].detach().cpu(), 641 | captions, self.ixtoword, att_maps, att_sze) 642 | if img_set is not None: 643 | im = Image.fromarray(img_set) 644 | fullpath = '%s/D_%s_%d.png'\ 645 | % (self.image_dir, name, gen_iterations) 646 | im.save(fullpath) 647 | 648 | def train(self): 649 | text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() 650 | avg_param_G = copy_G_params(netG) 651 | optimizerG, optimizersD = self.define_optimizers(netG, netsD) 652 | real_labels, fake_labels, match_labels = self.prepare_labels() 653 | 654 | batch_size = self.batch_size 655 | nz = cfg.GAN.Z_DIM 656 | noise = Variable(torch.FloatTensor(batch_size, nz)) 657 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) 658 | if cfg.CUDA: 659 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 660 | 661 | gen_iterations = 0 662 | # gen_iterations = start_epoch * self.num_batches 663 | for epoch in range(start_epoch, self.max_epoch): 664 | start_t = time.time() 665 | 666 | data_iter = iter(self.data_loader) 667 | step = 0 668 | while step < self.num_batches: 669 | # reset requires_grad to be trainable for all Ds 670 | # self.set_requires_grad_value(netsD, True) 671 | 672 | ###################################################### 673 | # (1) Prepare training data and Compute text embeddings 674 | ###################################################### 675 | data = data_iter.next() 676 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 677 | 678 | hidden = text_encoder.init_hidden(batch_size) 679 | # words_embs: batch_size x nef x seq_len 680 | # sent_emb: batch_size x nef 681 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 682 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 683 | mask = (captions == 0) 684 | num_words = words_embs.size(2) 685 | if mask.size(1) > num_words: 686 | mask = mask[:, :num_words] 687 | 688 | ####################################################### 689 | # (2) Generate fake images 690 | ###################################################### 691 | noise.data.normal_(0, 1) 692 | fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) 693 | 694 | ####################################################### 695 | # (3) Update D network 696 | ###################################################### 697 | errD_total = 0 698 | D_logs = '' 699 | for i in range(len(netsD)): 700 | netsD[i].zero_grad() 701 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], 702 | sent_emb, real_labels, fake_labels) 703 | # backward and update parameters 704 | errD.backward() 705 | optimizersD[i].step() 706 | errD_total += errD 707 | D_logs += 'errD%d: %.2f ' % (i, errD.item()) 708 | 709 | ####################################################### 710 | # (4) Update G network: maximize log(D(G(z))) 711 | ###################################################### 712 | # compute total loss for training G 713 | step += 1 714 | gen_iterations += 1 715 | 716 | # do not need to compute gradient for Ds 717 | # self.set_requires_grad_value(netsD, False) 718 | netG.zero_grad() 719 | errG_total, G_logs = \ 720 | cycle_generator_loss(netsD, image_encoder, fake_imgs, real_labels, captions, 721 | words_embs, sent_emb, match_labels, cap_lens, class_ids) 722 | kl_loss = KL_loss(mu, logvar) 723 | errG_total += kl_loss 724 | G_logs += 'kl_loss: %.2f ' % kl_loss.item() 725 | # backward and update parameters 726 | errG_total.backward() 727 | optimizerG.step() 728 | for p, avg_p in zip(netG.parameters(), avg_param_G): 729 | avg_p.mul_(0.999).add_(0.001, p.data) 730 | 731 | if gen_iterations % 100 == 0: 732 | print(D_logs + '\n' + G_logs) 733 | # save images 734 | if gen_iterations % 1000 == 0: 735 | backup_para = copy_G_params(netG) 736 | load_params(netG, avg_param_G) 737 | self.save_img_results(netG, fixed_noise, sent_emb, 738 | words_embs, mask, image_encoder, 739 | captions, cap_lens, epoch, name='average') 740 | load_params(netG, backup_para) 741 | # 742 | # self.save_img_results(netG, fixed_noise, sent_emb, 743 | # words_embs, mask, image_encoder, 744 | # captions, cap_lens, 745 | # epoch, name='current') 746 | end_t = time.time() 747 | 748 | print('''[%d/%d][%d] 749 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' 750 | % (epoch, self.max_epoch, self.num_batches, 751 | errD_total.item(), errG_total.item(), 752 | end_t - start_t)) 753 | 754 | if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: 755 | self.save_model(netG, avg_param_G, netsD, epoch) 756 | 757 | self.save_model(netG, avg_param_G, netsD, self.max_epoch) 758 | 759 | def save_singleimages(self, images, filenames, save_dir, 760 | split_dir, sentenceID=0): 761 | for i in range(images.size(0)): 762 | s_tmp = '%s/single_samples/%s/%s' %\ 763 | (save_dir, split_dir, filenames[i]) 764 | folder = s_tmp[:s_tmp.rfind('/')] 765 | if not os.path.isdir(folder): 766 | print('Make a new folder: ', folder) 767 | mkdir_p(folder) 768 | 769 | fullpath = '%s_%d.jpg' % (s_tmp, sentenceID) 770 | # range from [-1, 1] to [0, 1] 771 | # img = (images[i] + 1.0) / 2 772 | img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() 773 | # range from [0, 1] to [0, 255] 774 | ndarr = img.permute(1, 2, 0).data.cpu().numpy() 775 | im = Image.fromarray(ndarr) 776 | im.save(fullpath) 777 | 778 | def sampling(self, split_dir): 779 | if cfg.TRAIN.NET_G == '': 780 | print('Error: the path for morels is not found!') 781 | else: 782 | if split_dir == 'test': 783 | split_dir = 'valid' 784 | # Build and load the generator 785 | if cfg.GAN.B_DCGAN: 786 | netG = G_DCGAN() 787 | else: 788 | netG = G_NET() 789 | netG.apply(weights_init) 790 | netG.cuda() 791 | netG.eval() 792 | # 793 | text_encoder = BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 794 | state_dict = \ 795 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 796 | text_encoder.load_state_dict(state_dict) 797 | print('Load text encoder from:', cfg.TRAIN.NET_E) 798 | text_encoder = text_encoder.cuda() 799 | text_encoder.eval() 800 | 801 | batch_size = self.batch_size 802 | nz = cfg.GAN.Z_DIM 803 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 804 | noise = noise.cuda() 805 | 806 | model_dir = cfg.TRAIN.NET_G 807 | state_dict = \ 808 | torch.load(model_dir, map_location=lambda storage, loc: storage) 809 | # state_dict = torch.load(cfg.TRAIN.NET_G) 810 | netG.load_state_dict(state_dict) 811 | print('Load G from: ', model_dir) 812 | 813 | # the path to save generated images 814 | s_tmp = model_dir[:model_dir.rfind('.pth')] 815 | save_dir = '%s/%s' % (s_tmp, split_dir) 816 | mkdir_p(save_dir) 817 | 818 | cnt = 0 819 | 820 | for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): 821 | for step, data in enumerate(self.data_loader, 0): 822 | cnt += batch_size 823 | if step % 100 == 0: 824 | print('step: ', step) 825 | # if step > 50: 826 | # break 827 | 828 | imgs, captions, cap_lens, class_ids, keys = prepare_data(data) 829 | 830 | hidden = text_encoder.init_hidden(batch_size) 831 | # words_embs: batch_size x nef x seq_len 832 | # sent_emb: batch_size x nef 833 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 834 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 835 | mask = (captions == 0) 836 | num_words = words_embs.size(2) 837 | if mask.size(1) > num_words: 838 | mask = mask[:, :num_words] 839 | 840 | ####################################################### 841 | # (2) Generate fake images 842 | ###################################################### 843 | noise.data.normal_(0, 1) 844 | fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) 845 | for j in range(batch_size): 846 | s_tmp = '%s/single/%s' % (save_dir, keys[j]) 847 | folder = s_tmp[:s_tmp.rfind('/')] 848 | if not os.path.isdir(folder): 849 | print('Make a new folder: ', folder) 850 | mkdir_p(folder) 851 | k = -1 852 | # for k in range(len(fake_imgs)): 853 | im = fake_imgs[k][j].data.cpu().numpy() 854 | # [-1, 1] --> [0, 255] 855 | im = (im + 1.0) * 127.5 856 | im = im.astype(np.uint8) 857 | im = np.transpose(im, (1, 2, 0)) 858 | im = Image.fromarray(im) 859 | fullpath = '%s_s%d.png' % (s_tmp, k) 860 | im.save(fullpath) 861 | 862 | def gen_example(self, data_dic): 863 | if cfg.TRAIN.NET_G == '': 864 | print('Error: the path for morels is not found!') 865 | else: 866 | # Build and load the generator 867 | text_encoder = \ 868 | BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 869 | state_dict = \ 870 | torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 871 | text_encoder.load_state_dict(state_dict) 872 | print('Load text encoder from:', cfg.TRAIN.NET_E) 873 | text_encoder = text_encoder.cuda() 874 | text_encoder.eval() 875 | 876 | # the path to save generated images 877 | if cfg.GAN.B_DCGAN: 878 | netG = G_DCGAN() 879 | else: 880 | netG = G_NET() 881 | s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] 882 | model_dir = cfg.TRAIN.NET_G 883 | state_dict = \ 884 | torch.load(model_dir, map_location=lambda storage, loc: storage) 885 | netG.load_state_dict(state_dict) 886 | print('Load G from: ', model_dir) 887 | netG.cuda() 888 | netG.eval() 889 | for key in data_dic: 890 | save_dir = '%s/%s' % (s_tmp, key) 891 | mkdir_p(save_dir) 892 | captions, cap_lens, sorted_indices = data_dic[key] 893 | 894 | batch_size = captions.shape[0] 895 | nz = cfg.GAN.Z_DIM 896 | captions = Variable(torch.from_numpy(captions), volatile=True) 897 | cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) 898 | 899 | captions = captions.cuda() 900 | cap_lens = cap_lens.cuda() 901 | for i in range(1): # 16 902 | noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) 903 | noise = noise.cuda() 904 | ####################################################### 905 | # (1) Extract text embeddings 906 | ###################################################### 907 | hidden = text_encoder.init_hidden(batch_size) 908 | # words_embs: batch_size x nef x seq_len 909 | # sent_emb: batch_size x nef 910 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 911 | mask = (captions == 0) 912 | ####################################################### 913 | # (2) Generate fake images 914 | ###################################################### 915 | noise.data.normal_(0, 1) 916 | fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) 917 | # G attention 918 | cap_lens_np = cap_lens.cpu().data.numpy() 919 | for j in range(batch_size): 920 | save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) 921 | for k in range(len(fake_imgs)): 922 | im = fake_imgs[k][j].data.cpu().numpy() 923 | im = (im + 1.0) * 127.5 924 | im = im.astype(np.uint8) 925 | # print('im', im.shape) 926 | im = np.transpose(im, (1, 2, 0)) 927 | # print('im', im.shape) 928 | im = Image.fromarray(im) 929 | fullpath = '%s_g%d.png' % (save_name, k) 930 | im.save(fullpath) 931 | 932 | for k in range(len(attention_maps)): 933 | if len(fake_imgs) > 1: 934 | im = fake_imgs[k + 1].detach().cpu() 935 | else: 936 | im = fake_imgs[0].detach().cpu() 937 | attn_maps = attention_maps[k] 938 | att_sze = attn_maps.size(2) 939 | img_set, sentences = \ 940 | build_super_images2(im[j].unsqueeze(0), 941 | captions[j].unsqueeze(0), 942 | [cap_lens_np[j]], self.ixtoword, 943 | [attn_maps[j]], att_sze) 944 | if img_set is not None: 945 | im = Image.fromarray(img_set) 946 | fullpath = '%s_a%d.png' % (save_name, k) 947 | im.save(fullpath) 948 | --------------------------------------------------------------------------------