├── .gitmodules ├── LICENSE ├── README.md ├── data-prepro ├── CUB200_preprocess │ ├── CUB_preprocess_token.py │ ├── ECCV16_explanations_splits │ │ ├── test.txt │ │ ├── train_noCub.txt │ │ └── val.txt │ ├── dictionary_5.npz │ ├── download_cub.sh │ ├── get_split.py │ └── prepro_cub_annotation.py └── MSCOCO_preprocess │ ├── K_cleaned_words.npz │ ├── K_split.json │ ├── dictionary_5.npz │ ├── download_mscoco.sh │ ├── extract_resnet_coco.py │ ├── prepro_coco_annotation.py │ ├── prepro_mscoco_caption.sh │ ├── preprocess_entity.py │ ├── preprocess_token.py │ └── resnet_model │ └── ResNet_mean.npy ├── images ├── im11063.jpg ├── im22197.jpg ├── im270.jpg ├── im6795.jpg └── teaser.png └── show-adapt-tell ├── cub ├── data ├── data_loader.py ├── highway.py ├── main.py ├── model.py ├── pretrain_CNN_D.py ├── pretrain_G.py ├── pretrain_LSTM_D.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data-prepro/MSCOCO_preprocess/neuraltalk2"] 2 | path = data-prepro/MSCOCO_preprocess/neuraltalk2 3 | url = git@github.com:karpathy/neuraltalk2.git 4 | [submodule "data-prepro/MSCOCO_preprocess/deep-residual-networks"] 5 | path = data-prepro/MSCOCO_preprocess/deep-residual-networks 6 | url = git@github.com:KaimingHe/deep-residual-networks.git 7 | [submodule "show-adapt-tell/coco-caption"] 8 | path = show-adapt-tell/coco-caption 9 | url = git@github.com:peteanderson80/coco-caption.git 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Paul Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # show-adapt-and-tell 2 | 3 | This is the official code for the paper 4 | 5 | **[Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner](https://arxiv.org/pdf/1705.00930.pdf)** 6 |
7 | [Tseng-Hung Chen](https://tsenghungchen.github.io/), 8 | [Yuan-Hong Liao](https://andrewliao11.github.io/), 9 | [Ching-Yao Chuang](http://jameschuanggg.github.io/), 10 | [Wan-Ting Hsu](https://hsuwanting.github.io/), 11 | [Jianlong Fu](https://www.microsoft.com/en-us/research/people/jianf/), 12 | [Min Sun](http://aliensunmin.github.io/) 13 |
14 | To appear in [ICCV 2017](http://iccv2017.thecvf.com/) 15 | 16 | 17 |
18 | 19 |
20 | 21 | In this repository we provide: 22 | 23 | - The cross-domain captioning models [used in the paper](#models-from-the-paper) 24 | - Script for [preprocessing MSCOCO data](#mscoco-captioning-dataset) 25 | - Script for [preprocessing CUB-200-2011 captions](#cub-200-2011-with-descriptions) 26 | - Code for [training the cross-domain captioning models](#training) 27 | 28 | 29 | If you find this code useful for your research, please cite 30 | 31 | ``` 32 | @article{chen2017show, 33 | title={Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner}, 34 | author={Chen, Tseng-Hung and Liao, Yuan-Hong and Chuang, Ching-Yao and Hsu, Wan-Ting and Fu, Jianlong and Sun, Min}, 35 | journal={arXiv preprint arXiv:1705.00930}, 36 | year={2017} 37 | } 38 | ``` 39 | 40 | ## Requirements 41 | 42 | - Python 2.7 43 | - [Tensoflow 0.12.1](https://www.tensorflow.org/versions/r0.12/get_started/os_setup) 44 | - [Caffe](https://github.com/BVLC/caffe) 45 | - OpenCV 2.4.9 46 | 47 | P.S. Please clone the repository with the `--recursive` flag: 48 | 49 | ```Shell 50 | # Make sure to clone with --recursive 51 | git clone --recursive https://github.com/tsenghungchen/show-adapt-and-tell.git 52 | ``` 53 | 54 | ## Data Preprocessing 55 | 56 | ### MSCOCO Captioning dataset 57 | 58 | #### Feature Extraction 59 | 1. Download the pretrained [ResNet-101 model](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777) and place it under `data-prepro/MSCOCO_preprocess/resnet_model/`. 60 | 2. Please modify the caffe path in `data-prepro/MSCOCO_preprocess/extract_resnet_coco.py`. 61 | 2. Go to `data-prepro/MSCOCO_preprocess` and run the following script: 62 | `./download_mscoco.sh` for downloading images and extracting features. 63 | 64 | #### Captions Tokenization 65 | 1. Clone the [NeuralTalk2](https://github.com/karpathy/neuraltalk2/tree/bd8c9d879f957e1218a8f9e1f9b663ac70375866) repository and head over to the coco/ folder and run the IPython notebook to generate a json file for Karpathy split: `coco_raw.json`. 66 | 2. Run the following script: 67 | `./prepro_mscoco_caption.sh` for downloading and tokenizing captions. 68 | 3. Run `python prepro_coco_annotation.py` to generate annotation json file for testing. 69 | 70 | ### CUB-200-2011 with Descriptions 71 | #### Feature Extraction 72 | 1. Run the script `./download_cub.sh` to download the images in CUB-200-2011. 73 | 2. Please modify the input/output path in `data-prepro/MSCOCO_preprocess/extract_resnet_coco.py` to extract and pack features in CUB-200-2011. 74 | 75 | #### Captions Tokenization 76 | 1. Download the [description data](https://drive.google.com/open?id=0B0ywwgffWnLLZW9uVHNjb2JmNlE). 77 | 2. Run `python get_split.py` to generate dataset split following the ECCV16 paper "Generating Visual Explanations". 78 | 3. Run `python prepro_cub_annotation.py` to generate annotation json file for testing. 79 | 4. Run `python CUB_preprocess_token.py` for tokenization. 80 | 81 | 82 | ## Models from the paper 83 | 84 | ### Pretrained Models 85 | Download all pretrained and adaption models: 86 | 87 | - [MSCOCO pretrained model](https://drive.google.com/drive/folders/0B340bHpZlbZzYW91R0UtNDRXUDA?usp=sharing) 88 | - [CUB-200-2011 adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzNUZybXNzWVR2VWM?usp=sharing) 89 | - [TGIF adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzX0ZWcFZ1YzdrSTg?usp=sharing) 90 | - [Flickr30k adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzNldjRmZVX3JXdVk?usp=sharing) 91 | 92 | ### Example Results 93 | Here are some example results where the captions are generated from these models: 94 | 95 | 96 | 97 | 98 | 109 | 120 | 121 |
99 | 100 |
101 | MSCOCO: A large air plane on a run way. 102 |
103 | CUB-200-2011: A large white and black airplane with a large beak. 104 |
105 | TGIF: A plane is flying over a field. 106 |
107 | Flickr30k: A large airplane is sitting on a runway. 108 |
110 | 111 |
112 | MSCOCO: A traffic light is seen in front of a large building. 113 |
114 | CUB-200-2011: A yellow traffic light with a yellow light. 115 |
116 | TGIF: A traffic light is hanging on a pole. 117 |
118 | Flickr30k: A street sign is lit up in the dark 119 |
122 | 123 | 124 | 125 | 136 | 147 | 148 |
126 | 127 |
128 | MSCOCO: A black dog sitting on the ground next to a window. 129 |
130 | CUB-200-2011: A black and white dog with a black head. 131 |
132 | TGIF: A dog is looking at something in the mirror. 133 |
134 | Flickr30k: A black dog is looking out of the window. 135 |
137 | 138 |
139 | MSCOCO: A man riding a skateboard up the side of a ramp. 140 |
141 | CUB-200-2011: A man riding a skateboard on a white ramp. 142 |
143 | TGIF: A man is doing a trick on a skateboard. 144 |
145 | Flickr30k: A man in a blue shirt is doing a trick on a skateboard. 146 |
149 | 150 | 151 | 152 | ## Training 153 | The training codes are under the `show-adapt-tell/` folder. 154 | 155 | Simply run `python main.py` for two steps of training: 156 | 157 | ### Training the source model with paired image-caption data 158 | Please set the Boolean value of `"G_is_pretrain"` to True in `main.py` to start pretraining the generator. 159 | ### Training the cross-domain captioner with unpaired data 160 | After pretraining, set `"G_is_pretrain"` to False to start training the cross-domain model. 161 | 162 | ## License 163 | 164 | Free for personal or research use; for commercial use please contact me. 165 | 166 | -------------------------------------------------------------------------------- /data-prepro/CUB200_preprocess/CUB_preprocess_token.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pdb 6 | import os 7 | import pickle 8 | import cPickle 9 | import string 10 | 11 | def unpickle(p): 12 | return cPickle.load(open(p,'r')) 13 | 14 | def load_json(p): 15 | return json.load(open(p,'r')) 16 | 17 | def clean_words(data): 18 | dict = {} 19 | freq = {} 20 | # start with 1 21 | idx = 1 22 | sentence_count = 0 23 | eliminate = 0 24 | max_w = 30 25 | for k in tqdm(range(len(data['caption']))): 26 | sen = data['caption'][k] 27 | filename = data['file_name'][k] 28 | # skip the no image description 29 | words = re.split(' ', sen) 30 | # pop the last u'.' 31 | n = len(words) 32 | if n <= max_w: 33 | sentence_count += 1 34 | for word in words: 35 | for p in string.punctuation: 36 | if p in word: 37 | word = word.replace(p,'') 38 | word = word.lower() 39 | if word not in dict.keys(): 40 | dict[word] = idx 41 | idx += 1 42 | freq[word] = 1 43 | else: 44 | freq[word] += 1 45 | else: 46 | eliminate += 1 47 | print 'Threshold(max_words) =', max_w 48 | print 'Eliminate =', eliminate 49 | print 'Total sentence_count =', sentence_count 50 | print 'Number of different words =', len(dict.keys()) 51 | print 'Saving....' 52 | np.savez('cleaned_words', dict=dict, freq=freq) 53 | return dict, freq 54 | 55 | 56 | phase = 'train' 57 | id2name = unpickle('id2name.pkl') 58 | id2caption = unpickle('id2caption.pkl') 59 | splits = unpickle('splits.pkl') 60 | split = splits[phase + '_id'] 61 | thres = 5 62 | 63 | filename_list = [] 64 | caption_list = [] 65 | img_id_list = [] 66 | for i in split: 67 | for sen in id2caption[i]: 68 | img_id_list.append(i) 69 | filename_list.append(id2name[i]) 70 | caption_list.append(sen) 71 | 72 | # build dictionary 73 | if not os.path.isfile('cub_data/dictionary_'+str(thres)+'.npz'): 74 | pdb.set_trace() 75 | # clean the words through the frequency 76 | words = np.load('K_cleaned_words.npz') 77 | dict = words['dict'].item(0) 78 | freq = words['freq'].item(0) 79 | idx2word = {} 80 | word2idx = {} 81 | idx = 1 82 | for k in tqdm(dict.keys()): 83 | if freq[k] >= thres: 84 | word2idx[k] = idx 85 | idx2word[str(idx)] = k 86 | idx += 1 87 | 88 | word2idx[u''] = len(word2idx.keys())+1 89 | idx2word[str(len(word2idx.keys()))] = u'' 90 | print 'Threshold of word fequency =', thres 91 | print 'Total words in the dictionary =', len(word2idx.keys()) 92 | np.savez('cub_data/dictionary_'+str(thres), word2idx=word2idx, idx2word=idx2word) 93 | else: 94 | tem = np.load('cub_data/dictionary_'+str(thres)+'.npz') 95 | word2idx = tem['word2idx'].item(0) 96 | idx2word = tem['idx2word'].item(0) 97 | 98 | 99 | # generate tokenized data 100 | num_sentence = 0 101 | eliminate = 0 102 | tokenized_caption_list = [] 103 | caption_list_new = [] 104 | filename_list_new = [] 105 | img_id_list_new = [] 106 | caption_length = [] 107 | for k in tqdm(range(len(caption_list))): 108 | sen = caption_list[k] 109 | img_id = img_id_list[k] 110 | filename = filename_list[k] 111 | # skip the no image description 112 | words = re.split(' ', sen) 113 | # pop the last u'.' 114 | count = 0 115 | valid = True 116 | tokenized_sent = np.ones([31],dtype=int) * word2idx[u''] # initialize as 117 | if len(words) <= 30: 118 | for word in words: 119 | try: 120 | word = word.lower() 121 | for p in string.punctuation: 122 | if p in word: 123 | word = word.replace(p,'') 124 | idx = int(word2idx[word]) 125 | tokenized_sent[count] = idx 126 | count += 1 127 | except KeyError: 128 | # if contain then drop the sentence in train phase 129 | valid = False 130 | break 131 | # add 132 | tokenized_sent[len(words)] = word2idx[u''] 133 | if valid: 134 | tokenized_caption_list.append(tokenized_sent) 135 | filename_list_new.append(filename) 136 | img_id_list_new.append(img_id) 137 | caption_list_new.append(sen) 138 | num_sentence += 1 139 | else: 140 | eliminate += 1 141 | tokenized_caption_info = {} 142 | tokenized_caption_info['tokenized_caption_list'] = np.asarray(tokenized_caption_list) 143 | tokenized_caption_info['filename_list'] = np.asarray(filename_list_new) 144 | tokenized_caption_info['img_id_list'] = np.asarray(img_id_list_new) 145 | tokenized_caption_info['raw_caption_list'] = np.asarray(caption_list_new) 146 | print 'Number of sentence =', num_sentence 147 | print 'eliminate = ', eliminate 148 | with open('./cub_data/tokenized_'+phase+'_caption.pkl', 'w') as outfile: 149 | pickle.dump(tokenized_caption_info, outfile) 150 | 151 | -------------------------------------------------------------------------------- /data-prepro/CUB200_preprocess/dictionary_5.npz: -------------------------------------------------------------------------------- 1 | ../MSCOCO_preprocess/dictionary_5.npz -------------------------------------------------------------------------------- /data-prepro/CUB200_preprocess/download_cub.sh: -------------------------------------------------------------------------------- 1 | mkdir cub_dataset 2 | cd cub_dataset 3 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 4 | tar zxvf CUB_200_2011.tgz 5 | # please download caption data on https://github.com/reedscot/cvpr2016. CUB_CVPR16 will be created after unzipping. 6 | 7 | -------------------------------------------------------------------------------- /data-prepro/CUB200_preprocess/get_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cPickle 4 | 5 | # generate name2id & id2name dictionary 6 | name_id_path = '../images.txt' 7 | name_id = open(name_id_path).read().splitlines() 8 | name2id = {} 9 | id2name = {} 10 | for img in name_id: 11 | name2id[img.split(' ')[1]] = img.split(' ')[0] 12 | id2name[img.split(' ')[0]] = img.split(' ')[1] 13 | 14 | cPickle.dump(name2id, open('name2id.pkl', 'wb')) 15 | cPickle.dump(id2name, open('id2name.pkl', 'wb')) 16 | 17 | # generate id2caption dictionary for all images 18 | # please download caption data on https://github.com/reedscot/cvpr2016. 19 | # CUB_CVPR16 will be created after unzipping. 20 | caption_path = './CUB_CVPR16/text_c10/' 21 | id2caption = {} 22 | for name in name2id: 23 | txt_name = '.'.join(name.split('.')[0:-1]) + '.txt' 24 | txt_path = os.path.join(caption_path, txt_name) 25 | id = name2id[name] 26 | id2caption[id] = open(txt_path).read().splitlines() 27 | 28 | cPickle.dump(id2caption, open('id2caption.pkl', 'wb')) 29 | 30 | # generate split dictionary 31 | train_path = './ECCV16_explanations_splits/train_noCub.txt' 32 | test_path = './ECCV16_explanations_splits/test.txt' 33 | val_path = './ECCV16_explanations_splits/val.txt' 34 | splits = {} 35 | splits['train_name'] = open(train_path).read().splitlines() 36 | splits['test_name'] = open(test_path).read().splitlines() 37 | splits['val_name'] = open(val_path).read().splitlines() 38 | 39 | splits['train_id'] = [name2id[n] for n in splits['train_name']] 40 | splits['test_id'] = [name2id[n] for n in splits['test_name']] 41 | splits['val_id'] = [name2id[n] for n in splits['val_name']] 42 | 43 | cPickle.dump(splits, open('splits.pkl', 'wb')) 44 | 45 | -------------------------------------------------------------------------------- /data-prepro/CUB200_preprocess/prepro_cub_annotation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import string 3 | import scipy.io as sio 4 | import numpy as np 5 | from tqdm import tqdm 6 | from random import shuffle, seed 7 | import pickle as pk 8 | import pdb 9 | input_data = 'split.pkl' 10 | with open(input_data) as data_file: 11 | dataset = pk.load(data_file) 12 | 13 | skip_num = 0 14 | val_data = {} 15 | test_data = {} 16 | train_data = [] 17 | 18 | val_dataset = [] 19 | test_dataset = [] 20 | counter = 0 21 | id2name = pk.load(open('id2name.pkl')) 22 | data = pk.load(open('id2caption.pkl')) 23 | 24 | for i in dataset['test_id']: 25 | caps = [] 26 | # For GT 27 | name = id2name[i] 28 | count = 0 29 | for sen in data[i]: 30 | for punc in string.punctuation: 31 | if punc in sen: 32 | sen = sen.replace(punc, '') 33 | 34 | tmp = {} 35 | tmp['filename'] = name 36 | tmp['img_id'] = i 37 | tmp['cap_id'] = count 38 | tmp['caption'] = sen 39 | count += 1 40 | caps.append(tmp) 41 | 42 | test_data[i] = caps 43 | print 'number of skip train data: ' + str(skip_num) 44 | [u'info', u'images', u'licenses', u'type', u'annotations'] 45 | json.dump(test_data, open('cub_data/K_test_annotation.json', 'w')) 46 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/K_cleaned_words.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/K_cleaned_words.npz -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/dictionary_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/dictionary_5.npz -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/download_mscoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # download mscoco images 3 | mkdir coco 4 | cd coco 5 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 6 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip 7 | wget http://msvocds.blob.core.windows.net/coco2014/test2014.zip 8 | unzip train2014.zip 9 | unzip val2014.zip 10 | unzip test2014.zip 11 | rm train2014.zip 12 | rm val2014.zip 13 | rm test2014.zip 14 | cd .. 15 | # please download the pretrained ResNet-101 model at https://github.com/KaimingHe/deep-residual-networks 16 | mkdir mscoco_data 17 | # extract resnet feature and pack in pickle format 18 | python extract_resnet_coco.py --def deep-residual-networks/prototxt/ResNet-101-deploy.prototxt --net resnet_model/ResNet-101-model.caffemodel --gpu 0 19 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/extract_resnet_coco.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/PaulChen/deep-residual-networks/caffe/python') 3 | import caffe 4 | import numpy as np 5 | import argparse 6 | import cv2 7 | import os, time 8 | import json 9 | import pdb 10 | import PIL 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import re 14 | import pickle as pk 15 | 16 | def parse_args(): 17 | """ 18 | Parse input arguments 19 | """ 20 | parser = argparse.ArgumentParser(description='Extract a CNN features') 21 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id to use', 22 | default=0, type=int) 23 | parser.add_argument('--def', dest='prototxt', 24 | help='prototxt file defining the network', 25 | default=None, type=str) 26 | parser.add_argument('--net', dest='caffemodel', 27 | help='model to test', 28 | default=None, type=str) 29 | 30 | if len(sys.argv) == 1: 31 | parser.print_help() 32 | sys.exit(1) 33 | 34 | args = parser.parse_args() 35 | return args 36 | 37 | def set_transformer(net): 38 | transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) 39 | transformer.set_transpose('data',(2,0,1)) 40 | transformer.set_mean('data', np.load(\ 41 | os.path.join('resnet_model','ResNet_mean.npy'))) 42 | transformer.set_input_scale('data', 255) 43 | return transformer 44 | 45 | def iter_frames(im): 46 | try: 47 | i= 0 48 | while 1: 49 | im.seek(i) 50 | imframe = im.copy() 51 | if i == 0: 52 | palette = imframe.getpalette() 53 | else: 54 | imframe.putpalette(palette) 55 | yield imframe 56 | i += 1 57 | except EOFError: 58 | pass 59 | 60 | def extract_image(net, image_file): 61 | batch_size = 1 62 | transformer = set_transformer(net) 63 | if image_file.split('.')[-1] == 'gif': 64 | img = Image.open(image_file).convert("P",palette=Image.ADAPTIVE, colors=256) 65 | newfile = ''.join(image_file.split('.')[:-1])+'.png' 66 | for i, frame in enumerate(iter_frames(img)): 67 | frame.save(newfile,**frame.info) 68 | image_file = newfile 69 | 70 | img = cv2.imread(image_file) 71 | img = img.astype('float') / 255 72 | net.blobs['data'].data[:] = transformer.preprocess('data', img) 73 | net.forward() 74 | blobs_out_pool5 = net.blobs['pool5'].data[0,:,0,0] 75 | return blobs_out_pool5 76 | 77 | 78 | def split(split, net, feat_dict): 79 | print 'load ' + split 80 | img_dir = './coco/' 81 | img_path = os.path.join(img_dir, split) 82 | img_list = os.listdir(img_path) 83 | pool5_list = [] 84 | prob_list = [] 85 | for k in tqdm(img_list): 86 | blobs_out_pool5 = extract_image(net, os.path.join(img_path,k)) 87 | feat_dict[k.split('.')[0]] = np.array(blobs_out_pool5) 88 | 89 | return feat_dict 90 | 91 | if __name__ == '__main__': 92 | args = parse_args() 93 | caffe_path = os.path.join('/home','PaulChen','caffe','python') 94 | 95 | print 'caffe setting' 96 | caffe.set_mode_gpu() 97 | caffe.set_device(args.gpu_id) 98 | 99 | print 'load caffe' 100 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) 101 | net.name = os.path.splitext(os.path.basename(args.caffemodel))[0] 102 | 103 | feat_dict = {} 104 | split('train2014', net, feat_dict) 105 | split('val2014', net, feat_dict) 106 | pk.dump(feat_dict, open('./mscoco_data/coco_trainval_feat.pkl','w')) 107 | 108 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/prepro_coco_annotation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import string 3 | import scipy.io as sio 4 | import numpy as np 5 | from tqdm import tqdm 6 | from random import shuffle, seed 7 | 8 | input_json = 'neuraltalk2/coco/coco_raw.json' 9 | with open(input_json) as data_file: 10 | data = json.load(data_file) 11 | 12 | seed(123) 13 | shuffle(data) 14 | 15 | skip_num = 0 16 | val_data = {} 17 | test_data = {} 18 | train_data_ = {} 19 | 20 | train_data = [] 21 | 22 | val_ann = [] 23 | 24 | val_dataset = [] 25 | test_dataset = [] 26 | train_dataset = [] 27 | 28 | counter = 0 29 | 30 | for i in tqdm(range(len(data))): 31 | if i < 5000: 32 | # For GT 33 | idx = data[i]['id'] 34 | caps = [] 35 | for j in range(len(data[i]['captions'])): 36 | sen = data[i]['captions'][j].lower() 37 | for punc in string.punctuation: 38 | if punc in sen: 39 | sen = sen.replace(punc, '') 40 | tmp = {} 41 | tmp['img_id'] = data[i]['id'] 42 | tmp['cap_id'] = j 43 | tmp['caption'] = sen 44 | caps.append(tmp) 45 | 46 | val_data[idx] = caps 47 | 48 | # For load 49 | tmp = {} 50 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0] 51 | tmp['img_id'] = idx 52 | val_dataset.append(tmp) 53 | 54 | elif i < 10000: 55 | idx = data[i]['id'] 56 | caps = [] 57 | for j in range(len(data[i]['captions'])): 58 | sen = data[i]['captions'][j].lower() 59 | for punc in string.punctuation: 60 | if punc in sen: 61 | sen = sen.replace(punc, '') 62 | tmp = {} 63 | tmp['img_id'] = data[i]['id'] 64 | tmp['cap_id'] = j 65 | tmp['caption'] = sen 66 | caps.append(tmp) 67 | 68 | test_data[idx] = caps 69 | 70 | tmp = {} 71 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0] 72 | tmp['img_id'] = idx 73 | test_dataset.append(tmp) 74 | 75 | 76 | else: 77 | idx = data[i]['id'] 78 | caps = [] 79 | for j in range(len(data[i]['captions'])): 80 | sen = data[i]['captions'][j].lower() 81 | for punc in string.punctuation: 82 | if punc in sen: 83 | sen = sen.replace(punc, '') 84 | 85 | 86 | 87 | tmp = {} 88 | tmp['img_id'] = data[i]['id'] 89 | tmp['cap_id'] = j 90 | tmp['caption'] = sen 91 | caps.append(tmp) 92 | 93 | train_data_[idx] = caps 94 | 95 | tmp = {} 96 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0] 97 | tmp['img_id'] = idx 98 | train_dataset.append(tmp) 99 | 100 | 101 | 102 | # FOR TRAINING 103 | for j in range(len(data[i]['captions'])): 104 | sen = data[i]['captions'][j].lower() 105 | 106 | for punc in string.punctuation: 107 | if punc in sen: 108 | sen = sen.replace(punc, '') 109 | 110 | if len(sen.split()) > 30: 111 | skip_num += 1 112 | continue 113 | 114 | tmp = {} 115 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0] 116 | tmp['img_id'] = data[i]['id'] 117 | tmp['caption'] = sen 118 | tmp['length'] = len(sen.split()) 119 | train_data.append(tmp) 120 | 121 | print 'number of skip train data: ' + str(skip_num) 122 | 123 | [u'info', u'images', u'licenses', u'type', u'annotations'] 124 | 125 | #json.dump(val_data, open('K_val_train.json', 'w')) 126 | json.dump(val_data, open('./mscoco_data/K_val_annotation.json', 'w')) 127 | json.dump(test_data, open('./mscoco_data/K_test_annotation.json', 'w')) 128 | json.dump(train_data_, open('./mscoco_data/K_train_annotation.json', 'w')) 129 | 130 | #json.dump(train_data, open('K_train_raw.json', 'w')) 131 | 132 | json.dump(val_dataset, open('./mscoco_data/K_val_data.json', 'w')) 133 | json.dump(test_dataset, open('./mscoco_data/K_test_data.json', 'w')) 134 | json.dump(train_dataset, open('./mscoco_data/K_train_data.json', 'w')) 135 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/prepro_mscoco_caption.sh: -------------------------------------------------------------------------------- 1 | # download and preprocess captions 2 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip 3 | unzip captions_train-val2014.zip 4 | rm captions_train-val2014.zip 5 | python preprocess_entity.py train 6 | python preprocess_entity.py test 7 | python preprocess_entity.py val 8 | python preprocess_token.py train 9 | python preprocess_token.py val 10 | python preprocess_token.py test 11 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/preprocess_entity.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | import pdb 7 | import sys 8 | def load_json(p): 9 | return json.load(open(p,'r')) 10 | 11 | desired_phase = sys.argv[1] 12 | split_path = 'K_split.json' 13 | split = load_json(split_path) 14 | split_id = split[desired_phase] 15 | 16 | phase = ['train', 'val'] 17 | id2name = {} 18 | name2id = {} 19 | id2caption = {} 20 | description_list = [] 21 | img_name = [] 22 | for p in phase: 23 | data_path = './annotations/captions_%s2014.json' % p 24 | data = load_json(data_path) 25 | for img_info in data['images']: 26 | if img_info['id'] in split_id: 27 | id2name[str(img_info['id'])] = img_info['file_name'] 28 | name2id[img_info['file_name']] = str(img_info['id']) 29 | id2caption[str(img_info['id'])] = [] 30 | count = 0 31 | for k in tqdm(range(len(data['annotations']))): 32 | sen = data['annotations'][k]['caption'] 33 | image_id = data['annotations'][k]['image_id'] 34 | if image_id in split_id: 35 | id2caption[str(image_id)].append(sen) 36 | file_name = id2name[str(image_id)] 37 | description_list.append(sen) 38 | img_name.append(file_name) 39 | 40 | out = {} 41 | out['caption_entity'] = description_list 42 | out['file_name'] = img_name 43 | out['id2filename'] = id2name 44 | out['filename2id'] = name2id 45 | out['id2caption'] = id2caption 46 | print 'Saving ...' 47 | print 'Numer of sentence =', len(description_list) 48 | with open('./mscoco_data/K_annotation_%s2014.pkl'%desired_phase, 'w') as outfile: 49 | pickle.dump(out, outfile) 50 | 51 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/preprocess_token.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pdb 6 | import os 7 | import pickle 8 | import cPickle 9 | import string 10 | import sys 11 | 12 | def unpickle(p): 13 | return cPickle.load(open(p,'r')) 14 | 15 | def load_json(p): 16 | return json.load(open(p,'r')) 17 | 18 | def clean_words(data): 19 | dict = {} 20 | freq = {} 21 | # start with 1 22 | idx = 1 23 | sentence_count = 0 24 | eliminate = 0 25 | max_w = 30 26 | for k in tqdm(range(len(data['caption_entity']))): 27 | sen = data['caption_entity'][k] 28 | filename = data['file_name'][k] 29 | # skip the no image description 30 | words = re.split(' ', sen) 31 | # pop the last u'.' 32 | n = len(words) 33 | if "" in words: 34 | words.remove("") 35 | if n <= max_w: 36 | sentence_count += 1 37 | for word in words: 38 | if "\n" in word: 39 | word = word.replace("\n", "") 40 | for p in string.punctuation: 41 | if p in word: 42 | word = word.replace(p,'') 43 | word = word.lower() 44 | if word not in dict.keys(): 45 | dict[word] = idx 46 | idx += 1 47 | freq[word] = 1 48 | else: 49 | freq[word] += 1 50 | else: 51 | eliminate += 1 52 | print 'Threshold(max_words) =', max_w 53 | print 'Eliminate =', eliminate 54 | print 'Total sentence_count =', sentence_count 55 | print 'Number of different words =', len(dict.keys()) 56 | print 'Saving....' 57 | np.savez('K_cleaned_words', dict=dict, freq=freq) 58 | return dict, freq 59 | 60 | phase = sys.argv[1] 61 | data_path = './mscoco_data/K_annotation_'+phase+'2014.pkl' 62 | data = unpickle(data_path) 63 | thres = 5 64 | if not os.path.isfile('./mscoco_data/dictionary_'+str(thres)+'.npz'): 65 | # clean the words through the frequency 66 | if not os.path.isfile('K_cleaned_words.npz'): 67 | dict, freq = clean_words(data) 68 | else: 69 | words = np.load('K_cleaned_words.npz') 70 | dict = words['dict'].item(0) 71 | freq = words['freq'].item(0) 72 | idx2word = {} 73 | word2idx = {} 74 | idx = 1 75 | for k in tqdm(dict.keys()): 76 | if freq[k] >= thres and k != "": 77 | word2idx[k] = idx 78 | idx2word[str(idx)] = k 79 | idx += 1 80 | 81 | word2idx[u''] = 0 82 | idx2word["0"] = u'' 83 | word2idx[u''] = len(word2idx.keys()) 84 | idx2word[str(len(idx2word.keys()))] = u'' 85 | word2idx[u''] = len(word2idx.keys()) 86 | idx2word[str(len(idx2word.keys()))] = u'' 87 | word2idx[u''] = len(word2idx.keys()) 88 | idx2word[str(len(idx2word.keys()))] = u'' 89 | print 'Threshold of word fequency =', thres 90 | print 'Total words in the dictionary =', len(word2idx.keys()) 91 | np.savez('./mscoco_data/dictionary_'+str(thres), word2idx=word2idx, idx2word=idx2word) 92 | else: 93 | tem = np.load('./mscoco_data/dictionary_'+str(thres)+'.npz') 94 | word2idx = tem['word2idx'].item(0) 95 | idx2word = tem['idx2word'].item(0) 96 | 97 | num_sentence = 0 98 | eliminate = 0 99 | tokenized_caption_list = [] 100 | caption_list = [] 101 | filename_list = [] 102 | caption_length = [] 103 | for k in tqdm(range(len(data['caption_entity']))): 104 | sen = data['caption_entity'][k] 105 | filename = data['file_name'][k] 106 | # skip the no image description 107 | words = re.split(' ', sen) 108 | # pop the last u'.' 109 | tokenized_sent = np.zeros([30+1], dtype=int) 110 | tokenized_sent.fill(int(word2idx[u''])) 111 | #tokenized_sent[0] = int(word2idx[u'']) 112 | valid = True 113 | count = 0 114 | caption = [] 115 | 116 | if len(words) <= 30: 117 | for word in words: 118 | try: 119 | word = word.lower() 120 | for p in string.punctuation: 121 | if p in word: 122 | word = word.replace(p,'') 123 | if word != "": 124 | idx = int(word2idx[word]) 125 | tokenized_sent[count] = idx 126 | caption.append(word) 127 | count += 1 128 | except KeyError: 129 | # if contain then drop the sentence 130 | if phase == 'train': 131 | valid = False 132 | break 133 | else: 134 | tokenized_sent[count] = int(word2idx[u'']) 135 | count += 1 136 | if valid: 137 | tokenized_sent[count] = (word2idx[""]) 138 | caption_list.append(caption) 139 | length = np.sum((tokenized_sent!=0)+0) 140 | tokenized_caption_list.append(tokenized_sent) 141 | filename_list.append(filename) 142 | caption_length.append(length) 143 | num_sentence += 1 144 | else: 145 | if phase == 'val': 146 | pdb.set_trace() 147 | eliminate += 1 148 | tokenized_caption_info = {} 149 | tokenized_caption_info['caption_length'] = np.asarray(caption_length) 150 | tokenized_caption_info['tokenized_caption_list'] = np.asarray(tokenized_caption_list) 151 | tokenized_caption_info['caption_list'] = np.asarray(caption_list) 152 | tokenized_caption_info['filename_list'] = np.asarray(filename_list) 153 | print 'Number of sentence =', num_sentence 154 | with open('./mscoco_data/tokenized_'+phase+'_caption.pkl', 'w') as outfile: 155 | pickle.dump(tokenized_caption_info, outfile) 156 | 157 | -------------------------------------------------------------------------------- /data-prepro/MSCOCO_preprocess/resnet_model/ResNet_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/resnet_model/ResNet_mean.npy -------------------------------------------------------------------------------- /images/im11063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im11063.jpg -------------------------------------------------------------------------------- /images/im22197.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im22197.jpg -------------------------------------------------------------------------------- /images/im270.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im270.jpg -------------------------------------------------------------------------------- /images/im6795.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im6795.jpg -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/teaser.png -------------------------------------------------------------------------------- /show-adapt-tell/cub: -------------------------------------------------------------------------------- 1 | ../data-prepro/CUB200_preprocess/cub_data -------------------------------------------------------------------------------- /show-adapt-tell/data: -------------------------------------------------------------------------------- 1 | ../data-prepro/MSCOCO_preprocess/mscoco_data -------------------------------------------------------------------------------- /show-adapt-tell/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils 3 | import os, re, json 4 | import pdb 5 | from tqdm import tqdm 6 | 7 | 8 | def get_key(name): 9 | return re.split('\.', name)[0] 10 | 11 | class mscoco_negative(): 12 | 13 | def __init__(self, dataset, conf): 14 | self.dataset_name = 'mscoco_negative' 15 | self.batch_size = conf.batch_size 16 | data_dir = './negative_samples/mscoco_sample' 17 | npz_paths = os.listdir(data_dir) 18 | print "Load Training data" 19 | count = 0 20 | self.neg_img_filename_train = [] 21 | for npz_path in tqdm(npz_paths): 22 | if int(re.split("\.", re.split("_", npz_path)[1])[0]) <= 30000: 23 | npz = np.load(os.path.join(data_dir, npz_path)) 24 | # tokenize caption 25 | if count == 0: 26 | self.neg_caption_train = npz["index"] 27 | else: 28 | self.neg_caption_train = np.concatenate((self.neg_caption_train, npz["index"]), 0) 29 | # img_idx 30 | for i in npz["img_name"]: 31 | self.neg_img_filename_train.append(i+'.jpg') 32 | count += 1 33 | self.neg_img_filename_train = np.asarray(self.neg_img_filename_train) 34 | 35 | npz_paths = ["mscoco_51000.npz"] 36 | print "Testing data" 37 | self.neg_img_filename_test = [] 38 | count = 0 39 | for npz_path in tqdm(npz_paths): 40 | npz = np.load(os.path.join(data_dir, npz_path)) 41 | if count == 0: 42 | self.neg_caption_test = npz["index"] 43 | else: 44 | self.neg_caption_test = np.concatenate((self.neg_caption_test, npz["index"]), 0) 45 | # img_idx 46 | for i in npz["img_name"]: 47 | self.neg_img_filename_test.append(i+'.jpg') 48 | count += 1 49 | self.neg_img_filename_test = np.asarray(self.neg_img_filename_test) 50 | 51 | self.current = 0 52 | self.num_train = len(self.neg_img_filename_train) 53 | self.num_test = len(self.neg_img_filename_test) 54 | self.random_shuffle() 55 | self.filename2id = dataset.filename2id 56 | self.img_dims = dataset.img_dims 57 | self.img_feat = dataset.img_feat 58 | 59 | def random_shuffle(self): 60 | idx = range(self.num_train) 61 | np.random.shuffle(idx) 62 | self.neg_img_filename_train = self.neg_img_filename_train[idx] 63 | self.neg_caption_train = self.neg_caption_train[idx, :] 64 | 65 | def get_paired_data(self, num_data, phase): 66 | if phase == 'train': 67 | caption = self.neg_caption_train 68 | img_filename = self.neg_img_filename_train 69 | else: 70 | caption = self.neg_caption_test 71 | img_filename = self.neg_img_filename_test 72 | 73 | if num_data > 0: 74 | caption = caption[:num_data, :] 75 | img_filename = img_filename[:num_data] 76 | else: 77 | if phase=='train': 78 | num_data = self.num_train 79 | else: 80 | num_data = self.num_test 81 | 82 | image_feature = np.zeros([num_data, self.img_dims]) 83 | img_idx = [] 84 | for i in range(num_data): 85 | image_feature[i, :] = self.img_feat[get_key(img_filename[i])] 86 | img_idx.append(get_key(img_filename[i])) 87 | return image_feature, caption, np.asarray(img_idx) 88 | 89 | def sequential_sample(self, batch_size): 90 | end = (self.current+batch_size) % self.num_train 91 | if self.current + batch_size < self.num_train: 92 | caption = self.neg_caption_train[self.current:end, :] 93 | img_filename = self.neg_img_filename_train[self.current:end] 94 | else: 95 | caption = np.concatenate((self.neg_caption_train[self.current:], self.neg_caption_train[:end]), axis=0) 96 | img_filename = np.concatenate((self.neg_img_filename_train[self.current:], self.neg_img_filename_train[:end]), axis=0) 97 | self.random_shuffle() 98 | 99 | image_feature = np.zeros([batch_size, self.img_dims]) 100 | img_id = [] 101 | for i in range(batch_size): 102 | image_feature[i, :] = self.img_feat[get_key(img_filename[i])] 103 | img_id.append(self.filename2id[img_filename[i]]) 104 | self.current = end 105 | return image_feature, caption, np.asarray(img_id) 106 | 107 | class mscoco(): 108 | 109 | def __init__(self, conf=None): 110 | # train img feature 111 | self.dataset_name = 'cub' 112 | # target data 113 | flickr_img_path = './cub/cub_train_resnet.pkl' 114 | self.train_flickr_img_feat = utils.unpickle(flickr_img_path) 115 | self.num_train_images_filckr = len(self.train_flickr_img_feat.keys()) 116 | self.train_img_idx = self.train_flickr_img_feat.keys() 117 | flickr_caption_train_data_path = './cub/tokenized_train_caption.pkl' 118 | flickr_caption_train_data = utils.unpickle(flickr_caption_train_data_path) 119 | self.flickr_caption_train = flickr_caption_train_data['tokenized_caption_list'] 120 | self.flickr_caption_idx_train = flickr_caption_train_data['filename_list'] 121 | self.num_flickr_train_caption = self.flickr_caption_train.shape[0] 122 | flickr_img_path = './cub/cub_test_resnet.pkl' 123 | self.test_flickr_img_feat = utils.unpickle(flickr_img_path) 124 | self.flickr_random_shuffle() # shuffle the text data 125 | 126 | # MSCOCO data 127 | img_feat_path = './data/coco_trainval_feat.pkl' 128 | self.img_feat = utils.unpickle(img_feat_path) 129 | train_meta_path = './data/K_annotation_train2014.pkl' 130 | train_meta = utils.unpickle(train_meta_path) 131 | self.filename2id = train_meta['filename2id'] 132 | val_meta_path = './data/K_annotation_val2014.pkl' 133 | val_meta = utils.unpickle(val_meta_path) 134 | self.id2filename = val_meta['id2filename'] 135 | # train caption 136 | caption_train_data_path = './data/tokenized_train_caption.pkl' 137 | caption_train_data = utils.unpickle(caption_train_data_path) 138 | self.caption_train = caption_train_data['tokenized_caption_list'] 139 | self.caption_idx_train = caption_train_data['filename_list'] 140 | # val caption 141 | caption_test_data_path = './data/tokenized_test_caption.pkl' 142 | caption_test_data = utils.unpickle(caption_test_data_path) 143 | self.caption_test = caption_test_data['tokenized_caption_list'] 144 | self.caption_idx_test = caption_test_data['filename_list'] 145 | dict_path = './data/dictionary_5.npz' 146 | temp = np.load(dict_path) 147 | self.ix2word = temp['idx2word'].item() 148 | self.word2ix = temp['word2idx'].item() 149 | # add token 150 | if conf != None: 151 | self.batch_size = conf.batch_size 152 | self.dict_size = len(self.ix2word.keys()) 153 | self.test_pointer = 0 154 | self.current_flickr = 0 155 | self.current_flickr_caption = 0 156 | self.current = 0 157 | self.max_words = self.caption_train.shape[1] 158 | tmp = self.img_feat[self.img_feat.keys()[0]] 159 | self.img_dims = tmp.shape[0] 160 | self.num_train = self.caption_train.shape[0] 161 | self.num_test = self.caption_test.shape[0] 162 | # Load annotation 163 | self.source_test_annotation = json.load(open('./data/K_val_annotation.json')) 164 | self.source_test_images = self.source_test_annotation.keys() 165 | self.source_num_test_images = len(self.source_test_images) 166 | self.test_annotation = json.load(open('./cub/K_test_annotation.json')) 167 | self.test_images = self.test_annotation.keys() 168 | self.num_test_images = len(self.test_images) 169 | self.random_shuffle() 170 | 171 | def random_shuffle(self): 172 | idx = range(self.num_train) 173 | np.random.shuffle(idx) 174 | self.caption_train = self.caption_train[idx] 175 | self.caption_idx_train = self.caption_idx_train[idx] 176 | 177 | def flickr_random_shuffle(self): 178 | idx = range(self.num_flickr_train_caption) 179 | np.random.shuffle(idx) 180 | self.flickr_caption_train = self.flickr_caption_train[idx] 181 | self.flickr_caption_idx_train = self.flickr_caption_idx_train[idx] 182 | 183 | def get_train_annotation(self): 184 | return self.train_annotation 185 | 186 | def get_train_for_eval(self, num): 187 | image_feature = np.zeros([num, self.img_dims]) 188 | filenames = [] 189 | self.random_shuffle() 190 | for i in range(num): 191 | filename = get_key(self.caption_idx_train[i]) 192 | filenames.append(filename) 193 | image_feature[i, :] = self.img_feat[filename] 194 | 195 | return image_feature, np.asarray(filenames) 196 | 197 | def get_test_for_eval(self): 198 | 199 | image_feature = np.zeros([self.num_test_images, self.img_dims]) 200 | image_id = np.zeros([self.num_test_images]) 201 | for i in range(self.num_test_images): 202 | image_feature[i, :] = self.test_flickr_img_feat[self.test_images[i]] 203 | image_id[i] = int(self.test_images[i]) 204 | 205 | return image_feature, image_id, self.test_annotation 206 | 207 | def get_source_test_for_eval(self): 208 | 209 | image_feature = np.zeros([self.source_num_test_images, self.img_dims]) 210 | image_id = np.zeros([self.source_num_test_images]) 211 | for i in range(self.source_num_test_images): 212 | image_feature[i, :] = self.img_feat[get_key(self.id2filename[self.source_test_images[i]])] 213 | image_id[i] = int(self.source_test_images[i]) 214 | 215 | return image_feature, image_id, self.source_test_annotation 216 | 217 | def get_wrong_text(self, num_data, phase='train'): 218 | assert phase=='train' 219 | idx = range(self.num_train) 220 | np.random.shuffle(idx) 221 | caption_train = self.caption_train[idx, :] 222 | return caption_train[:num_data, :] 223 | 224 | def get_paired_data(self, num_data, phase): 225 | if phase == 'train': 226 | caption = self.caption_train 227 | img_idx = self.caption_idx_train 228 | else: 229 | caption = self.caption_test 230 | img_idx = self.caption_idx_test 231 | 232 | if num_data > 0: 233 | caption = caption[:num_data, :] 234 | img_idx = img_idx[:num_data] 235 | else: 236 | if phase=='train': 237 | num_data = self.num_train 238 | else: 239 | num_data = self.num_test 240 | 241 | image_feature = np.zeros([num_data, self.img_dims]) 242 | for i in range(num_data): 243 | image_feature[i, :] = self.img_feat[get_key(img_idx[i])] 244 | return image_feature, caption, img_idx 245 | 246 | def preprocess(self, caption, lstm_steps): 247 | caption_padding = sequence.pad_sequences(caption, padding='post', maxlen=lstm_steps) 248 | return caption_padding 249 | 250 | def decode(self, sent_idx, type='string', remove_END=False): 251 | if len(sent_idx.shape) == 1: 252 | sent_idx = np.expand_dims(sent_idx, 0) 253 | sentences = [] 254 | indexes = [] 255 | for s in range(sent_idx.shape[0]): 256 | index = [] 257 | sentence = '' 258 | for i in range(sent_idx.shape[1]): 259 | if int(sent_idx[s][i]) == int(self.word2ix[u'']): 260 | if not remove_END: 261 | #sentence = sentence + '' 262 | index.append(int(sent_idx[s][i])) 263 | break 264 | else: 265 | try: 266 | word = self.ix2word[str(int(sent_idx[s][i]))] 267 | sentence = sentence + word + ' ' 268 | index.append(int(sent_idx[s][i])) 269 | except KeyError: 270 | sentence = sentence + "" + ' ' 271 | index.append(int(self.word2ix[u''])) 272 | indexes.append(index) 273 | sentences.append((sentence+'.').capitalize()) 274 | if type=='string': 275 | return sentences 276 | elif type=='index': 277 | return indexes 278 | 279 | def flickr_sequential_sample(self, batch_size): 280 | 281 | end = (self.current_flickr+batch_size) % self.num_train_images_filckr 282 | image_feature = np.zeros([batch_size, self.img_dims]) 283 | if self.current_flickr + batch_size < self.num_train_images_filckr: 284 | key = self.train_img_idx[self.current_flickr:end] 285 | else: 286 | key = np.concatenate((self.train_img_idx[self.current_flickr:], self.train_img_idx[:end]), axis=0) 287 | 288 | count = 0 289 | for k in key: 290 | image_feature[count] = self.train_flickr_img_feat[k] 291 | count += 1 292 | self.current_flickr = end 293 | return image_feature 294 | 295 | def flickr_caption_sequential_sample(self, batch_size): 296 | end = (self.current_flickr_caption+batch_size) % self.num_flickr_train_caption 297 | if self.current_flickr_caption + batch_size < self.num_flickr_train_caption: 298 | caption = self.flickr_caption_train[self.current_flickr_caption:end, :] 299 | else: 300 | caption = np.concatenate((self.flickr_caption_train[self.current_flickr_caption:], self.flickr_caption_train[:end]), axis=0) 301 | self.flickr_random_shuffle() 302 | 303 | self.current_flickr_caption = end 304 | return caption 305 | 306 | def sequential_sample(self, batch_size): 307 | end = (self.current+batch_size) % self.num_train 308 | if self.current + batch_size < self.num_train: 309 | caption = self.caption_train[self.current:end, :] 310 | img_idx = self.caption_idx_train[self.current:end] 311 | else: 312 | caption = np.concatenate((self.caption_train[self.current:], self.caption_train[:end]), axis=0) 313 | img_idx = np.concatenate((self.caption_idx_train[self.current:], self.caption_idx_train[:end]), axis=0) 314 | self.random_shuffle() 315 | 316 | image_feature = np.zeros([batch_size, self.img_dims]) 317 | img_id = [] 318 | for i in range(batch_size): 319 | image_feature[i, :] = self.img_feat[get_key(img_idx[i])] 320 | img_id.append(self.filename2id[img_idx[i]]) 321 | self.current = end 322 | return image_feature, caption, img_id 323 | 324 | -------------------------------------------------------------------------------- /show-adapt-tell/highway.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # highway layer that borrowed from https://github.com/carpedm20/lstm-char-cnn-tensorflow 4 | def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu): 5 | """Highway Network (cf. http://arxiv.org/abs/1505.00387). 6 | 7 | t = sigmoid(Wy + b) 8 | z = t * g(Wy + b) + (1 - t) * y 9 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate. 10 | """ 11 | output = input_ 12 | for idx in xrange(layer_size): 13 | output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx)) 14 | transform_gate = tf.sigmoid( 15 | tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias) 16 | carry_gate = 1. - transform_gate 17 | output = transform_gate * output + carry_gate * input_ 18 | return output 19 | 20 | -------------------------------------------------------------------------------- /show-adapt-tell/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc 3 | import numpy as np 4 | import tensorflow as tf 5 | from pretrain_G import G_pretrained 6 | from pretrain_CNN_D import D_pretrained 7 | from model import SeqGAN 8 | from data_loader import mscoco, mscoco_negative 9 | import pprint 10 | import pdb 11 | 12 | flags = tf.app.flags 13 | flags.DEFINE_integer("epoch", 100, "Epoch to train [100]") 14 | flags.DEFINE_float("learning_rate", 5e-5, "Learning rate of for adam [0.0003]") 15 | flags.DEFINE_float("drop_out_rate", 0.3, "Drop out rate fro LSTM") 16 | flags.DEFINE_float("discount", 0.95, "discount factor in RL") 17 | flags.DEFINE_string("model_name", "cub_no_scheduled", "") 18 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") # 128:G, 32:D 19 | flags.DEFINE_integer("G_hidden_size", 512, "") # 512:G, 64:D 20 | flags.DEFINE_integer("D_hidden_size", 512, "") 21 | flags.DEFINE_integer("max_iter", 100000, "") 22 | flags.DEFINE_integer('max_to_keep', 40, '') 23 | flags.DEFINE_string("method", "ROUGE_L", "") 24 | flags.DEFINE_string("load_ckpt", './checkpoint/mscoco/G_pretrained/G_Pretrained-39000', "Directory name to loade the checkpoints [checkpoint]") 25 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 26 | flags.DEFINE_boolean("G_is_pretrain", False, "Do the G pretraining") 27 | flags.DEFINE_boolean("D_is_pretrain", False, "Do the D pretraining") 28 | flags.DEFINE_boolean("load_pretrain", True, "Load the pretraining") 29 | flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]") 30 | 31 | # Setting from Self-critical Sequence Training for Image Captioning 32 | tf.app.flags.DEFINE_float('init_lr', 5e-4, '') # follow IBM's paper 33 | tf.app.flags.DEFINE_float('lr_decay', 0.8, 'learning rate decay factor') 34 | tf.app.flags.DEFINE_float('lr_decay_every', 6600, 'every 3 epoch 3*2200') 35 | tf.app.flags.DEFINE_float('ss_ascent', 0.05, 'schedule sampling') 36 | tf.app.flags.DEFINE_float('ss_ascent_every', 11000, 'every 5 epoch 5*2200') 37 | tf.app.flags.DEFINE_float('ss_max', 0.25, '0.05*5=0.25') 38 | 39 | FLAGS = flags.FLAGS 40 | pp = pprint.PrettyPrinter() 41 | def main(_): 42 | pp.pprint(flags.FLAGS.__flags) 43 | 44 | if not os.path.exists(FLAGS.checkpoint_dir): 45 | os.makedirs(FLAGS.checkpoint_dir) 46 | 47 | dataset = mscoco(FLAGS) 48 | config = tf.ConfigProto() 49 | config.gpu_options.per_process_gpu_memory_fraction = 1/10 50 | config.gpu_options.allow_growth = True 51 | with tf.Session(config=config) as sess: 52 | filter_sizes = [1,2,3,4,5,6,7,8,9,10,16,24,dataset.max_words] 53 | num_filters = [100,200,200,200,200,100,100,100,100,100,160,160,160] 54 | num_filters_total = sum(num_filters) 55 | info={'num_classes':3, 'filter_sizes':filter_sizes, 'num_filters':num_filters, 56 | 'num_filters_total':num_filters_total, 'l2_reg_lambda':0.2} 57 | if FLAGS.G_is_pretrain: 58 | G_pretrained_model = G_pretrained(sess, dataset, conf=FLAGS) 59 | if FLAGS.is_train: 60 | G_pretrained_model.train() 61 | G_pretrained_model.evaluate('test', 0, ) 62 | if FLAGS.D_is_pretrain: 63 | negative_dataset = mscoco_negative(dataset, FLAGS) 64 | D_pretrained_model = D_pretrained(sess, dataset, negative_dataset, info, conf=FLAGS) 65 | D_pretrained_model.train() 66 | if FLAGS.is_train: 67 | model = SeqGAN(sess, dataset, info, conf=FLAGS) 68 | model.train() 69 | 70 | if __name__ == '__main__': 71 | tf.app.run() 72 | -------------------------------------------------------------------------------- /show-adapt-tell/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | from highway import * 8 | import copy 9 | from coco_caption.pycocoevalcap.eval import COCOEvalCap 10 | import pdb 11 | 12 | def calculate_loss_and_acc_with_logits(predictions, logits, label, l2_loss, l2_reg_lambda): 13 | # Calculate Mean cross-entropy loss 14 | with tf.variable_scope("loss"): 15 | losses = tf.nn.softmax_cross_entropy_with_logits(tf.squeeze(logits), label) 16 | D_loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 17 | with tf.variable_scope("accuracy"): 18 | correct_predictions = tf.equal(predictions, tf.argmax(label, 1)) 19 | accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float")) 20 | return D_loss, accuracy 21 | 22 | 23 | class SeqGAN(): 24 | def __init__(self, sess, dataset, D_info, conf=None): 25 | self.sess = sess 26 | self.model_name = conf.model_name 27 | self.batch_size = conf.batch_size 28 | self.max_iter = conf.max_iter 29 | self.max_to_keep = conf.max_to_keep 30 | self.is_train = conf.is_train 31 | # Testing => dropout rate is 0 32 | if self.is_train: 33 | self.drop_out_rate = conf.drop_out_rate 34 | else: 35 | self.drop_out_rate = 0 36 | 37 | self.num_train = dataset.num_train 38 | self.G_hidden_size = conf.G_hidden_size # 512 39 | self.D_hidden_size = conf.D_hidden_size # 512 40 | self.dict_size = dataset.dict_size 41 | self.max_words = dataset.max_words 42 | self.dataset = dataset 43 | self.img_dims = self.dataset.img_dims 44 | self.checkpoint_dir = conf.checkpoint_dir 45 | self.lstm_steps = self.max_words+1 46 | self.START = self.dataset.word2ix[u''] 47 | self.END = self.dataset.word2ix[u''] 48 | self.UNK = self.dataset.word2ix[u''] 49 | self.NOT = self.dataset.word2ix[u''] 50 | self.method = conf.method 51 | self.discount = conf.discount 52 | self.load_pretrain = conf.load_pretrain 53 | self.filter_sizes = D_info['filter_sizes'] 54 | self.num_filters = D_info['num_filters'] 55 | self.num_filters_total = sum(self.num_filters) 56 | self.num_classes = D_info['num_classes'] 57 | self.num_domains = 3 58 | self.l2_reg_lambda = D_info['l2_reg_lambda'] 59 | 60 | 61 | # D placeholder 62 | self.images = tf.placeholder('float32', [self.batch_size, self.img_dims]) 63 | self.right_text = tf.placeholder('int32', [self.batch_size, self.max_words]) 64 | self.wrong_text = tf.placeholder('int32', [self.batch_size, self.max_words]) 65 | self.wrong_length = tf.placeholder('int32', [self.batch_size], name="wrong_length") 66 | self.right_length = tf.placeholder('int32', [self.batch_size], name="right_length") 67 | 68 | # Domain Classider 69 | self.src_images = tf.placeholder('float32', [self.batch_size, self.img_dims]) 70 | self.tgt_images = tf.placeholder('float32', [self.batch_size, self.img_dims]) 71 | self.src_text = tf.placeholder('int32', [self.batch_size, self.max_words]) 72 | self.tgt_text = tf.placeholder('int32', [self.batch_size, self.max_words]) 73 | # Optimizer 74 | self.G_optim = tf.train.AdamOptimizer(conf.learning_rate) 75 | self.D_optim = tf.train.AdamOptimizer(conf.learning_rate) 76 | self.T_optim = tf.train.AdamOptimizer(conf.learning_rate) 77 | self.Domain_image_optim = tf.train.AdamOptimizer(conf.learning_rate) 78 | self.Domain_text_optim = tf.train.AdamOptimizer(conf.learning_rate) 79 | D_info["sentence_length"] = self.max_words 80 | self.D_info = D_info 81 | 82 | ################################################### 83 | # Generator # 84 | ################################################### 85 | # G placeholder 86 | state_list, predict_words_list_sample, log_probs_action_picked_list, self.rollout_mask, self.predict_mask = self.generator(name='G', reuse=False) 87 | predict_words_sample = tf.pack(predict_words_list_sample) 88 | self.predict_words_sample = tf.transpose(predict_words_sample, [1,0]) # B,S 89 | # for testing 90 | # argmax prediction 91 | _, predict_words_list_argmax, log_probs_action_picked_list_argmax, _, self.predict_mask_argmax = self.generator_test(name='G', reuse=True) 92 | predict_words_argmax = tf.pack(predict_words_list_argmax) 93 | self.predict_words_argmax = tf.transpose(predict_words_argmax, [1,0]) # B,S 94 | rollout = [] 95 | rollout_length = [] 96 | rollout_num = 3 97 | for i in range(rollout_num): 98 | rollout_i, rollout_length_i = self.rollout(predict_words_list_sample, state_list, name="G") # S*B, S 99 | rollout.append(rollout_i) # R,B,S 100 | rollout_length.append(rollout_length_i) # R,B, 1 101 | 102 | rollout = tf.pack(rollout) # R,B,S 103 | rollout = tf.reshape(rollout, [-1, self.max_words]) # R*B,S 104 | rollout_length = tf.pack(rollout_length) # R,B,1 105 | rollout_length = tf.reshape(rollout_length, [-1, 1]) # R*B, 1 106 | rollout_length = tf.squeeze(rollout_length) 107 | rollout_size = self.batch_size * self.max_words * rollout_num 108 | images_expand = tf.expand_dims(self.images, 1) # B,1,I 109 | images_tile = tf.tile(images_expand, [1, self.max_words, 1]) # B,S,I 110 | images_tile_transpose = tf.transpose(images_tile, [1,0,2]) # S,B,I 111 | images_tile_transpose = tf.tile(tf.expand_dims(images_tile_transpose, 0), [rollout_num,1,1,1]) #R,S,B,I 112 | images_reshape = tf.reshape(images_tile_transpose, [-1, self.img_dims]) #R*S*B,I 113 | 114 | D_rollout_vqa_softmax, D_rollout_logits_vqa = self.discriminator(rollout_size, images_reshape, rollout, rollout_length, name="D", reuse=False) 115 | D_rollout_text, D_rollout_text_softmax, D_logits_rollout_text, l2_loss_rollout_text = self.text_discriminator(rollout, D_info, name="D_text", reuse=False) 116 | reward = tf.multiply(D_rollout_vqa_softmax[:,0], D_rollout_text_softmax[:,0]) # S*B, 1 117 | 118 | reward = tf.reshape(reward, [rollout_num, -1]) # R, S*B 119 | reward = tf.reduce_mean(reward, 0) # S*B 120 | 121 | self.rollout_reward = tf.reshape(reward, [self.max_words, self.batch_size]) # S,B 122 | D_logits_rollout_reshape = tf.reshape(self.rollout_reward, [-1]) 123 | self.G_loss = (-1)*tf.reduce_sum(log_probs_action_picked_list*tf.stop_gradient(D_logits_rollout_reshape)) / tf.reduce_sum(tf.stop_gradient(self.predict_mask)) 124 | 125 | # Teacher Forcing 126 | self.mask = tf.placeholder('float32', [self.batch_size, self.max_words]) # mask out the loss 127 | self.teacher_loss, self.teacher_loss_sum = self.Teacher_Forcing(self.right_text, self.mask, name="G", reuse=True) 128 | 129 | ################################################### 130 | # Discriminator # 131 | ################################################### 132 | # take the sample as fake data 133 | D_info["sentence_length"] = self.max_words 134 | 135 | # take the argmax sample as fake data 136 | self.fake_length = tf.reduce_sum(tf.stop_gradient(self.predict_mask),1) 137 | D_fake_vqa_softmax, D_fake_logits_vqa = self.discriminator(self.batch_size, self.images, tf.to_int32(self.predict_words_sample), tf.to_int32(self.fake_length), name="D", reuse=True) 138 | D_right_vqa_softmax, D_right_logits_vqa = self.discriminator(self.batch_size, self.images, self.right_text, 139 | self.right_length, name="D", reuse=True) 140 | D_wrong_vqa_softmax, D_wrong_logits_vqa = self.discriminator(self.batch_size, self.images, self.wrong_text, 141 | self.wrong_length, name="D", reuse=True) 142 | 143 | D_right_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_right_logits_vqa, 144 | tf.concat(1,(tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1)))))) 145 | D_wrong_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_wrong_logits_vqa, 146 | tf.concat(1,(tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1)))))) 147 | D_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_fake_logits_vqa, 148 | tf.concat(1,(tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1)))))) 149 | 150 | 151 | self.D_loss = D_fake_loss + D_right_loss + D_wrong_loss 152 | ################################################### 153 | # Text Domain Classifier 154 | ################################################### 155 | D_src_text, D_src_text_softmax, D_logits_src_text, l2_loss_src_text = self.text_discriminator(self.src_text, D_info, name="D_text", reuse=True) 156 | D_tgt_text, D_tgt_text_softmax, D_logits_tgt_text, l2_loss_tgt_text = self.text_discriminator(self.tgt_text, D_info, name="D_text", reuse=True) 157 | D_fake_text, D_fake_text_softmax, D_logits_fake_text, l2_loss_fake_text = self.text_discriminator(self.predict_words_sample, D_info, name="D_text", reuse=True) 158 | 159 | 160 | D_src_loss_text, D_src_acc_text = calculate_loss_and_acc_with_logits(D_src_text, 161 | D_logits_src_text, tf.concat(1,(tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1)), 162 | tf.ones((self.batch_size,1)))), l2_loss_src_text, D_info["l2_reg_lambda"]) 163 | D_fake_loss_text, D_fake_acc_text = calculate_loss_and_acc_with_logits(D_fake_text, 164 | D_logits_fake_text, tf.concat(1,(tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1)), 165 | tf.zeros((self.batch_size,1)))), l2_loss_fake_text, D_info["l2_reg_lambda"]) 166 | D_tgt_loss_text, D_tgt_acc_text = calculate_loss_and_acc_with_logits(D_tgt_text, 167 | D_logits_tgt_text, tf.concat(1,(tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1)), 168 | tf.zeros((self.batch_size,1)))), l2_loss_tgt_text, D_info["l2_reg_lambda"]) 169 | self.D_text_loss = D_src_loss_text + D_tgt_loss_text + D_fake_loss_text 170 | 171 | 172 | ########################## tensorboard summary:######################## 173 | # D_real_sum, D_fake_sum = the sigmoid output 174 | # D_real_loss_sum, D_fake_loss_sum = the loss for different kinds input 175 | # D_loss_sum, G_loss_sum = loss of the G&D 176 | ####################################################################### 177 | self.start_reward_sum = tf.scalar_summary("start_reward", tf.reduce_mean(self.rollout_reward[0,:])) 178 | self.total_reward_sum = tf.scalar_summary("total_mean_reward", tf.reduce_mean(self.rollout_reward)) 179 | self.logprobs_mean_sum = tf.scalar_summary("logprobs_mean", tf.reduce_sum(log_probs_action_picked_list)/tf.reduce_sum(self.predict_mask)) 180 | self.logprobs_dist_sum = tf.histogram_summary("log_probs", log_probs_action_picked_list) 181 | self.D_fake_loss_sum = tf.scalar_summary("D_fake_loss", D_fake_loss) 182 | self.D_wrong_loss_sum = tf.scalar_summary("D_wrong_loss", D_wrong_loss) 183 | self.D_right_loss_sum = tf.scalar_summary("D_right_loss", D_right_loss) 184 | self.D_loss_sum = tf.scalar_summary("D_loss", self.D_loss) 185 | self.G_loss_sum = tf.scalar_summary("G_loss", self.G_loss) 186 | ################################################### 187 | # Record the paramters # 188 | ################################################### 189 | params = tf.trainable_variables() 190 | self.R_params = [] 191 | self.G_params = [] 192 | self.D_params = [] 193 | self.G_params_dict = {} 194 | self.D_params_dict = {} 195 | for param in params: 196 | if "R" in param.name: 197 | self.R_params.append(param) 198 | elif "G" in param.name: 199 | self.G_params.append(param) 200 | self.G_params_dict.update({param.name:param}) 201 | elif "D" in param.name: 202 | self.D_params.append(param) 203 | self.D_params_dict.update({param.name:param}) 204 | print "Build graph complete" 205 | 206 | def rollout_update(self): 207 | for r, g in zip(self.R_params, self.G_params): 208 | assign_op = r.assign(g) 209 | self.sess.run(assign_op) 210 | def discriminator(self, batch_size, images, text, length, name="discriminator", reuse=False): 211 | 212 | ### sentence: B, S 213 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 214 | with tf.variable_scope(name): 215 | if reuse: 216 | tf.get_variable_scope().reuse_variables() 217 | with tf.variable_scope("lstm"): 218 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.D_hidden_size, state_is_tuple=True) 219 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 220 | with tf.device('/cpu:0'), tf.variable_scope("embedding"): 221 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.D_hidden_size], "float32", random_uniform_init) 222 | with tf.variable_scope("text_emb"): 223 | text_W = tf.get_variable("text_W", [2*self.D_hidden_size, self.D_hidden_size],"float32", random_uniform_init) 224 | text_b = tf.get_variable("text_b", [self.D_hidden_size], "float32", random_uniform_init) 225 | with tf.variable_scope("images_emb"): 226 | images_W = tf.get_variable("images_W", [self.img_dims, self.D_hidden_size],"float32", random_uniform_init) 227 | images_b = tf.get_variable("images_b", [self.D_hidden_size], "float32", random_uniform_init) 228 | with tf.variable_scope("scores_emb"): 229 | # "generator/scores" 230 | scores_W = tf.get_variable("scores_W", [self.D_hidden_size, 3], "float32", random_uniform_init) 231 | scores_b = tf.get_variable("scores_b", [3], "float32", random_uniform_init) 232 | 233 | state = lstm1.zero_state(batch_size, 'float32') 234 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[batch_size]) 235 | # VQA use states 236 | state_list = [] 237 | for j in range(self.max_words+1): 238 | if j > 0: 239 | tf.get_variable_scope().reuse_variables() 240 | with tf.device('/cpu:0'): 241 | if j ==0: 242 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 243 | else: 244 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, text[:,j-1]) 245 | with tf.variable_scope("lstm"): 246 | # "generator/lstm" 247 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 248 | # apppend state from index 1 (the start of the word) 249 | if j > 0: 250 | state_list.append(tf.concat(1,[state[0], state[1]])) 251 | 252 | state_list = tf.pack(state_list) # S,B,2H 253 | state_list = tf.transpose(state_list, [1,0,2]) # B,S,2H 254 | state_flatten = tf.reshape(state_list, [-1, 2*self.D_hidden_size]) # B*S, 2H 255 | # length-1 => index start from 0 256 | # need to prevent length = 0 257 | length_index = length-1 258 | condition = tf.greater_equal(length_index, 0) # B 259 | length_index = tf.select(condition, length_index, tf.constant(0, dtype=tf.int32, shape=[batch_size])) 260 | idx = tf.range(batch_size)*self.max_words + length_index # B 261 | state_gather = tf.gather(state_flatten, idx) # B, 2H 262 | # text embedding 263 | text_emb = tf.matmul(state_gather, text_W) + text_b # B,H 264 | text_emb = tf.nn.tanh(text_emb) 265 | # images embedding 266 | images_emb = tf.matmul(images, images_W) + images_b # B,H 267 | images_emb = tf.nn.tanh(images_emb) 268 | # embed to score 269 | logits = tf.mul(text_emb, images_emb) # B,H 270 | score = tf.matmul(logits, scores_W) + scores_b 271 | 272 | #return tf.nn.sigmoid(score), score 273 | return tf.nn.softmax(score), score 274 | 275 | 276 | def text_discriminator(self, sentence, info, name="text_discriminator", reuse=False): 277 | ### sentence: B, S 278 | hidden_size = self.D_hidden_size 279 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 280 | with tf.variable_scope(name): 281 | if reuse: 282 | tf.get_variable_scope().reuse_variables() 283 | with tf.device('/cpu:0'), tf.variable_scope("embedding"): 284 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, hidden_size], "float32", random_uniform_init) 285 | embedded_chars = tf.nn.embedding_lookup(word_emb_W, sentence) # B,S,H 286 | embedded_chars_expanded = tf.expand_dims(embedded_chars, -1) # B,S,H,1 287 | with tf.variable_scope("output"): 288 | output_W = tf.get_variable("output_W", [info["num_filters_total"], self.num_domains], 289 | "float32", random_uniform_init) 290 | output_b = tf.get_variable("output_b", [self.num_domains], "float32", random_uniform_init) 291 | # Create a convolution + maxpool layer for each filter size 292 | pooled_outputs = [] 293 | # Keeping track of l2 regularization loss (optional) 294 | l2_loss = tf.constant(0.0) 295 | for filter_size, num_filter in zip(info["filter_sizes"], info["num_filters"]): 296 | with tf.variable_scope("conv-maxpool-%s" % filter_size): 297 | # Convolution Layer 298 | filter_shape = [filter_size, hidden_size, 1, num_filter] 299 | W = tf.get_variable("W", filter_shape, "float32", random_uniform_init) 300 | b = tf.get_variable("b", [num_filter], "float32", random_uniform_init) 301 | conv = tf.nn.conv2d( 302 | embedded_chars_expanded, 303 | W, 304 | strides=[1, 1, 1, 1], 305 | padding="VALID", 306 | name="conv") 307 | # Apply nonlinearity 308 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 309 | # Maxpooling over the outputs 310 | pooled = tf.nn.max_pool( 311 | h, 312 | ksize=[1, info["sentence_length"] - filter_size + 1, 1, 1], 313 | strides=[1, 1, 1, 1], 314 | padding='VALID', 315 | name="pool") 316 | pooled_outputs.append(pooled) 317 | h_pool = tf.concat(3, pooled_outputs) # B,1,1,total filters 318 | h_pool_flat = tf.reshape(h_pool, [-1, info["num_filters_total"]]) # b, total filters 319 | 320 | # Add highway 321 | with tf.variable_scope("highway"): 322 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0) 323 | with tf.variable_scope("output"): 324 | l2_loss += tf.nn.l2_loss(output_W) 325 | l2_loss += tf.nn.l2_loss(output_b) 326 | logits = tf.nn.xw_plus_b(h_highway, output_W, output_b, name="logits") 327 | logits_softmax = tf.nn.softmax(logits) 328 | predictions = tf.argmax(logits_softmax, 1, name="predictions") 329 | return predictions, logits_softmax, logits, l2_loss 330 | 331 | def domain_classifier(self, images, name="G", reuse=False): 332 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 333 | with tf.variable_scope(name): 334 | tf.get_variable_scope().reuse_variables() 335 | with tf.variable_scope("images"): 336 | # "generator/images" 337 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init) 338 | images_emb = tf.matmul(images, images_W) # B,H 339 | 340 | l2_loss = tf.constant(0.0) 341 | with tf.variable_scope("domain"): 342 | if reuse: 343 | tf.get_variable_scope().reuse_variables() 344 | with tf.variable_scope("output"): 345 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.num_domains], 346 | "float32", random_uniform_init) 347 | output_b = tf.get_variable("output_b", [self.num_domains], "float32", random_uniform_init) 348 | l2_loss += tf.nn.l2_loss(output_W) 349 | l2_loss += tf.nn.l2_loss(output_b) 350 | logits = tf.nn.xw_plus_b(images_emb, output_W, output_b, name="logits") 351 | predictions = tf.argmax(logits, 1, name="predictions") 352 | 353 | return predictions, logits, l2_loss 354 | 355 | 356 | def rollout(self, predict_words, state_list, name="R"): 357 | 358 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 359 | with tf.variable_scope(name): 360 | tf.get_variable_scope().reuse_variables() 361 | with tf.variable_scope("images"): 362 | # "generator/images" 363 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init) 364 | with tf.variable_scope("lstm"): 365 | # WONT BE CREATED HERE 366 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True) 367 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 368 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 369 | # "R/embedding" 370 | word_emb_W = tf.get_variable("word_emb_W",[self.dict_size, self.G_hidden_size], "float32", random_uniform_init) 371 | with tf.variable_scope("output"): 372 | # "R/output" 373 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init) 374 | rollout_list = [] 375 | length_mask_list = [] 376 | # rollout for the first time step 377 | for step in range(self.max_words): 378 | sample_words = predict_words[step] 379 | state = state_list[step] 380 | rollout_step_list = [] 381 | mask = tf.constant(True, "bool", [self.batch_size]) 382 | # used to calcualte the length of the rollout sentence 383 | length_mask_step = [] 384 | for j in range(step+1): 385 | mask_out_word = tf.select(mask, predict_words[j], 386 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size])) 387 | rollout_step_list.append(mask_out_word) 388 | length_mask_step.append(mask) 389 | prev_mask = mask 390 | mask_step = tf.not_equal(predict_words[j], self.END) # B 391 | mask = tf.logical_and(prev_mask, mask_step) 392 | for j in range(self.max_words-step-1): 393 | if step != 0 or j != 0: 394 | tf.get_variable_scope().reuse_variables() 395 | with tf.device("/cpu:0"): 396 | sample_words_emb = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words)) 397 | with tf.variable_scope("lstm"): 398 | output, state = lstm1(sample_words_emb, state, scope=tf.get_variable_scope()) # output: B,H 399 | logits = tf.matmul(output, output_W) 400 | # add 1e-8 to prevent log(0) 401 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D 402 | sample_words = tf.squeeze(tf.multinomial(log_probs,1)) 403 | mask_out_word = tf.select(mask, sample_words, 404 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size])) 405 | rollout_step_list.append(mask_out_word) 406 | length_mask_step.append(mask) 407 | prev_mask = mask 408 | mask_step = tf.not_equal(sample_words, self.END) # B 409 | mask = tf.logical_and(prev_mask, mask_step) 410 | 411 | length_mask_step = tf.pack(length_mask_step) # S,B 412 | length_mask_step = tf.transpose(length_mask_step, [1,0]) # B,S 413 | length_mask_list.append(length_mask_step) 414 | rollout_step_list = tf.pack(rollout_step_list) # S,B 415 | rollout_step_list = tf.transpose(rollout_step_list, [1,0]) # B,S 416 | rollout_list.append(rollout_step_list) 417 | 418 | length_mask_list = tf.pack(length_mask_list) # S,B,S 419 | length_mask_list = tf.reshape(length_mask_list, [-1, self.max_words]) # S*B,S 420 | rollout_list = tf.pack(rollout_list) # S,B,S 421 | rollout_list = tf.reshape(rollout_list, [-1, self.max_words]) # S*B, S 422 | rollout_length = tf.to_int32(tf.reduce_sum(tf.to_float(length_mask_list),1)) 423 | return rollout_list, rollout_length 424 | 425 | def Teacher_Forcing(self, target_sentence, mask, name='generator', reuse=False): 426 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 427 | with tf.variable_scope(name): 428 | if reuse: 429 | tf.get_variable_scope().reuse_variables() 430 | with tf.variable_scope("images"): 431 | # "generator/images" 432 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init) 433 | with tf.variable_scope("lstm"): 434 | # "generator/lstm" 435 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True) 436 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 437 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 438 | # "generator/embedding" 439 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init) 440 | with tf.variable_scope("output"): 441 | # "generator/output" 442 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init) 443 | 444 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size]) 445 | state = lstm1.zero_state(self.batch_size, 'float32') 446 | teacher_loss = 0. 447 | for j in range(self.lstm_steps): 448 | if j == 0: 449 | images_emb = tf.matmul(self.images, images_W) # B,H 450 | lstm1_in = images_emb 451 | else: 452 | tf.get_variable_scope().reuse_variables() 453 | with tf.device("/cpu:0"): 454 | if j == 1: 455 | # 456 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 457 | else: 458 | # schedule sampling 459 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, target_sentence[:,j-2]) 460 | 461 | with tf.variable_scope("lstm"): 462 | # "generator/lstm" 463 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 464 | 465 | if j > 0: 466 | logits = tf.matmul(output, output_W) # B,D 467 | # calculate loss 468 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D 469 | action_picked = tf.range(self.batch_size)*(self.dict_size) + target_sentence[:,j-1] 470 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), mask[:,j-1]) 471 | loss_t = (-1)*tf.reduce_sum(log_probs_action_picked*tf.ones(self.batch_size)) 472 | teacher_loss += loss_t 473 | 474 | teacher_loss /= tf.reduce_sum(mask) 475 | teacher_loss_sum = tf.scalar_summary("teacher_loss", teacher_loss) 476 | 477 | return teacher_loss, teacher_loss_sum 478 | 479 | def generator(self, name='generator', reuse=False): 480 | 481 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 482 | with tf.variable_scope(name): 483 | if reuse: 484 | tf.get_variable_scope().reuse_variables() 485 | with tf.variable_scope("images"): 486 | # "generator/images" 487 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init) 488 | #images_b = tf.get_variable("images_b", [self.G_hidden_size], "float32", random_uniform_init) 489 | with tf.variable_scope("lstm"): 490 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True) 491 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 492 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 493 | # "generator/embedding" 494 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init) 495 | with tf.variable_scope("output"): 496 | # "generator/output" 497 | # dict size minus 1 => remove 498 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init) 499 | 500 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size]) 501 | state = lstm1.zero_state(self.batch_size, 'float32') 502 | mask = tf.constant(True, "bool", [self.batch_size]) 503 | log_probs_action_picked_list = [] 504 | predict_words = [] 505 | state_list = [] 506 | predict_mask_list = [] 507 | for j in range(self.max_words+1): 508 | if j == 0: 509 | #images_emb = tf.matmul(self.images, images_W) + images_b # B,H 510 | images_emb = tf.matmul(self.images, images_W) 511 | lstm1_in = images_emb 512 | else: 513 | tf.get_variable_scope().reuse_variables() 514 | with tf.device("/cpu:0"): 515 | if j == 1: 516 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 517 | else: 518 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words)) 519 | with tf.variable_scope("lstm"): 520 | # "generator/lstm" 521 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 522 | if j > 0: 523 | logits = tf.matmul(output, output_W) 524 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D 525 | # word drawn from the multinomial distribution 526 | sample_words = tf.reshape(tf.multinomial(log_probs,1), [self.batch_size]) 527 | mask_out_word = tf.select(mask, sample_words, 528 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size])) 529 | predict_words.append(mask_out_word) 530 | #predict_words.append(sample_words) 531 | # the mask should be dynamic 532 | # if the sentence is: This is a dog 533 | # the predict_mask_list is: 1,1,1,1,1,0,0,..... 534 | predict_mask_list.append(tf.to_float(mask)) 535 | action_picked = tf.range(self.batch_size)*(self.dict_size) + tf.to_int32(sample_words) # B 536 | # mask out the word beyond the 537 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), tf.to_float(mask)) 538 | log_probs_action_picked_list.append(log_probs_action_picked) 539 | prev_mask = mask 540 | mask_step = tf.not_equal(sample_words, self.END) # B 541 | mask = tf.logical_and(prev_mask, mask_step) 542 | state_list.append(state) 543 | 544 | predict_mask_list = tf.pack(predict_mask_list) # S,B 545 | predict_mask_list = tf.transpose(predict_mask_list, [1,0]) # B,S 546 | log_probs_action_picked_list = tf.pack(log_probs_action_picked_list) # S,B 547 | log_probs_action_picked_list = tf.reshape(log_probs_action_picked_list, [-1]) # S*B 548 | return state_list, predict_words, log_probs_action_picked_list, None, predict_mask_list 549 | 550 | def generator_test(self, name='generator', reuse=False): 551 | 552 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 553 | with tf.variable_scope(name): 554 | if reuse: 555 | tf.get_variable_scope().reuse_variables() 556 | with tf.variable_scope("images"): 557 | # "generator/images" 558 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init) 559 | with tf.variable_scope("lstm"): 560 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True) 561 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 562 | # "generator/embedding" 563 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init) 564 | with tf.variable_scope("output"): 565 | # "generator/output" 566 | # dict size minus 1 => remove 567 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init) 568 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size]) 569 | state = lstm1.zero_state(self.batch_size, 'float32') 570 | mask = tf.constant(True, "bool", [self.batch_size]) 571 | log_probs_action_picked_list = [] 572 | predict_words = [] 573 | state_list = [] 574 | predict_mask_list = [] 575 | for j in range(self.max_words+1): 576 | if j == 0: 577 | images_emb = tf.matmul(self.images, images_W) 578 | lstm1_in = images_emb 579 | else: 580 | tf.get_variable_scope().reuse_variables() 581 | with tf.device("/cpu:0"): 582 | if j == 1: 583 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 584 | else: 585 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words)) 586 | with tf.variable_scope("lstm"): 587 | # "generator/lstm" 588 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 589 | if j > 0: 590 | #logits = tf.matmul(output, output_W) + output_b # B,D 591 | logits = tf.matmul(output, output_W) 592 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D 593 | # word drawn from the multinomial distribution 594 | sample_words = tf.argmax(log_probs, 1) # B 595 | mask_out_word = tf.select(mask, sample_words, 596 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size])) 597 | predict_words.append(mask_out_word) 598 | # the mask should be dynamic 599 | # if the sentence is: This is a dog 600 | # the predict_mask_list is: 1,1,1,1,1,0,0,..... 601 | predict_mask_list.append(tf.to_float(mask)) 602 | action_picked = tf.range(self.batch_size)*(self.dict_size) + tf.to_int32(sample_words) # B 603 | # mask out the word beyond the 604 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), tf.to_float(mask)) 605 | log_probs_action_picked_list.append(log_probs_action_picked) 606 | prev_mask = mask 607 | mask_step = tf.not_equal(sample_words, self.END) # B 608 | mask = tf.logical_and(prev_mask, mask_step) 609 | state_list.append(state) 610 | 611 | predict_mask_list = tf.pack(predict_mask_list) # S,B 612 | predict_mask_list = tf.transpose(predict_mask_list, [1,0]) # B,S 613 | log_probs_action_picked_list = tf.pack(log_probs_action_picked_list) # S,B 614 | log_probs_action_picked_list = tf.reshape(log_probs_action_picked_list, [-1]) # S*B 615 | return state_list, predict_words, log_probs_action_picked_list, None, predict_mask_list 616 | 617 | 618 | def train(self): 619 | 620 | self.G_train_op = self.G_optim.minimize(self.G_loss, var_list=self.G_params) 621 | self.G_hat_train_op = self.T_optim.minimize(self.teacher_loss, var_list=self.G_params) 622 | self.D_train_op = self.D_optim.minimize(self.D_loss, var_list=self.D_params) 623 | self.Domain_text_train_op = self.Domain_text_optim.minimize(self.D_text_loss) 624 | log_dir = os.path.join('.', 'logs', self.model_name) 625 | if not os.path.exists(log_dir): 626 | os.makedirs(log_dir) 627 | #### Old version 628 | self.writer = tf.train.SummaryWriter(os.path.join(log_dir, "SeqGAN_sample"), self.sess.graph) 629 | self.summary_op = tf.merge_all_summaries() 630 | tf.initialize_all_variables().run() 631 | if self.load_pretrain: 632 | print "[@] Load the pretrained model" 633 | self.G_saver = tf.train.Saver(self.G_params_dict) 634 | self.G_saver.restore(self.sess, "./checkpoint/mscoco/G_pretrained/G_Pretrained-39000") 635 | 636 | self.saver = tf.train.Saver(max_to_keep=self.max_to_keep) 637 | count = 0 638 | D_count = 0 639 | G_count = 0 640 | for idx in range(self.max_iter//250): 641 | self.save(self.checkpoint_dir, count) 642 | self.evaluate(count) 643 | for _ in tqdm(range(250)): 644 | tgt_image_feature = self.dataset.flickr_sequential_sample(self.batch_size) 645 | tgt_text = self.dataset.flickr_caption_sequential_sample(self.batch_size) 646 | image_feature, right_text, _ = self.dataset.sequential_sample(self.batch_size) 647 | nonENDs = np.array(map(lambda x: (x != self.NOT).sum(), right_text)) 648 | mask_t = np.zeros([self.batch_size, self.max_words]) 649 | for ind, row in enumerate(mask_t): 650 | # mask out the 651 | row[0:nonENDs[ind]] = 1 652 | 653 | wrong_text = self.dataset.get_wrong_text(self.batch_size) 654 | right_length = np.sum((right_text!=self.NOT)+0, 1) 655 | wrong_length = np.sum((wrong_text!=self.NOT)+0, 1) 656 | for _ in range(1): # g_step 657 | # update G 658 | feed_dict = {self.images: tgt_image_feature} 659 | _, G_loss = self.sess.run([self.G_train_op, self.G_loss], feed_dict) 660 | G_count += 1 661 | for _ in range(20): # d_step 662 | # update D 663 | feed_dict = {self.images: image_feature, 664 | self.right_text:right_text, 665 | self.wrong_text:wrong_text, 666 | self.right_length:right_length, 667 | self.wrong_length:wrong_length, 668 | self.mask: mask_t, 669 | self.src_images: image_feature, 670 | self.tgt_images: tgt_image_feature, 671 | self.src_text: right_text, 672 | self.tgt_text: tgt_text} 673 | 674 | _, D_loss = self.sess.run([self.D_train_op, self.D_loss], feed_dict) 675 | D_count += 1 676 | _, D_text_loss = self.sess.run([self.Domain_text_train_op, self.D_text_loss], \ 677 | {self.src_text: right_text, 678 | self.tgt_text: tgt_text, 679 | self.images: tgt_image_feature 680 | }) 681 | 682 | count += 1 683 | 684 | def evaluate(self, count): 685 | 686 | samples = [] 687 | samples_index = [] 688 | image_feature, image_id, test_annotation = self.dataset.get_test_for_eval() 689 | num_samples = self.dataset.num_test_images 690 | samples_index = np.full([self.batch_size*(num_samples//self.batch_size), self.max_words], self.NOT) 691 | for i in range(num_samples//self.batch_size): 692 | image_feature_test = image_feature[i*self.batch_size:(i+1)*self.batch_size] 693 | feed_dict = {self.images: image_feature_test} 694 | predict_words = self.sess.run(self.predict_words_argmax, feed_dict) 695 | for j in range(self.batch_size): 696 | samples.append([self.dataset.decode(predict_words[j, :], type='string', remove_END=True)[0]]) 697 | sample_index = self.dataset.decode(predict_words[j, :], type='index', remove_END=False)[0] 698 | samples_index[i*self.batch_size+j][:len(sample_index)] = sample_index 699 | # predict from samples 700 | samples = np.asarray(samples) 701 | samples_index = np.asarray(samples_index) 702 | print '[%] Sentence:', samples[0] 703 | meteor_pd = {} 704 | meteor_id = [] 705 | for j in range(len(samples)): 706 | if image_id[j] == 0: 707 | break 708 | meteor_pd[str(int(image_id[j]))] = [{'image_id':str(int(image_id[j])), 'caption':samples[j][0]}] 709 | meteor_id.append(str(int(image_id[j]))) 710 | scorer = COCOEvalCap(test_annotation, meteor_pd, meteor_id) 711 | scorer.evaluate(verbose=True) 712 | sample_dir = os.path.join("./SeqGAN_samples_sample", self.model_name) 713 | if not os.path.exists(sample_dir): 714 | os.makedirs(sample_dir) 715 | file_name = "%s_%s" % (self.dataset.dataset_name, str(count)) 716 | np.savez(os.path.join(sample_dir, file_name), string=samples, index=samples_index, id=meteor_id) 717 | 718 | def save(self, checkpoint_dir, step): 719 | model_name = "SeqGAN_sample" 720 | model_dir = "%s" % (self.dataset.dataset_name) 721 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, self.model_name) 722 | if not os.path.exists(checkpoint_dir): 723 | os.makedirs(checkpoint_dir) 724 | self.saver.save(self.sess, 725 | os.path.join(checkpoint_dir, model_name), 726 | global_step=step) 727 | 728 | def load(self, checkpoint_dir): 729 | print(" [*] Reading checkpoints...") 730 | 731 | model_dir = "%s" % (self.dataset.dataset_name) 732 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 733 | 734 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 735 | if ckpt and ckpt.model_checkpoint_path: 736 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 737 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 738 | return True 739 | else: 740 | return False 741 | -------------------------------------------------------------------------------- /show-adapt-tell/pretrain_CNN_D.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | from highway import * 8 | import pdb 9 | 10 | class D_pretrained(): 11 | def __init__(self, sess, dataset, negative_dataset, D_info, conf=None, l2_reg_lambda=0.2): 12 | 13 | self.sess = sess 14 | self.batch_size = conf.batch_size 15 | self.max_iter = conf.max_iter 16 | self.num_train = dataset.num_train 17 | self.hidden_size = conf.D_hidden_size # 512 18 | self.dict_size = dataset.dict_size 19 | self.max_words = dataset.max_words 20 | self.dataset = dataset 21 | self.negative_dataset = negative_dataset 22 | self.checkpoint_dir = conf.checkpoint_dir 23 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False) 24 | self.optim = tf.train.AdamOptimizer(conf.learning_rate) 25 | self.filter_sizes = D_info['filter_sizes'] 26 | self.num_filters = D_info['num_filters'] 27 | self.num_filters_total = sum(self.num_filters) 28 | self.num_classes = D_info['num_classes'] 29 | self.l2_reg_lambda = l2_reg_lambda 30 | self.START = self.dataset.word2ix[u''] 31 | self.END = self.dataset.word2ix[u''] 32 | self.UNK = self.dataset.word2ix[u''] 33 | self.NOT = self.dataset.word2ix[u''] 34 | # placeholder 35 | self.text = tf.placeholder(tf.int32, [None, self.max_words], name="text") 36 | self.label = tf.placeholder(tf.float32, [None, self.num_classes], name="label") 37 | self.images = tf.placeholder(tf.float32, [None, self.dataset.img_dims], name="images") 38 | 39 | self.loss, self.pred = self.build_Discriminator(self.images, self.text, self.label, name='D') 40 | self.loss_sum = tf.scalar_summary("loss", self.loss) 41 | 42 | params = tf.trainable_variables() 43 | self.D_params_dict = {} 44 | self.D_params_train = [] 45 | for param in params: 46 | self.D_params_dict.update({param.name:param}) 47 | if "embedding" in param.name: 48 | embedding_matrix = np.load("embedding-42000.npy") 49 | self.embedding_assign_op = param.assign(tf.Variable(embedding_matrix, trainable=False)) 50 | else: 51 | self.D_params_train.append(param) 52 | 53 | def build_Discriminator(self, images, text, label, name="discriminator", reuse=False): 54 | 55 | ### sentence: B, S 56 | hidden_size = self.hidden_size 57 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 58 | with tf.variable_scope(name): 59 | if reuse: 60 | tf.get_variable_scope().reuse_variables() 61 | with tf.device('/cpu:0'), tf.variable_scope("embedding"): 62 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, hidden_size], "float32", random_uniform_init) 63 | embedded_chars = tf.nn.embedding_lookup(word_emb_W, text) # B,S,H 64 | embedded_chars_expanded = tf.expand_dims(embedded_chars, -1) # B,S,H,1 65 | with tf.variable_scope("output"): 66 | output_W = tf.get_variable("output_W", [hidden_size, self.num_classes], 67 | "float32", random_uniform_init) 68 | output_b = tf.get_variable("output_b", [self.num_classes], "float32", random_uniform_init) 69 | with tf.variable_scope("images"): 70 | images_W = tf.get_variable("images_W", [self.dataset.img_dims, hidden_size], 71 | "float32", random_uniform_init) 72 | images_b = tf.get_variable("images_b", [hidden_size], "float32", random_uniform_init) 73 | with tf.variable_scope("text"): 74 | text_W = tf.get_variable("text_W", [self.num_filters_total, hidden_size], 75 | "float32", random_uniform_init) 76 | text_b = tf.get_variable("text_b", [hidden_size], "float32", random_uniform_init) 77 | 78 | # Create a convolution + maxpool layer for each filter size 79 | pooled_outputs = [] 80 | # Keeping track of l2 regularization loss (optional) 81 | l2_loss = tf.constant(0.0) 82 | for filter_size, num_filter in zip(self.filter_sizes, self.num_filters): 83 | with tf.variable_scope("conv-maxpool-%s" % filter_size): 84 | # Convolution Layer 85 | filter_shape = [filter_size, hidden_size, 1, num_filter] 86 | W = tf.get_variable("W", filter_shape, "float32", random_uniform_init) 87 | b = tf.get_variable("b", [num_filter], "float32", random_uniform_init) 88 | conv = tf.nn.conv2d( 89 | embedded_chars_expanded, 90 | W, 91 | strides=[1, 1, 1, 1], 92 | padding="VALID", 93 | name="conv") 94 | # Apply nonlinearity 95 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 96 | # Maxpooling over the outputs 97 | pooled = tf.nn.max_pool( 98 | h, 99 | ksize=[1, self.max_words - filter_size + 1, 1, 1], 100 | strides=[1, 1, 1, 1], 101 | padding='VALID', 102 | name="pool") 103 | pooled_outputs.append(pooled) 104 | h_pool = tf.concat(3, pooled_outputs) # B,1,1,total filters 105 | h_pool_flat = tf.reshape(h_pool, [-1, self.num_filters_total]) # b, total filters 106 | # Add highway 107 | with tf.variable_scope("highway"): 108 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0) 109 | with tf.variable_scope("text"): 110 | text_emb = tf.nn.xw_plus_b(h_highway, text_W, text_b, name="text_emb") 111 | with tf.variable_scope("images"): 112 | images_emb = tf.nn.xw_plus_b(images, images_W, images_b, name="images_emb") 113 | with tf.variable_scope("output"): 114 | fusing_vec = tf.mul(text_emb, images_emb) 115 | l2_loss += tf.nn.l2_loss(output_W) 116 | l2_loss += tf.nn.l2_loss(output_b) 117 | logits = tf.nn.xw_plus_b(fusing_vec, output_W, output_b, name="logits") 118 | ypred_for_auc = tf.nn.softmax(logits) 119 | predictions = tf.argmax(logits, 1, name="predictions") 120 | #predictions = tf.nn.sigmoid(logits, name="predictions") 121 | # Calculate Mean cross-entropy loss 122 | with tf.variable_scope("loss"): 123 | losses = tf.nn.softmax_cross_entropy_with_logits(logits, label) 124 | #losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.squeeze(logits), self.input_y) 125 | loss = tf.reduce_mean(losses) + self.l2_reg_lambda * l2_loss 126 | 127 | return loss, predictions 128 | 129 | def train(self): 130 | 131 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step, var_list=self.D_params_train) 132 | #self.train_op = self.optim.minimize(self.loss, global_step=self.global_step) 133 | self.writer = tf.train.SummaryWriter("./logs/D_CNN_pretrained_sample", self.sess.graph) 134 | tf.initialize_all_variables().run() 135 | self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=30) 136 | # assign the G matrix to D pretrain 137 | self.sess.run(self.embedding_assign_op) 138 | count = 0 139 | for idx in range(self.max_iter//3000): 140 | self.save(self.checkpoint_dir, count) 141 | self.evaluate('test', count) 142 | self.evaluate('train', count) 143 | for k in tqdm(range(3000)): 144 | right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size) 145 | fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size) 146 | wrong_text = self.dataset.get_wrong_text(self.batch_size) 147 | 148 | images = np.concatenate((right_images, right_images, fake_images), axis=0) 149 | text = np.concatenate((right_text, wrong_text, fake_text.astype('int32')), axis=0) 150 | label = np.zeros((text.shape[0], self.num_classes)) 151 | # right -> first entry 152 | # wrong -> second entry 153 | # fake -> third entry 154 | label[:self.batch_size, 0] = 1 155 | label[self.batch_size:2*self.batch_size, 1] = 1 156 | label[2*self.batch_size:, 2] = 1 157 | _, loss, summary_str = self.sess.run([self.train_op, self.loss, self.loss_sum],{ 158 | self.text: text.astype('int32'), 159 | self.images: images, 160 | self.label: label 161 | }) 162 | self.writer.add_summary(summary_str, count) 163 | count += 1 164 | 165 | def evaluate(self, split, count): 166 | 167 | if split == 'test': 168 | num_test_pair = -1 169 | elif split == 'train': 170 | num_test_pair = 5000 171 | right_images, right_text, _ = self.dataset.get_paired_data(num_test_pair, phase=split) 172 | # the true paired data we get 173 | num_test_pair = len(right_images) 174 | fake_images, fake_text, _ = self.negative_dataset.get_paired_data(num_test_pair, phase=split) 175 | random_idx = range(num_test_pair) 176 | np.random.shuffle(random_idx) 177 | wrong_text = np.squeeze(right_text[random_idx, :]) 178 | count = 0. 179 | loss_t = [] 180 | right_acc_t = [] 181 | wrong_acc_t = [] 182 | fake_acc_t = [] 183 | for i in range(num_test_pair//self.batch_size): 184 | right_images_batch = right_images[i*self.batch_size:(i+1)*self.batch_size,:] 185 | fake_images_batch = fake_images[i*self.batch_size:(i+1)*self.batch_size,:] 186 | right_text_batch = right_text[i*self.batch_size:(i+1)*self.batch_size,:] 187 | fake_text_batch = fake_text[i*self.batch_size:(i+1)*self.batch_size,:] 188 | wrong_text_batch = wrong_text[i*self.batch_size:(i+1)*self.batch_size,:] 189 | text_batch = np.concatenate((right_text_batch, wrong_text_batch, fake_text_batch.astype('int32')), axis=0) 190 | images_batch = np.concatenate((right_images_batch, right_images_batch, fake_images_batch), axis=0) 191 | label = np.zeros((text_batch.shape[0], self.num_classes)) 192 | # right -> first entry 193 | # wrong -> second entry 194 | # fake -> third entry 195 | label[:self.batch_size, 0] = 1 196 | label[self.batch_size:2*self.batch_size, 1] = 1 197 | label[2*self.batch_size:, 2] = 1 198 | feed_dict = {self.images:images_batch, self.text:text_batch, self.label:label} 199 | loss, pred, loss_str = self.sess.run([self.loss, self.pred, self.loss_sum], feed_dict) 200 | loss_t.append(loss) 201 | right_acc_t.append(np.sum((np.argmax(label[:self.batch_size],1)==pred[:self.batch_size])+0)) 202 | wrong_acc_t.append(np.sum((np.argmax(label[self.batch_size:2*self.batch_size],1)==pred[self.batch_size:2*self.batch_size])+0)) 203 | fake_acc_t.append(np.sum((np.argmax(label[2*self.batch_size:],1)==pred[2*self.batch_size:])+0)) 204 | count += self.batch_size 205 | print "Phase =", split.capitalize() 206 | print "======================= Loss =====================" 207 | print '[$] Loss =', np.mean(loss_t) 208 | print "======================= Acc ======================" 209 | print '[$] Right Pair Acc. =', sum(right_acc_t)/count 210 | print '[$] Wrong Pair Acc. =', sum(wrong_acc_t)/count 211 | print '[$] Fake Pair Acc. =', sum(fake_acc_t)/count 212 | 213 | def save(self, checkpoint_dir, step): 214 | model_name = "D_Pretrained" 215 | model_dir = "%s" % (self.dataset.dataset_name) 216 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, "D_CNN_pretrained_sample") 217 | if not os.path.exists(checkpoint_dir): 218 | os.makedirs(checkpoint_dir) 219 | self.saver.save(self.sess, 220 | os.path.join(checkpoint_dir, model_name), 221 | global_step=step) 222 | 223 | def load(self, checkpoint_dir): 224 | print(" [*] Reading checkpoints...") 225 | 226 | model_dir = "%s" % (self.dataset.dataset_name) 227 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 228 | 229 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 230 | if ckpt and ckpt.model_checkpoint_path: 231 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 232 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 233 | return True 234 | else: 235 | return False 236 | 237 | -------------------------------------------------------------------------------- /show-adapt-tell/pretrain_G.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | from coco_spice.pycocoevalcap.eval import COCOEvalCap 8 | import pdb 9 | 10 | class G_pretrained(): 11 | def __init__(self, sess, dataset, conf=None): 12 | self.sess = sess 13 | self.batch_size = conf.batch_size 14 | self.max_iter = conf.max_iter 15 | self.num_train = dataset.num_train 16 | self.hidden_size = conf.G_hidden_size # 512 17 | self.dict_size = dataset.dict_size 18 | self.max_words = dataset.max_words 19 | self.dataset = dataset 20 | self.load_ckpt = conf.load_ckpt 21 | self.is_train = conf.is_train 22 | if self.is_train: 23 | self.drop_out_rate = conf.drop_out_rate 24 | else: 25 | self.drop_out_rate = 0 26 | 27 | self.init_lr = conf.init_lr 28 | self.lr_decay = conf.lr_decay 29 | self.lr_decay_every = conf.lr_decay_every 30 | self.ss_ascent = conf.ss_ascent 31 | self.ss_ascent_every = conf.ss_ascent_every 32 | self.ss_max = conf.ss_max 33 | # train pretrained model -> no need to add START_TOKEN 34 | # -> need to add END_TOKEN 35 | self.img_dims = self.dataset.img_dims 36 | self.lstm_steps = self.max_words+1 37 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False) 38 | #self.optim = tf.train.AdamOptimizer(conf.learning_rate) 39 | self.checkpoint_dir = conf.checkpoint_dir 40 | self.START = self.dataset.word2ix[u''] 41 | self.END = self.dataset.word2ix[u''] 42 | self.UNK = self.dataset.word2ix[u''] 43 | self.NOT = self.dataset.word2ix[u''] 44 | 45 | self.coins = tf.placeholder('bool', [self.batch_size, self.max_words-1]) 46 | self.images_one = tf.placeholder('float32', [100, self.img_dims]) 47 | self.images = tf.placeholder('float32', [self.batch_size, self.img_dims]) 48 | self.target_sentence = tf.placeholder('int32', [self.batch_size, self.max_words]) 49 | self.mask = tf.placeholder('float32', [self.batch_size, self.max_words]) # mask out the loss 50 | self.build_Generator(name='G') 51 | self._predict_words_argmax = [] 52 | self._predict_words_sample = [] 53 | self._predict_words_argmax = self.build_Generator_test(100, self._predict_words_argmax, type='max', name='G') 54 | self._predict_words_sample = self.build_Generator_test(100, self._predict_words_sample, type='sample', name='G') 55 | 56 | self.lr = tf.Variable(self.init_lr, trainable=False) 57 | self.optim = tf.train.AdamOptimizer(self.lr) 58 | 59 | params = tf.trainable_variables() 60 | self.G_params_dict = {} 61 | for param in params: 62 | self.G_params_dict.update({param.name:param}) 63 | 64 | def build_Generator_test(self, batch_size=100, predict_words=None, type='max', name='generator'): 65 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 66 | with tf.variable_scope(name): 67 | tf.get_variable_scope().reuse_variables() 68 | with tf.variable_scope("images"): 69 | # "generator/images" 70 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size], "float32", random_uniform_init) 71 | with tf.variable_scope("lstm"): 72 | # WONT BE CREATED HERE 73 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True) 74 | # lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 75 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 76 | # "generator/embedding" 77 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init) 78 | with tf.variable_scope("output"): 79 | # "generator/output" 80 | output_W = tf.get_variable("output_W", [self.hidden_size, self.dict_size], "float32", random_uniform_init) 81 | 82 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[batch_size]) 83 | state = lstm1.zero_state(batch_size, 'float32') 84 | for j in range(self.lstm_steps): 85 | tf.get_variable_scope().reuse_variables() 86 | if j == 0: 87 | images_emb = tf.matmul(self.images_one, images_W) # B,H 88 | lstm1_in = images_emb 89 | elif j == 1: 90 | with tf.device("/cpu:0"): 91 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 92 | else: 93 | with tf.device("/cpu:0"): 94 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, sample_words) 95 | with tf.variable_scope("lstm"): 96 | # "generator/lstm" 97 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 98 | if j > 0: 99 | logits = tf.matmul(output, output_W) # B,D 100 | #log_probs = tf.log(tf.nn.softmax(logits)) # B,D 101 | # word drawn from the multinomial distribution 102 | #sample_words = tf.reshape(tf.multinomial(log_probs,1), [batch_size]) 103 | sample_words = tf.argmax(logits, 1) 104 | predict_words.append(sample_words) 105 | 106 | predict_words = tf.pack(predict_words) 107 | predict_words = tf.transpose(predict_words, [1,0]) 108 | return predict_words 109 | 110 | def build_Generator(self, name='generator'): 111 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 112 | with tf.variable_scope(name): 113 | with tf.variable_scope("images"): 114 | # "generator/images" 115 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size], "float32", random_uniform_init) 116 | with tf.variable_scope("lstm"): 117 | # "generator/lstm" 118 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True) 119 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate) 120 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 121 | # "generator/embedding" 122 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init) 123 | with tf.variable_scope("output"): 124 | # "generator/output" 125 | output_W = tf.get_variable("output_W", [self.hidden_size, self.dict_size], "float32", random_uniform_init) 126 | 127 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size]) 128 | state = lstm1.zero_state(self.batch_size, 'float32') 129 | self.pretrained_loss = 0. 130 | for j in range(self.lstm_steps): 131 | if j == 0: 132 | images_emb = tf.matmul(self.images, images_W) # B,H 133 | lstm1_in = images_emb 134 | else: 135 | tf.get_variable_scope().reuse_variables() 136 | with tf.device("/cpu:0"): 137 | if j == 1: 138 | # 139 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 140 | else: 141 | # schedule sampling 142 | word = tf.select(self.coins[:,j-2], self.target_sentence[:,j-2], tf.stop_gradient(word_predict)) 143 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, word) 144 | 145 | with tf.variable_scope("lstm"): 146 | # "generator/lstm" 147 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 148 | 149 | if j > 0: 150 | logits = tf.matmul(output, output_W) # B,D 151 | # calculate loss 152 | pretrained_loss_t = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, self.target_sentence[:,j-1]) 153 | pretrained_loss_t = tf.reduce_sum(tf.mul(pretrained_loss_t, self.mask[:,j-1])) 154 | self.pretrained_loss += pretrained_loss_t 155 | word_predict = tf.to_int32(tf.argmax(logits, 1)) # B 156 | 157 | 158 | self.pretrained_loss /= tf.reduce_sum(self.mask) 159 | self.pretrained_loss_sum = tf.scalar_summary("pretrained_loss", self.pretrained_loss) 160 | 161 | def train(self): 162 | ''' 163 | Train a caption generator with XE 164 | with learning rate decay and schedule sampling 165 | ''' 166 | 167 | self.train_op = self.optim.minimize(self.pretrained_loss, global_step=self.global_step) 168 | self.writer = tf.train.SummaryWriter("./logs/G_pretrained", self.sess.graph) 169 | tf.initialize_all_variables().run() 170 | self.saver = tf.train.Saver(var_list=self.G_params_dict, max_to_keep=30) 171 | try: 172 | self.saver.restore(self.sess, self.load_ckpt) 173 | print "[#] Restore", self.load_ckpt 174 | except: 175 | print "[#] Fail to restore" 176 | 177 | self.current_lr = self.init_lr 178 | self.current_ss = 0. 179 | self.tr_count = 0 180 | for idx in range(self.max_iter//3000): 181 | print "Evaluate source test set..." 182 | self.evaluate('test', self.tr_count) 183 | print "Evaluate target test set..." 184 | self.evaluate('target_test', self.tr_count) 185 | self.evaluate('train', self.tr_count, eval_algo='max') 186 | self.evaluate('train', self.tr_count, eval_algo='sample') 187 | self.save(self.checkpoint_dir, self.tr_count) 188 | for k in tqdm(range(3000)): 189 | tgt_text = self.dataset.flickr_caption_sequential_sample(self.batch_size) 190 | image_feature, target, img_idx = self.dataset.sequential_sample(self.batch_size) 191 | # dummy_feature = np.zeros(image_feature.shape) 192 | nonENDs = np.array(map(lambda x: (x != self.NOT).sum(), target)) 193 | mask = np.zeros([self.batch_size, self.max_words]) 194 | tgt_mask = np.zeros([self.batch_size, self.max_words]) 195 | for ind, row in enumerate(mask): 196 | # mask out the 197 | row[0:nonENDs[ind]] = 1 198 | 199 | for ind, row in enumerate(tgt_mask): 200 | row[0:nonENDs[ind]] = 1 201 | # schedule sampling condition 202 | coins = np.zeros([self.batch_size, self.max_words-1]) 203 | for (x,y), value in np.ndenumerate(coins): 204 | if y==0: 205 | coins[x][y] = True 206 | elif np.random.rand() < self.current_ss: 207 | coins[x][y] = False 208 | else: 209 | coins[x][y] = True 210 | 211 | 212 | _, loss, summary_str = self.sess.run([self.train_op, self.pretrained_loss, self.pretrained_loss_sum],{ 213 | self.images: image_feature, 214 | self.target_sentence: target, 215 | self.mask: mask, 216 | self.coins: coins 217 | }) 218 | # _, dummy_loss, _ = self.sess.run([self.train_op, self.pretrained_loss, self.pretrained_loss_sum],{ 219 | # self.images: dummy_feature, 220 | # self.target_sentence: tgt_text, 221 | # self.mask: tgt_mask, 222 | # self.coins: coins 223 | # }) 224 | 225 | self.writer.add_summary(summary_str, self.tr_count) 226 | self.tr_count += 1 227 | 228 | #if k%1000 == 0: 229 | # print " [*] Iter {}, lr={}, ss={}, loss={}".format(self.tr_count, self.current_lr, self.current_ss, loss) 230 | 231 | if idx == 0 and k != 0 and k%1000 == 0: 232 | self.evaluate('train', self.tr_count, eval_algo='max') 233 | self.evaluate('train', self.tr_count, eval_algo='sample') 234 | self.evaluate('test', self.tr_count) 235 | self.evaluate('target_test', self.tr_count) 236 | # schedule sampling 237 | if (self.tr_count+1)%self.ss_ascent_every == 0 and self.current_ss'] 28 | self.END = self.dataset.word2ix[u''] 29 | self.UNK = self.dataset.word2ix[u''] 30 | self.NOT = self.dataset.word2ix[u''] 31 | 32 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False) 33 | self.optim = tf.train.AdamOptimizer(conf.learning_rate) 34 | 35 | # placeholder 36 | self.fake_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="fake_images") 37 | self.wrong_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="wrong_images") 38 | self.right_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="right_images") 39 | 40 | self.fake_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="fake_text") 41 | self.wrong_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="wrong_text") 42 | self.right_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="right_text") 43 | 44 | self.fake_length = tf.placeholder(tf.int32, [self.batch_size], name="fake_length") 45 | self.wrong_length = tf.placeholder(tf.int32, [self.batch_size], name="wrong_length") 46 | self.right_length = tf.placeholder(tf.int32, [self.batch_size], name="right_length") 47 | 48 | # build graph 49 | self.D_fake, D_fake_logits = self.build_Discriminator(self.fake_images, self.fake_text, self.fake_length, 50 | name="D", reuse=False) 51 | self.D_wrong, D_wrong_logits = self.build_Discriminator(self.wrong_images, self.wrong_text, self.wrong_length, 52 | name="D", reuse=True) 53 | self.D_right, D_right_logits = self.build_Discriminator(self.right_images, self.right_text, self.right_length, 54 | name="D", reuse=True) 55 | # loss 56 | self.D_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logits, tf.zeros_like(self.D_fake))) 57 | self.D_wrong_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_wrong_logits, tf.zeros_like(self.D_wrong))) 58 | self.D_right_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_right_logits, tf.ones_like(self.D_right))) 59 | self.loss = self.D_fake_loss+self.D_wrong_loss+self.D_right_loss 60 | # Summary 61 | self.D_fake_loss_sum = tf.scalar_summary("fake_loss", self.D_fake_loss) 62 | self.D_wrong_loss_sum = tf.scalar_summary("wrong_loss", self.D_wrong_loss) 63 | self.D_right_loss_sum = tf.scalar_summary("right_loss", self.D_right_loss) 64 | self.loss_sum = tf.scalar_summary("train_loss", self.loss) 65 | 66 | self.D_params_dict = {} 67 | params = tf.trainable_variables() 68 | for param in params: 69 | self.D_params_dict.update({param.name:param}) 70 | 71 | def build_Discriminator(self, images, text, length, name="discriminator", reuse=False): 72 | 73 | ### sentence: B, S 74 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 75 | with tf.variable_scope(name): 76 | if reuse: 77 | tf.get_variable_scope().reuse_variables() 78 | with tf.variable_scope("lstm"): 79 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True) 80 | with tf.device('/cpu:0'), tf.variable_scope("embedding"): 81 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init) 82 | with tf.variable_scope("text_emb"): 83 | text_W = tf.get_variable("text_W", [2*self.hidden_size, self.hidden_size],"float32", random_uniform_init) 84 | text_b = tf.get_variable("text_b", [self.hidden_size], "float32", random_uniform_init) 85 | with tf.variable_scope("images_emb"): 86 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size],"float32", random_uniform_init) 87 | images_b = tf.get_variable("images_b", [self.hidden_size], "float32", random_uniform_init) 88 | with tf.variable_scope("scores_emb"): 89 | # "generator/scores" 90 | scores_W = tf.get_variable("scores_W", [self.hidden_size, 1], "float32", random_uniform_init) 91 | scores_b = tf.get_variable("scores_b", [1], "float32", random_uniform_init) 92 | 93 | state = lstm1.zero_state(self.batch_size, 'float32') 94 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size]) 95 | # VQA use states 96 | state_list = [] 97 | for j in range(self.lstm_steps): 98 | if j > 0: 99 | tf.get_variable_scope().reuse_variables() 100 | with tf.device('/cpu:0'): 101 | if j ==0: 102 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 103 | else: 104 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, text[:,j-1]) 105 | with tf.variable_scope("lstm"): 106 | # "generator/lstm" 107 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H 108 | # apppend state from index 1 (the start of the word) 109 | if j > 0: 110 | state_list.append(tf.concat(1,[state[0], state[1]])) 111 | 112 | state_list = tf.pack(state_list) # S,B,2H 113 | state_list = tf.transpose(state_list, [1,0,2]) # B,S,2H 114 | state_flatten = tf.reshape(state_list, [-1, 2*self.hidden_size]) # B*S, 2H 115 | # length-1 => index start from 0 116 | idx = tf.range(self.batch_size)*self.max_words + (length-1) # B 117 | state_gather = tf.gather(state_flatten, idx) # B, 2H 118 | 119 | # text embedding 120 | text_emb = tf.matmul(state_gather, text_W) + text_b # B,H 121 | text_emb = tf.nn.tanh(text_emb) 122 | # images embedding 123 | images_emb = tf.matmul(images, images_W) + images_b # B,H 124 | images_emb = tf.nn.tanh(images_emb) 125 | # embed to score 126 | logits = tf.mul(text_emb, images_emb) # B,H 127 | score = tf.matmul(logits, scores_W) + scores_b 128 | 129 | return tf.nn.sigmoid(score), score 130 | 131 | def train(self): 132 | 133 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step) 134 | self.writer = tf.train.SummaryWriter("./logs/D_pretrained", self.sess.graph) 135 | self.summary_op = tf.merge_all_summaries() 136 | tf.initialize_all_variables().run() 137 | self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=self.max_to_keep) 138 | count = 0 139 | for idx in range(self.max_iter//3000): 140 | self.save(self.checkpoint_dir, count) 141 | self.evaluate('test', count) 142 | self.evaluate('train', count) 143 | for k in tqdm(range(3000)): 144 | right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size) 145 | right_length = np.sum((right_text!=self.NOT)+0, 1) 146 | fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size) 147 | fake_length = np.sum((fake_text!=self.NOT)+0, 1) 148 | wrong_text = self.dataset.get_wrong_text(self.batch_size) 149 | wrong_length = np.sum((wrong_text!=self.NOT)+0, 1) 150 | feed_dict = {self.right_images:right_images, self.right_text:right_text, self.right_length:right_length, 151 | self.fake_images:fake_images, self.fake_text:fake_text, self.fake_length:fake_length, 152 | self.wrong_images:right_images, self.wrong_text:wrong_text, self.wrong_length:wrong_length} 153 | _, loss, summary_str = self.sess.run([self.train_op, self.loss, self.summary_op], feed_dict) 154 | self.writer.add_summary(summary_str, count) 155 | count += 1 156 | 157 | def evaluate(self, split, count): 158 | 159 | if split == 'test': 160 | num_test_pair = -1 161 | elif split == 'train': 162 | num_test_pair = 5000 163 | right_images, right_text, _ = self.dataset.get_paired_data(num_test_pair, phase=split) 164 | # the true paired data we get 165 | num_test_pair = len(right_images) 166 | fake_images, fake_text, _ = self.negative_dataset.get_paired_data(num_test_pair, phase=split) 167 | random_idx = range(num_test_pair) 168 | np.random.shuffle(random_idx) 169 | wrong_text = np.squeeze(right_text[random_idx, :]) 170 | D_right_loss_t = [] 171 | D_fake_loss_t = [] 172 | D_wrong_loss_t = [] 173 | D_right_acc_t = [] 174 | D_fake_acc_t = [] 175 | D_wrong_acc_t = [] 176 | count = 0. 177 | for i in range(num_test_pair//self.batch_size): 178 | right_images_batch = right_images[i*self.batch_size:(i+1)*self.batch_size,:] 179 | fake_images_batch = fake_images[i*self.batch_size:(i+1)*self.batch_size,:] 180 | right_text_batch = right_text[i*self.batch_size:(i+1)*self.batch_size,:] 181 | fake_text_batch = fake_text[i*self.batch_size:(i+1)*self.batch_size,:] 182 | wrong_text_batch = wrong_text[i*self.batch_size:(i+1)*self.batch_size,:] 183 | right_length_batch = np.sum((right_text_batch!=self.NOT)+0, 1) 184 | fake_length_batch = np.sum((fake_text_batch!=self.NOT)+0, 1) 185 | wrong_length_batch = np.sum((wrong_text_batch!=self.NOT)+0, 1) 186 | feed_dict = {self.right_images:right_images_batch, self.right_text:right_text_batch, 187 | self.right_length:right_length_batch, self.fake_images:fake_images_batch, 188 | self.fake_text:fake_text_batch, self.fake_length:fake_length_batch, 189 | self.wrong_images:right_images_batch, self.wrong_text:wrong_text_batch, 190 | self.wrong_length:wrong_length_batch} 191 | D_right, D_fake, D_wrong, D_right_loss, D_fake_loss, D_wrong_loss = self.sess.run([self.D_right, self.D_fake, 192 | self.D_wrong, self.D_right_loss, self.D_fake_loss, self.D_wrong_loss], feed_dict) 193 | D_right_loss_t.append(D_right_loss) 194 | D_fake_loss_t.append(D_fake_loss) 195 | D_wrong_loss_t.append(D_wrong_loss) 196 | D_right_acc_t.append(np.sum((D_right>0.5)+0)) 197 | D_fake_acc_t.append(np.sum((D_fake<0.5)+0)) 198 | D_wrong_acc_t.append(np.sum((D_wrong<0.5)+0)) 199 | count += self.batch_size 200 | 201 | print "Phase =", split.capitalize() 202 | print "======================= Loss =====================" 203 | print '[$] Right Pair Loss =', sum(D_right_loss_t)/count 204 | print '[$] Wrong Pair Loss =', sum(D_wrong_loss_t)/count 205 | print '[$] Fake Pair Loss =', sum(D_fake_loss_t)/count 206 | print "======================= Acc ======================" 207 | print '[$] Right Pair Acc. =', sum(D_right_acc_t)/count 208 | print '[$] Wrong Pair Acc. =', sum(D_wrong_acc_t)/count 209 | print '[$] Fake Pair Acc. =', sum(D_fake_acc_t)/count 210 | 211 | def save(self, checkpoint_dir, step): 212 | model_name = "D_Pretrained" 213 | model_dir = "%s" % (self.dataset.dataset_name) 214 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, "D_pretrained") 215 | if not os.path.exists(checkpoint_dir): 216 | os.makedirs(checkpoint_dir) 217 | self.saver.save(self.sess, 218 | os.path.join(checkpoint_dir, model_name), 219 | global_step=step) 220 | 221 | def load(self, checkpoint_dir): 222 | print(" [*] Reading checkpoints...") 223 | 224 | model_dir = "%s" % (self.dataset.dataset_name) 225 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 226 | 227 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 228 | if ckpt and ckpt.model_checkpoint_path: 229 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 230 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 231 | return True 232 | else: 233 | return False 234 | 235 | -------------------------------------------------------------------------------- /show-adapt-tell/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import json 4 | import h5py 5 | from functools import reduce 6 | from tensorflow.contrib.layers.python.layers import initializers 7 | import cPickle 8 | import numpy as np 9 | 10 | 11 | def load_h5(file): 12 | train_data = {} 13 | with h5py.File(file,'r') as hf: 14 | for k in hf.keys(): 15 | tem = hf.get(k) 16 | train_data[k] = np.array(tem) 17 | return train_data 18 | 19 | def load_json(file): 20 | fo = open(file, 'rb') 21 | dict = json.load(fo) 22 | fo.close() 23 | return dict 24 | 25 | def unpickle(file): 26 | fo = open(file, 'rb') 27 | dict = cPickle.load(fo) 28 | fo.close() 29 | return dict 30 | 31 | def load_h5py(file, key=None): 32 | if key != None: 33 | with h5py.File(file,'r') as hf: 34 | data = hf.get(key) 35 | return np.asarray(data) 36 | else: 37 | print '[-] Can not load file' 38 | --------------------------------------------------------------------------------