├── .gitmodules
├── LICENSE
├── README.md
├── data-prepro
├── CUB200_preprocess
│ ├── CUB_preprocess_token.py
│ ├── ECCV16_explanations_splits
│ │ ├── test.txt
│ │ ├── train_noCub.txt
│ │ └── val.txt
│ ├── dictionary_5.npz
│ ├── download_cub.sh
│ ├── get_split.py
│ └── prepro_cub_annotation.py
└── MSCOCO_preprocess
│ ├── K_cleaned_words.npz
│ ├── K_split.json
│ ├── dictionary_5.npz
│ ├── download_mscoco.sh
│ ├── extract_resnet_coco.py
│ ├── prepro_coco_annotation.py
│ ├── prepro_mscoco_caption.sh
│ ├── preprocess_entity.py
│ ├── preprocess_token.py
│ └── resnet_model
│ └── ResNet_mean.npy
├── images
├── im11063.jpg
├── im22197.jpg
├── im270.jpg
├── im6795.jpg
└── teaser.png
└── show-adapt-tell
├── cub
├── data
├── data_loader.py
├── highway.py
├── main.py
├── model.py
├── pretrain_CNN_D.py
├── pretrain_G.py
├── pretrain_LSTM_D.py
└── utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data-prepro/MSCOCO_preprocess/neuraltalk2"]
2 | path = data-prepro/MSCOCO_preprocess/neuraltalk2
3 | url = git@github.com:karpathy/neuraltalk2.git
4 | [submodule "data-prepro/MSCOCO_preprocess/deep-residual-networks"]
5 | path = data-prepro/MSCOCO_preprocess/deep-residual-networks
6 | url = git@github.com:KaimingHe/deep-residual-networks.git
7 | [submodule "show-adapt-tell/coco-caption"]
8 | path = show-adapt-tell/coco-caption
9 | url = git@github.com:peteanderson80/coco-caption.git
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Paul Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # show-adapt-and-tell
2 |
3 | This is the official code for the paper
4 |
5 | **[Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner](https://arxiv.org/pdf/1705.00930.pdf)**
6 |
7 | [Tseng-Hung Chen](https://tsenghungchen.github.io/),
8 | [Yuan-Hong Liao](https://andrewliao11.github.io/),
9 | [Ching-Yao Chuang](http://jameschuanggg.github.io/),
10 | [Wan-Ting Hsu](https://hsuwanting.github.io/),
11 | [Jianlong Fu](https://www.microsoft.com/en-us/research/people/jianf/),
12 | [Min Sun](http://aliensunmin.github.io/)
13 |
14 | To appear in [ICCV 2017](http://iccv2017.thecvf.com/)
15 |
16 |
17 |
18 |

19 |
20 |
21 | In this repository we provide:
22 |
23 | - The cross-domain captioning models [used in the paper](#models-from-the-paper)
24 | - Script for [preprocessing MSCOCO data](#mscoco-captioning-dataset)
25 | - Script for [preprocessing CUB-200-2011 captions](#cub-200-2011-with-descriptions)
26 | - Code for [training the cross-domain captioning models](#training)
27 |
28 |
29 | If you find this code useful for your research, please cite
30 |
31 | ```
32 | @article{chen2017show,
33 | title={Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner},
34 | author={Chen, Tseng-Hung and Liao, Yuan-Hong and Chuang, Ching-Yao and Hsu, Wan-Ting and Fu, Jianlong and Sun, Min},
35 | journal={arXiv preprint arXiv:1705.00930},
36 | year={2017}
37 | }
38 | ```
39 |
40 | ## Requirements
41 |
42 | - Python 2.7
43 | - [Tensoflow 0.12.1](https://www.tensorflow.org/versions/r0.12/get_started/os_setup)
44 | - [Caffe](https://github.com/BVLC/caffe)
45 | - OpenCV 2.4.9
46 |
47 | P.S. Please clone the repository with the `--recursive` flag:
48 |
49 | ```Shell
50 | # Make sure to clone with --recursive
51 | git clone --recursive https://github.com/tsenghungchen/show-adapt-and-tell.git
52 | ```
53 |
54 | ## Data Preprocessing
55 |
56 | ### MSCOCO Captioning dataset
57 |
58 | #### Feature Extraction
59 | 1. Download the pretrained [ResNet-101 model](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777) and place it under `data-prepro/MSCOCO_preprocess/resnet_model/`.
60 | 2. Please modify the caffe path in `data-prepro/MSCOCO_preprocess/extract_resnet_coco.py`.
61 | 2. Go to `data-prepro/MSCOCO_preprocess` and run the following script:
62 | `./download_mscoco.sh` for downloading images and extracting features.
63 |
64 | #### Captions Tokenization
65 | 1. Clone the [NeuralTalk2](https://github.com/karpathy/neuraltalk2/tree/bd8c9d879f957e1218a8f9e1f9b663ac70375866) repository and head over to the coco/ folder and run the IPython notebook to generate a json file for Karpathy split: `coco_raw.json`.
66 | 2. Run the following script:
67 | `./prepro_mscoco_caption.sh` for downloading and tokenizing captions.
68 | 3. Run `python prepro_coco_annotation.py` to generate annotation json file for testing.
69 |
70 | ### CUB-200-2011 with Descriptions
71 | #### Feature Extraction
72 | 1. Run the script `./download_cub.sh` to download the images in CUB-200-2011.
73 | 2. Please modify the input/output path in `data-prepro/MSCOCO_preprocess/extract_resnet_coco.py` to extract and pack features in CUB-200-2011.
74 |
75 | #### Captions Tokenization
76 | 1. Download the [description data](https://drive.google.com/open?id=0B0ywwgffWnLLZW9uVHNjb2JmNlE).
77 | 2. Run `python get_split.py` to generate dataset split following the ECCV16 paper "Generating Visual Explanations".
78 | 3. Run `python prepro_cub_annotation.py` to generate annotation json file for testing.
79 | 4. Run `python CUB_preprocess_token.py` for tokenization.
80 |
81 |
82 | ## Models from the paper
83 |
84 | ### Pretrained Models
85 | Download all pretrained and adaption models:
86 |
87 | - [MSCOCO pretrained model](https://drive.google.com/drive/folders/0B340bHpZlbZzYW91R0UtNDRXUDA?usp=sharing)
88 | - [CUB-200-2011 adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzNUZybXNzWVR2VWM?usp=sharing)
89 | - [TGIF adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzX0ZWcFZ1YzdrSTg?usp=sharing)
90 | - [Flickr30k adaptation model](https://drive.google.com/drive/folders/0B340bHpZlbZzNldjRmZVX3JXdVk?usp=sharing)
91 |
92 | ### Example Results
93 | Here are some example results where the captions are generated from these models:
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 | MSCOCO: A large air plane on a run way.
102 |
103 | CUB-200-2011: A large white and black airplane with a large beak.
104 |
105 | TGIF: A plane is flying over a field.
106 |
107 | Flickr30k: A large airplane is sitting on a runway.
108 | |
109 |
110 |
111 |
112 | MSCOCO: A traffic light is seen in front of a large building.
113 |
114 | CUB-200-2011: A yellow traffic light with a yellow light.
115 |
116 | TGIF: A traffic light is hanging on a pole.
117 |
118 | Flickr30k: A street sign is lit up in the dark
119 | |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 | MSCOCO: A black dog sitting on the ground next to a window.
129 |
130 | CUB-200-2011: A black and white dog with a black head.
131 |
132 | TGIF: A dog is looking at something in the mirror.
133 |
134 | Flickr30k: A black dog is looking out of the window.
135 | |
136 |
137 |
138 |
139 | MSCOCO: A man riding a skateboard up the side of a ramp.
140 |
141 | CUB-200-2011: A man riding a skateboard on a white ramp.
142 |
143 | TGIF: A man is doing a trick on a skateboard.
144 |
145 | Flickr30k: A man in a blue shirt is doing a trick on a skateboard.
146 | |
147 |
148 |
149 |
150 |
151 |
152 | ## Training
153 | The training codes are under the `show-adapt-tell/` folder.
154 |
155 | Simply run `python main.py` for two steps of training:
156 |
157 | ### Training the source model with paired image-caption data
158 | Please set the Boolean value of `"G_is_pretrain"` to True in `main.py` to start pretraining the generator.
159 | ### Training the cross-domain captioner with unpaired data
160 | After pretraining, set `"G_is_pretrain"` to False to start training the cross-domain model.
161 |
162 | ## License
163 |
164 | Free for personal or research use; for commercial use please contact me.
165 |
166 |
--------------------------------------------------------------------------------
/data-prepro/CUB200_preprocess/CUB_preprocess_token.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import numpy as np
4 | from tqdm import tqdm
5 | import pdb
6 | import os
7 | import pickle
8 | import cPickle
9 | import string
10 |
11 | def unpickle(p):
12 | return cPickle.load(open(p,'r'))
13 |
14 | def load_json(p):
15 | return json.load(open(p,'r'))
16 |
17 | def clean_words(data):
18 | dict = {}
19 | freq = {}
20 | # start with 1
21 | idx = 1
22 | sentence_count = 0
23 | eliminate = 0
24 | max_w = 30
25 | for k in tqdm(range(len(data['caption']))):
26 | sen = data['caption'][k]
27 | filename = data['file_name'][k]
28 | # skip the no image description
29 | words = re.split(' ', sen)
30 | # pop the last u'.'
31 | n = len(words)
32 | if n <= max_w:
33 | sentence_count += 1
34 | for word in words:
35 | for p in string.punctuation:
36 | if p in word:
37 | word = word.replace(p,'')
38 | word = word.lower()
39 | if word not in dict.keys():
40 | dict[word] = idx
41 | idx += 1
42 | freq[word] = 1
43 | else:
44 | freq[word] += 1
45 | else:
46 | eliminate += 1
47 | print 'Threshold(max_words) =', max_w
48 | print 'Eliminate =', eliminate
49 | print 'Total sentence_count =', sentence_count
50 | print 'Number of different words =', len(dict.keys())
51 | print 'Saving....'
52 | np.savez('cleaned_words', dict=dict, freq=freq)
53 | return dict, freq
54 |
55 |
56 | phase = 'train'
57 | id2name = unpickle('id2name.pkl')
58 | id2caption = unpickle('id2caption.pkl')
59 | splits = unpickle('splits.pkl')
60 | split = splits[phase + '_id']
61 | thres = 5
62 |
63 | filename_list = []
64 | caption_list = []
65 | img_id_list = []
66 | for i in split:
67 | for sen in id2caption[i]:
68 | img_id_list.append(i)
69 | filename_list.append(id2name[i])
70 | caption_list.append(sen)
71 |
72 | # build dictionary
73 | if not os.path.isfile('cub_data/dictionary_'+str(thres)+'.npz'):
74 | pdb.set_trace()
75 | # clean the words through the frequency
76 | words = np.load('K_cleaned_words.npz')
77 | dict = words['dict'].item(0)
78 | freq = words['freq'].item(0)
79 | idx2word = {}
80 | word2idx = {}
81 | idx = 1
82 | for k in tqdm(dict.keys()):
83 | if freq[k] >= thres:
84 | word2idx[k] = idx
85 | idx2word[str(idx)] = k
86 | idx += 1
87 |
88 | word2idx[u''] = len(word2idx.keys())+1
89 | idx2word[str(len(word2idx.keys()))] = u''
90 | print 'Threshold of word fequency =', thres
91 | print 'Total words in the dictionary =', len(word2idx.keys())
92 | np.savez('cub_data/dictionary_'+str(thres), word2idx=word2idx, idx2word=idx2word)
93 | else:
94 | tem = np.load('cub_data/dictionary_'+str(thres)+'.npz')
95 | word2idx = tem['word2idx'].item(0)
96 | idx2word = tem['idx2word'].item(0)
97 |
98 |
99 | # generate tokenized data
100 | num_sentence = 0
101 | eliminate = 0
102 | tokenized_caption_list = []
103 | caption_list_new = []
104 | filename_list_new = []
105 | img_id_list_new = []
106 | caption_length = []
107 | for k in tqdm(range(len(caption_list))):
108 | sen = caption_list[k]
109 | img_id = img_id_list[k]
110 | filename = filename_list[k]
111 | # skip the no image description
112 | words = re.split(' ', sen)
113 | # pop the last u'.'
114 | count = 0
115 | valid = True
116 | tokenized_sent = np.ones([31],dtype=int) * word2idx[u''] # initialize as
117 | if len(words) <= 30:
118 | for word in words:
119 | try:
120 | word = word.lower()
121 | for p in string.punctuation:
122 | if p in word:
123 | word = word.replace(p,'')
124 | idx = int(word2idx[word])
125 | tokenized_sent[count] = idx
126 | count += 1
127 | except KeyError:
128 | # if contain then drop the sentence in train phase
129 | valid = False
130 | break
131 | # add
132 | tokenized_sent[len(words)] = word2idx[u'']
133 | if valid:
134 | tokenized_caption_list.append(tokenized_sent)
135 | filename_list_new.append(filename)
136 | img_id_list_new.append(img_id)
137 | caption_list_new.append(sen)
138 | num_sentence += 1
139 | else:
140 | eliminate += 1
141 | tokenized_caption_info = {}
142 | tokenized_caption_info['tokenized_caption_list'] = np.asarray(tokenized_caption_list)
143 | tokenized_caption_info['filename_list'] = np.asarray(filename_list_new)
144 | tokenized_caption_info['img_id_list'] = np.asarray(img_id_list_new)
145 | tokenized_caption_info['raw_caption_list'] = np.asarray(caption_list_new)
146 | print 'Number of sentence =', num_sentence
147 | print 'eliminate = ', eliminate
148 | with open('./cub_data/tokenized_'+phase+'_caption.pkl', 'w') as outfile:
149 | pickle.dump(tokenized_caption_info, outfile)
150 |
151 |
--------------------------------------------------------------------------------
/data-prepro/CUB200_preprocess/dictionary_5.npz:
--------------------------------------------------------------------------------
1 | ../MSCOCO_preprocess/dictionary_5.npz
--------------------------------------------------------------------------------
/data-prepro/CUB200_preprocess/download_cub.sh:
--------------------------------------------------------------------------------
1 | mkdir cub_dataset
2 | cd cub_dataset
3 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
4 | tar zxvf CUB_200_2011.tgz
5 | # please download caption data on https://github.com/reedscot/cvpr2016. CUB_CVPR16 will be created after unzipping.
6 |
7 |
--------------------------------------------------------------------------------
/data-prepro/CUB200_preprocess/get_split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cPickle
4 |
5 | # generate name2id & id2name dictionary
6 | name_id_path = '../images.txt'
7 | name_id = open(name_id_path).read().splitlines()
8 | name2id = {}
9 | id2name = {}
10 | for img in name_id:
11 | name2id[img.split(' ')[1]] = img.split(' ')[0]
12 | id2name[img.split(' ')[0]] = img.split(' ')[1]
13 |
14 | cPickle.dump(name2id, open('name2id.pkl', 'wb'))
15 | cPickle.dump(id2name, open('id2name.pkl', 'wb'))
16 |
17 | # generate id2caption dictionary for all images
18 | # please download caption data on https://github.com/reedscot/cvpr2016.
19 | # CUB_CVPR16 will be created after unzipping.
20 | caption_path = './CUB_CVPR16/text_c10/'
21 | id2caption = {}
22 | for name in name2id:
23 | txt_name = '.'.join(name.split('.')[0:-1]) + '.txt'
24 | txt_path = os.path.join(caption_path, txt_name)
25 | id = name2id[name]
26 | id2caption[id] = open(txt_path).read().splitlines()
27 |
28 | cPickle.dump(id2caption, open('id2caption.pkl', 'wb'))
29 |
30 | # generate split dictionary
31 | train_path = './ECCV16_explanations_splits/train_noCub.txt'
32 | test_path = './ECCV16_explanations_splits/test.txt'
33 | val_path = './ECCV16_explanations_splits/val.txt'
34 | splits = {}
35 | splits['train_name'] = open(train_path).read().splitlines()
36 | splits['test_name'] = open(test_path).read().splitlines()
37 | splits['val_name'] = open(val_path).read().splitlines()
38 |
39 | splits['train_id'] = [name2id[n] for n in splits['train_name']]
40 | splits['test_id'] = [name2id[n] for n in splits['test_name']]
41 | splits['val_id'] = [name2id[n] for n in splits['val_name']]
42 |
43 | cPickle.dump(splits, open('splits.pkl', 'wb'))
44 |
45 |
--------------------------------------------------------------------------------
/data-prepro/CUB200_preprocess/prepro_cub_annotation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import string
3 | import scipy.io as sio
4 | import numpy as np
5 | from tqdm import tqdm
6 | from random import shuffle, seed
7 | import pickle as pk
8 | import pdb
9 | input_data = 'split.pkl'
10 | with open(input_data) as data_file:
11 | dataset = pk.load(data_file)
12 |
13 | skip_num = 0
14 | val_data = {}
15 | test_data = {}
16 | train_data = []
17 |
18 | val_dataset = []
19 | test_dataset = []
20 | counter = 0
21 | id2name = pk.load(open('id2name.pkl'))
22 | data = pk.load(open('id2caption.pkl'))
23 |
24 | for i in dataset['test_id']:
25 | caps = []
26 | # For GT
27 | name = id2name[i]
28 | count = 0
29 | for sen in data[i]:
30 | for punc in string.punctuation:
31 | if punc in sen:
32 | sen = sen.replace(punc, '')
33 |
34 | tmp = {}
35 | tmp['filename'] = name
36 | tmp['img_id'] = i
37 | tmp['cap_id'] = count
38 | tmp['caption'] = sen
39 | count += 1
40 | caps.append(tmp)
41 |
42 | test_data[i] = caps
43 | print 'number of skip train data: ' + str(skip_num)
44 | [u'info', u'images', u'licenses', u'type', u'annotations']
45 | json.dump(test_data, open('cub_data/K_test_annotation.json', 'w'))
46 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/K_cleaned_words.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/K_cleaned_words.npz
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/dictionary_5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/dictionary_5.npz
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/download_mscoco.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # download mscoco images
3 | mkdir coco
4 | cd coco
5 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
6 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip
7 | wget http://msvocds.blob.core.windows.net/coco2014/test2014.zip
8 | unzip train2014.zip
9 | unzip val2014.zip
10 | unzip test2014.zip
11 | rm train2014.zip
12 | rm val2014.zip
13 | rm test2014.zip
14 | cd ..
15 | # please download the pretrained ResNet-101 model at https://github.com/KaimingHe/deep-residual-networks
16 | mkdir mscoco_data
17 | # extract resnet feature and pack in pickle format
18 | python extract_resnet_coco.py --def deep-residual-networks/prototxt/ResNet-101-deploy.prototxt --net resnet_model/ResNet-101-model.caffemodel --gpu 0
19 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/extract_resnet_coco.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/PaulChen/deep-residual-networks/caffe/python')
3 | import caffe
4 | import numpy as np
5 | import argparse
6 | import cv2
7 | import os, time
8 | import json
9 | import pdb
10 | import PIL
11 | from tqdm import tqdm
12 | from PIL import Image
13 | import re
14 | import pickle as pk
15 |
16 | def parse_args():
17 | """
18 | Parse input arguments
19 | """
20 | parser = argparse.ArgumentParser(description='Extract a CNN features')
21 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id to use',
22 | default=0, type=int)
23 | parser.add_argument('--def', dest='prototxt',
24 | help='prototxt file defining the network',
25 | default=None, type=str)
26 | parser.add_argument('--net', dest='caffemodel',
27 | help='model to test',
28 | default=None, type=str)
29 |
30 | if len(sys.argv) == 1:
31 | parser.print_help()
32 | sys.exit(1)
33 |
34 | args = parser.parse_args()
35 | return args
36 |
37 | def set_transformer(net):
38 | transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
39 | transformer.set_transpose('data',(2,0,1))
40 | transformer.set_mean('data', np.load(\
41 | os.path.join('resnet_model','ResNet_mean.npy')))
42 | transformer.set_input_scale('data', 255)
43 | return transformer
44 |
45 | def iter_frames(im):
46 | try:
47 | i= 0
48 | while 1:
49 | im.seek(i)
50 | imframe = im.copy()
51 | if i == 0:
52 | palette = imframe.getpalette()
53 | else:
54 | imframe.putpalette(palette)
55 | yield imframe
56 | i += 1
57 | except EOFError:
58 | pass
59 |
60 | def extract_image(net, image_file):
61 | batch_size = 1
62 | transformer = set_transformer(net)
63 | if image_file.split('.')[-1] == 'gif':
64 | img = Image.open(image_file).convert("P",palette=Image.ADAPTIVE, colors=256)
65 | newfile = ''.join(image_file.split('.')[:-1])+'.png'
66 | for i, frame in enumerate(iter_frames(img)):
67 | frame.save(newfile,**frame.info)
68 | image_file = newfile
69 |
70 | img = cv2.imread(image_file)
71 | img = img.astype('float') / 255
72 | net.blobs['data'].data[:] = transformer.preprocess('data', img)
73 | net.forward()
74 | blobs_out_pool5 = net.blobs['pool5'].data[0,:,0,0]
75 | return blobs_out_pool5
76 |
77 |
78 | def split(split, net, feat_dict):
79 | print 'load ' + split
80 | img_dir = './coco/'
81 | img_path = os.path.join(img_dir, split)
82 | img_list = os.listdir(img_path)
83 | pool5_list = []
84 | prob_list = []
85 | for k in tqdm(img_list):
86 | blobs_out_pool5 = extract_image(net, os.path.join(img_path,k))
87 | feat_dict[k.split('.')[0]] = np.array(blobs_out_pool5)
88 |
89 | return feat_dict
90 |
91 | if __name__ == '__main__':
92 | args = parse_args()
93 | caffe_path = os.path.join('/home','PaulChen','caffe','python')
94 |
95 | print 'caffe setting'
96 | caffe.set_mode_gpu()
97 | caffe.set_device(args.gpu_id)
98 |
99 | print 'load caffe'
100 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
101 | net.name = os.path.splitext(os.path.basename(args.caffemodel))[0]
102 |
103 | feat_dict = {}
104 | split('train2014', net, feat_dict)
105 | split('val2014', net, feat_dict)
106 | pk.dump(feat_dict, open('./mscoco_data/coco_trainval_feat.pkl','w'))
107 |
108 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/prepro_coco_annotation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import string
3 | import scipy.io as sio
4 | import numpy as np
5 | from tqdm import tqdm
6 | from random import shuffle, seed
7 |
8 | input_json = 'neuraltalk2/coco/coco_raw.json'
9 | with open(input_json) as data_file:
10 | data = json.load(data_file)
11 |
12 | seed(123)
13 | shuffle(data)
14 |
15 | skip_num = 0
16 | val_data = {}
17 | test_data = {}
18 | train_data_ = {}
19 |
20 | train_data = []
21 |
22 | val_ann = []
23 |
24 | val_dataset = []
25 | test_dataset = []
26 | train_dataset = []
27 |
28 | counter = 0
29 |
30 | for i in tqdm(range(len(data))):
31 | if i < 5000:
32 | # For GT
33 | idx = data[i]['id']
34 | caps = []
35 | for j in range(len(data[i]['captions'])):
36 | sen = data[i]['captions'][j].lower()
37 | for punc in string.punctuation:
38 | if punc in sen:
39 | sen = sen.replace(punc, '')
40 | tmp = {}
41 | tmp['img_id'] = data[i]['id']
42 | tmp['cap_id'] = j
43 | tmp['caption'] = sen
44 | caps.append(tmp)
45 |
46 | val_data[idx] = caps
47 |
48 | # For load
49 | tmp = {}
50 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0]
51 | tmp['img_id'] = idx
52 | val_dataset.append(tmp)
53 |
54 | elif i < 10000:
55 | idx = data[i]['id']
56 | caps = []
57 | for j in range(len(data[i]['captions'])):
58 | sen = data[i]['captions'][j].lower()
59 | for punc in string.punctuation:
60 | if punc in sen:
61 | sen = sen.replace(punc, '')
62 | tmp = {}
63 | tmp['img_id'] = data[i]['id']
64 | tmp['cap_id'] = j
65 | tmp['caption'] = sen
66 | caps.append(tmp)
67 |
68 | test_data[idx] = caps
69 |
70 | tmp = {}
71 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0]
72 | tmp['img_id'] = idx
73 | test_dataset.append(tmp)
74 |
75 |
76 | else:
77 | idx = data[i]['id']
78 | caps = []
79 | for j in range(len(data[i]['captions'])):
80 | sen = data[i]['captions'][j].lower()
81 | for punc in string.punctuation:
82 | if punc in sen:
83 | sen = sen.replace(punc, '')
84 |
85 |
86 |
87 | tmp = {}
88 | tmp['img_id'] = data[i]['id']
89 | tmp['cap_id'] = j
90 | tmp['caption'] = sen
91 | caps.append(tmp)
92 |
93 | train_data_[idx] = caps
94 |
95 | tmp = {}
96 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0]
97 | tmp['img_id'] = idx
98 | train_dataset.append(tmp)
99 |
100 |
101 |
102 | # FOR TRAINING
103 | for j in range(len(data[i]['captions'])):
104 | sen = data[i]['captions'][j].lower()
105 |
106 | for punc in string.punctuation:
107 | if punc in sen:
108 | sen = sen.replace(punc, '')
109 |
110 | if len(sen.split()) > 30:
111 | skip_num += 1
112 | continue
113 |
114 | tmp = {}
115 | tmp['file_id'] = data[i]['file_path'].split('/')[1].split('.')[0]
116 | tmp['img_id'] = data[i]['id']
117 | tmp['caption'] = sen
118 | tmp['length'] = len(sen.split())
119 | train_data.append(tmp)
120 |
121 | print 'number of skip train data: ' + str(skip_num)
122 |
123 | [u'info', u'images', u'licenses', u'type', u'annotations']
124 |
125 | #json.dump(val_data, open('K_val_train.json', 'w'))
126 | json.dump(val_data, open('./mscoco_data/K_val_annotation.json', 'w'))
127 | json.dump(test_data, open('./mscoco_data/K_test_annotation.json', 'w'))
128 | json.dump(train_data_, open('./mscoco_data/K_train_annotation.json', 'w'))
129 |
130 | #json.dump(train_data, open('K_train_raw.json', 'w'))
131 |
132 | json.dump(val_dataset, open('./mscoco_data/K_val_data.json', 'w'))
133 | json.dump(test_dataset, open('./mscoco_data/K_test_data.json', 'w'))
134 | json.dump(train_dataset, open('./mscoco_data/K_train_data.json', 'w'))
135 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/prepro_mscoco_caption.sh:
--------------------------------------------------------------------------------
1 | # download and preprocess captions
2 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip
3 | unzip captions_train-val2014.zip
4 | rm captions_train-val2014.zip
5 | python preprocess_entity.py train
6 | python preprocess_entity.py test
7 | python preprocess_entity.py val
8 | python preprocess_token.py train
9 | python preprocess_token.py val
10 | python preprocess_token.py test
11 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/preprocess_entity.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pickle
3 | import json
4 | import numpy as np
5 | from tqdm import tqdm
6 | import pdb
7 | import sys
8 | def load_json(p):
9 | return json.load(open(p,'r'))
10 |
11 | desired_phase = sys.argv[1]
12 | split_path = 'K_split.json'
13 | split = load_json(split_path)
14 | split_id = split[desired_phase]
15 |
16 | phase = ['train', 'val']
17 | id2name = {}
18 | name2id = {}
19 | id2caption = {}
20 | description_list = []
21 | img_name = []
22 | for p in phase:
23 | data_path = './annotations/captions_%s2014.json' % p
24 | data = load_json(data_path)
25 | for img_info in data['images']:
26 | if img_info['id'] in split_id:
27 | id2name[str(img_info['id'])] = img_info['file_name']
28 | name2id[img_info['file_name']] = str(img_info['id'])
29 | id2caption[str(img_info['id'])] = []
30 | count = 0
31 | for k in tqdm(range(len(data['annotations']))):
32 | sen = data['annotations'][k]['caption']
33 | image_id = data['annotations'][k]['image_id']
34 | if image_id in split_id:
35 | id2caption[str(image_id)].append(sen)
36 | file_name = id2name[str(image_id)]
37 | description_list.append(sen)
38 | img_name.append(file_name)
39 |
40 | out = {}
41 | out['caption_entity'] = description_list
42 | out['file_name'] = img_name
43 | out['id2filename'] = id2name
44 | out['filename2id'] = name2id
45 | out['id2caption'] = id2caption
46 | print 'Saving ...'
47 | print 'Numer of sentence =', len(description_list)
48 | with open('./mscoco_data/K_annotation_%s2014.pkl'%desired_phase, 'w') as outfile:
49 | pickle.dump(out, outfile)
50 |
51 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/preprocess_token.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import numpy as np
4 | from tqdm import tqdm
5 | import pdb
6 | import os
7 | import pickle
8 | import cPickle
9 | import string
10 | import sys
11 |
12 | def unpickle(p):
13 | return cPickle.load(open(p,'r'))
14 |
15 | def load_json(p):
16 | return json.load(open(p,'r'))
17 |
18 | def clean_words(data):
19 | dict = {}
20 | freq = {}
21 | # start with 1
22 | idx = 1
23 | sentence_count = 0
24 | eliminate = 0
25 | max_w = 30
26 | for k in tqdm(range(len(data['caption_entity']))):
27 | sen = data['caption_entity'][k]
28 | filename = data['file_name'][k]
29 | # skip the no image description
30 | words = re.split(' ', sen)
31 | # pop the last u'.'
32 | n = len(words)
33 | if "" in words:
34 | words.remove("")
35 | if n <= max_w:
36 | sentence_count += 1
37 | for word in words:
38 | if "\n" in word:
39 | word = word.replace("\n", "")
40 | for p in string.punctuation:
41 | if p in word:
42 | word = word.replace(p,'')
43 | word = word.lower()
44 | if word not in dict.keys():
45 | dict[word] = idx
46 | idx += 1
47 | freq[word] = 1
48 | else:
49 | freq[word] += 1
50 | else:
51 | eliminate += 1
52 | print 'Threshold(max_words) =', max_w
53 | print 'Eliminate =', eliminate
54 | print 'Total sentence_count =', sentence_count
55 | print 'Number of different words =', len(dict.keys())
56 | print 'Saving....'
57 | np.savez('K_cleaned_words', dict=dict, freq=freq)
58 | return dict, freq
59 |
60 | phase = sys.argv[1]
61 | data_path = './mscoco_data/K_annotation_'+phase+'2014.pkl'
62 | data = unpickle(data_path)
63 | thres = 5
64 | if not os.path.isfile('./mscoco_data/dictionary_'+str(thres)+'.npz'):
65 | # clean the words through the frequency
66 | if not os.path.isfile('K_cleaned_words.npz'):
67 | dict, freq = clean_words(data)
68 | else:
69 | words = np.load('K_cleaned_words.npz')
70 | dict = words['dict'].item(0)
71 | freq = words['freq'].item(0)
72 | idx2word = {}
73 | word2idx = {}
74 | idx = 1
75 | for k in tqdm(dict.keys()):
76 | if freq[k] >= thres and k != "":
77 | word2idx[k] = idx
78 | idx2word[str(idx)] = k
79 | idx += 1
80 |
81 | word2idx[u''] = 0
82 | idx2word["0"] = u''
83 | word2idx[u''] = len(word2idx.keys())
84 | idx2word[str(len(idx2word.keys()))] = u''
85 | word2idx[u''] = len(word2idx.keys())
86 | idx2word[str(len(idx2word.keys()))] = u''
87 | word2idx[u''] = len(word2idx.keys())
88 | idx2word[str(len(idx2word.keys()))] = u''
89 | print 'Threshold of word fequency =', thres
90 | print 'Total words in the dictionary =', len(word2idx.keys())
91 | np.savez('./mscoco_data/dictionary_'+str(thres), word2idx=word2idx, idx2word=idx2word)
92 | else:
93 | tem = np.load('./mscoco_data/dictionary_'+str(thres)+'.npz')
94 | word2idx = tem['word2idx'].item(0)
95 | idx2word = tem['idx2word'].item(0)
96 |
97 | num_sentence = 0
98 | eliminate = 0
99 | tokenized_caption_list = []
100 | caption_list = []
101 | filename_list = []
102 | caption_length = []
103 | for k in tqdm(range(len(data['caption_entity']))):
104 | sen = data['caption_entity'][k]
105 | filename = data['file_name'][k]
106 | # skip the no image description
107 | words = re.split(' ', sen)
108 | # pop the last u'.'
109 | tokenized_sent = np.zeros([30+1], dtype=int)
110 | tokenized_sent.fill(int(word2idx[u'']))
111 | #tokenized_sent[0] = int(word2idx[u''])
112 | valid = True
113 | count = 0
114 | caption = []
115 |
116 | if len(words) <= 30:
117 | for word in words:
118 | try:
119 | word = word.lower()
120 | for p in string.punctuation:
121 | if p in word:
122 | word = word.replace(p,'')
123 | if word != "":
124 | idx = int(word2idx[word])
125 | tokenized_sent[count] = idx
126 | caption.append(word)
127 | count += 1
128 | except KeyError:
129 | # if contain then drop the sentence
130 | if phase == 'train':
131 | valid = False
132 | break
133 | else:
134 | tokenized_sent[count] = int(word2idx[u''])
135 | count += 1
136 | if valid:
137 | tokenized_sent[count] = (word2idx[""])
138 | caption_list.append(caption)
139 | length = np.sum((tokenized_sent!=0)+0)
140 | tokenized_caption_list.append(tokenized_sent)
141 | filename_list.append(filename)
142 | caption_length.append(length)
143 | num_sentence += 1
144 | else:
145 | if phase == 'val':
146 | pdb.set_trace()
147 | eliminate += 1
148 | tokenized_caption_info = {}
149 | tokenized_caption_info['caption_length'] = np.asarray(caption_length)
150 | tokenized_caption_info['tokenized_caption_list'] = np.asarray(tokenized_caption_list)
151 | tokenized_caption_info['caption_list'] = np.asarray(caption_list)
152 | tokenized_caption_info['filename_list'] = np.asarray(filename_list)
153 | print 'Number of sentence =', num_sentence
154 | with open('./mscoco_data/tokenized_'+phase+'_caption.pkl', 'w') as outfile:
155 | pickle.dump(tokenized_caption_info, outfile)
156 |
157 |
--------------------------------------------------------------------------------
/data-prepro/MSCOCO_preprocess/resnet_model/ResNet_mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/data-prepro/MSCOCO_preprocess/resnet_model/ResNet_mean.npy
--------------------------------------------------------------------------------
/images/im11063.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im11063.jpg
--------------------------------------------------------------------------------
/images/im22197.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im22197.jpg
--------------------------------------------------------------------------------
/images/im270.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im270.jpg
--------------------------------------------------------------------------------
/images/im6795.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/im6795.jpg
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsenghungchen/show-adapt-and-tell/6daf82e519bbbf70fdd54d659ba17efb665395b5/images/teaser.png
--------------------------------------------------------------------------------
/show-adapt-tell/cub:
--------------------------------------------------------------------------------
1 | ../data-prepro/CUB200_preprocess/cub_data
--------------------------------------------------------------------------------
/show-adapt-tell/data:
--------------------------------------------------------------------------------
1 | ../data-prepro/MSCOCO_preprocess/mscoco_data
--------------------------------------------------------------------------------
/show-adapt-tell/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import utils
3 | import os, re, json
4 | import pdb
5 | from tqdm import tqdm
6 |
7 |
8 | def get_key(name):
9 | return re.split('\.', name)[0]
10 |
11 | class mscoco_negative():
12 |
13 | def __init__(self, dataset, conf):
14 | self.dataset_name = 'mscoco_negative'
15 | self.batch_size = conf.batch_size
16 | data_dir = './negative_samples/mscoco_sample'
17 | npz_paths = os.listdir(data_dir)
18 | print "Load Training data"
19 | count = 0
20 | self.neg_img_filename_train = []
21 | for npz_path in tqdm(npz_paths):
22 | if int(re.split("\.", re.split("_", npz_path)[1])[0]) <= 30000:
23 | npz = np.load(os.path.join(data_dir, npz_path))
24 | # tokenize caption
25 | if count == 0:
26 | self.neg_caption_train = npz["index"]
27 | else:
28 | self.neg_caption_train = np.concatenate((self.neg_caption_train, npz["index"]), 0)
29 | # img_idx
30 | for i in npz["img_name"]:
31 | self.neg_img_filename_train.append(i+'.jpg')
32 | count += 1
33 | self.neg_img_filename_train = np.asarray(self.neg_img_filename_train)
34 |
35 | npz_paths = ["mscoco_51000.npz"]
36 | print "Testing data"
37 | self.neg_img_filename_test = []
38 | count = 0
39 | for npz_path in tqdm(npz_paths):
40 | npz = np.load(os.path.join(data_dir, npz_path))
41 | if count == 0:
42 | self.neg_caption_test = npz["index"]
43 | else:
44 | self.neg_caption_test = np.concatenate((self.neg_caption_test, npz["index"]), 0)
45 | # img_idx
46 | for i in npz["img_name"]:
47 | self.neg_img_filename_test.append(i+'.jpg')
48 | count += 1
49 | self.neg_img_filename_test = np.asarray(self.neg_img_filename_test)
50 |
51 | self.current = 0
52 | self.num_train = len(self.neg_img_filename_train)
53 | self.num_test = len(self.neg_img_filename_test)
54 | self.random_shuffle()
55 | self.filename2id = dataset.filename2id
56 | self.img_dims = dataset.img_dims
57 | self.img_feat = dataset.img_feat
58 |
59 | def random_shuffle(self):
60 | idx = range(self.num_train)
61 | np.random.shuffle(idx)
62 | self.neg_img_filename_train = self.neg_img_filename_train[idx]
63 | self.neg_caption_train = self.neg_caption_train[idx, :]
64 |
65 | def get_paired_data(self, num_data, phase):
66 | if phase == 'train':
67 | caption = self.neg_caption_train
68 | img_filename = self.neg_img_filename_train
69 | else:
70 | caption = self.neg_caption_test
71 | img_filename = self.neg_img_filename_test
72 |
73 | if num_data > 0:
74 | caption = caption[:num_data, :]
75 | img_filename = img_filename[:num_data]
76 | else:
77 | if phase=='train':
78 | num_data = self.num_train
79 | else:
80 | num_data = self.num_test
81 |
82 | image_feature = np.zeros([num_data, self.img_dims])
83 | img_idx = []
84 | for i in range(num_data):
85 | image_feature[i, :] = self.img_feat[get_key(img_filename[i])]
86 | img_idx.append(get_key(img_filename[i]))
87 | return image_feature, caption, np.asarray(img_idx)
88 |
89 | def sequential_sample(self, batch_size):
90 | end = (self.current+batch_size) % self.num_train
91 | if self.current + batch_size < self.num_train:
92 | caption = self.neg_caption_train[self.current:end, :]
93 | img_filename = self.neg_img_filename_train[self.current:end]
94 | else:
95 | caption = np.concatenate((self.neg_caption_train[self.current:], self.neg_caption_train[:end]), axis=0)
96 | img_filename = np.concatenate((self.neg_img_filename_train[self.current:], self.neg_img_filename_train[:end]), axis=0)
97 | self.random_shuffle()
98 |
99 | image_feature = np.zeros([batch_size, self.img_dims])
100 | img_id = []
101 | for i in range(batch_size):
102 | image_feature[i, :] = self.img_feat[get_key(img_filename[i])]
103 | img_id.append(self.filename2id[img_filename[i]])
104 | self.current = end
105 | return image_feature, caption, np.asarray(img_id)
106 |
107 | class mscoco():
108 |
109 | def __init__(self, conf=None):
110 | # train img feature
111 | self.dataset_name = 'cub'
112 | # target data
113 | flickr_img_path = './cub/cub_train_resnet.pkl'
114 | self.train_flickr_img_feat = utils.unpickle(flickr_img_path)
115 | self.num_train_images_filckr = len(self.train_flickr_img_feat.keys())
116 | self.train_img_idx = self.train_flickr_img_feat.keys()
117 | flickr_caption_train_data_path = './cub/tokenized_train_caption.pkl'
118 | flickr_caption_train_data = utils.unpickle(flickr_caption_train_data_path)
119 | self.flickr_caption_train = flickr_caption_train_data['tokenized_caption_list']
120 | self.flickr_caption_idx_train = flickr_caption_train_data['filename_list']
121 | self.num_flickr_train_caption = self.flickr_caption_train.shape[0]
122 | flickr_img_path = './cub/cub_test_resnet.pkl'
123 | self.test_flickr_img_feat = utils.unpickle(flickr_img_path)
124 | self.flickr_random_shuffle() # shuffle the text data
125 |
126 | # MSCOCO data
127 | img_feat_path = './data/coco_trainval_feat.pkl'
128 | self.img_feat = utils.unpickle(img_feat_path)
129 | train_meta_path = './data/K_annotation_train2014.pkl'
130 | train_meta = utils.unpickle(train_meta_path)
131 | self.filename2id = train_meta['filename2id']
132 | val_meta_path = './data/K_annotation_val2014.pkl'
133 | val_meta = utils.unpickle(val_meta_path)
134 | self.id2filename = val_meta['id2filename']
135 | # train caption
136 | caption_train_data_path = './data/tokenized_train_caption.pkl'
137 | caption_train_data = utils.unpickle(caption_train_data_path)
138 | self.caption_train = caption_train_data['tokenized_caption_list']
139 | self.caption_idx_train = caption_train_data['filename_list']
140 | # val caption
141 | caption_test_data_path = './data/tokenized_test_caption.pkl'
142 | caption_test_data = utils.unpickle(caption_test_data_path)
143 | self.caption_test = caption_test_data['tokenized_caption_list']
144 | self.caption_idx_test = caption_test_data['filename_list']
145 | dict_path = './data/dictionary_5.npz'
146 | temp = np.load(dict_path)
147 | self.ix2word = temp['idx2word'].item()
148 | self.word2ix = temp['word2idx'].item()
149 | # add token
150 | if conf != None:
151 | self.batch_size = conf.batch_size
152 | self.dict_size = len(self.ix2word.keys())
153 | self.test_pointer = 0
154 | self.current_flickr = 0
155 | self.current_flickr_caption = 0
156 | self.current = 0
157 | self.max_words = self.caption_train.shape[1]
158 | tmp = self.img_feat[self.img_feat.keys()[0]]
159 | self.img_dims = tmp.shape[0]
160 | self.num_train = self.caption_train.shape[0]
161 | self.num_test = self.caption_test.shape[0]
162 | # Load annotation
163 | self.source_test_annotation = json.load(open('./data/K_val_annotation.json'))
164 | self.source_test_images = self.source_test_annotation.keys()
165 | self.source_num_test_images = len(self.source_test_images)
166 | self.test_annotation = json.load(open('./cub/K_test_annotation.json'))
167 | self.test_images = self.test_annotation.keys()
168 | self.num_test_images = len(self.test_images)
169 | self.random_shuffle()
170 |
171 | def random_shuffle(self):
172 | idx = range(self.num_train)
173 | np.random.shuffle(idx)
174 | self.caption_train = self.caption_train[idx]
175 | self.caption_idx_train = self.caption_idx_train[idx]
176 |
177 | def flickr_random_shuffle(self):
178 | idx = range(self.num_flickr_train_caption)
179 | np.random.shuffle(idx)
180 | self.flickr_caption_train = self.flickr_caption_train[idx]
181 | self.flickr_caption_idx_train = self.flickr_caption_idx_train[idx]
182 |
183 | def get_train_annotation(self):
184 | return self.train_annotation
185 |
186 | def get_train_for_eval(self, num):
187 | image_feature = np.zeros([num, self.img_dims])
188 | filenames = []
189 | self.random_shuffle()
190 | for i in range(num):
191 | filename = get_key(self.caption_idx_train[i])
192 | filenames.append(filename)
193 | image_feature[i, :] = self.img_feat[filename]
194 |
195 | return image_feature, np.asarray(filenames)
196 |
197 | def get_test_for_eval(self):
198 |
199 | image_feature = np.zeros([self.num_test_images, self.img_dims])
200 | image_id = np.zeros([self.num_test_images])
201 | for i in range(self.num_test_images):
202 | image_feature[i, :] = self.test_flickr_img_feat[self.test_images[i]]
203 | image_id[i] = int(self.test_images[i])
204 |
205 | return image_feature, image_id, self.test_annotation
206 |
207 | def get_source_test_for_eval(self):
208 |
209 | image_feature = np.zeros([self.source_num_test_images, self.img_dims])
210 | image_id = np.zeros([self.source_num_test_images])
211 | for i in range(self.source_num_test_images):
212 | image_feature[i, :] = self.img_feat[get_key(self.id2filename[self.source_test_images[i]])]
213 | image_id[i] = int(self.source_test_images[i])
214 |
215 | return image_feature, image_id, self.source_test_annotation
216 |
217 | def get_wrong_text(self, num_data, phase='train'):
218 | assert phase=='train'
219 | idx = range(self.num_train)
220 | np.random.shuffle(idx)
221 | caption_train = self.caption_train[idx, :]
222 | return caption_train[:num_data, :]
223 |
224 | def get_paired_data(self, num_data, phase):
225 | if phase == 'train':
226 | caption = self.caption_train
227 | img_idx = self.caption_idx_train
228 | else:
229 | caption = self.caption_test
230 | img_idx = self.caption_idx_test
231 |
232 | if num_data > 0:
233 | caption = caption[:num_data, :]
234 | img_idx = img_idx[:num_data]
235 | else:
236 | if phase=='train':
237 | num_data = self.num_train
238 | else:
239 | num_data = self.num_test
240 |
241 | image_feature = np.zeros([num_data, self.img_dims])
242 | for i in range(num_data):
243 | image_feature[i, :] = self.img_feat[get_key(img_idx[i])]
244 | return image_feature, caption, img_idx
245 |
246 | def preprocess(self, caption, lstm_steps):
247 | caption_padding = sequence.pad_sequences(caption, padding='post', maxlen=lstm_steps)
248 | return caption_padding
249 |
250 | def decode(self, sent_idx, type='string', remove_END=False):
251 | if len(sent_idx.shape) == 1:
252 | sent_idx = np.expand_dims(sent_idx, 0)
253 | sentences = []
254 | indexes = []
255 | for s in range(sent_idx.shape[0]):
256 | index = []
257 | sentence = ''
258 | for i in range(sent_idx.shape[1]):
259 | if int(sent_idx[s][i]) == int(self.word2ix[u'']):
260 | if not remove_END:
261 | #sentence = sentence + ''
262 | index.append(int(sent_idx[s][i]))
263 | break
264 | else:
265 | try:
266 | word = self.ix2word[str(int(sent_idx[s][i]))]
267 | sentence = sentence + word + ' '
268 | index.append(int(sent_idx[s][i]))
269 | except KeyError:
270 | sentence = sentence + "" + ' '
271 | index.append(int(self.word2ix[u'']))
272 | indexes.append(index)
273 | sentences.append((sentence+'.').capitalize())
274 | if type=='string':
275 | return sentences
276 | elif type=='index':
277 | return indexes
278 |
279 | def flickr_sequential_sample(self, batch_size):
280 |
281 | end = (self.current_flickr+batch_size) % self.num_train_images_filckr
282 | image_feature = np.zeros([batch_size, self.img_dims])
283 | if self.current_flickr + batch_size < self.num_train_images_filckr:
284 | key = self.train_img_idx[self.current_flickr:end]
285 | else:
286 | key = np.concatenate((self.train_img_idx[self.current_flickr:], self.train_img_idx[:end]), axis=0)
287 |
288 | count = 0
289 | for k in key:
290 | image_feature[count] = self.train_flickr_img_feat[k]
291 | count += 1
292 | self.current_flickr = end
293 | return image_feature
294 |
295 | def flickr_caption_sequential_sample(self, batch_size):
296 | end = (self.current_flickr_caption+batch_size) % self.num_flickr_train_caption
297 | if self.current_flickr_caption + batch_size < self.num_flickr_train_caption:
298 | caption = self.flickr_caption_train[self.current_flickr_caption:end, :]
299 | else:
300 | caption = np.concatenate((self.flickr_caption_train[self.current_flickr_caption:], self.flickr_caption_train[:end]), axis=0)
301 | self.flickr_random_shuffle()
302 |
303 | self.current_flickr_caption = end
304 | return caption
305 |
306 | def sequential_sample(self, batch_size):
307 | end = (self.current+batch_size) % self.num_train
308 | if self.current + batch_size < self.num_train:
309 | caption = self.caption_train[self.current:end, :]
310 | img_idx = self.caption_idx_train[self.current:end]
311 | else:
312 | caption = np.concatenate((self.caption_train[self.current:], self.caption_train[:end]), axis=0)
313 | img_idx = np.concatenate((self.caption_idx_train[self.current:], self.caption_idx_train[:end]), axis=0)
314 | self.random_shuffle()
315 |
316 | image_feature = np.zeros([batch_size, self.img_dims])
317 | img_id = []
318 | for i in range(batch_size):
319 | image_feature[i, :] = self.img_feat[get_key(img_idx[i])]
320 | img_id.append(self.filename2id[img_idx[i]])
321 | self.current = end
322 | return image_feature, caption, img_id
323 |
324 |
--------------------------------------------------------------------------------
/show-adapt-tell/highway.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # highway layer that borrowed from https://github.com/carpedm20/lstm-char-cnn-tensorflow
4 | def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu):
5 | """Highway Network (cf. http://arxiv.org/abs/1505.00387).
6 |
7 | t = sigmoid(Wy + b)
8 | z = t * g(Wy + b) + (1 - t) * y
9 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
10 | """
11 | output = input_
12 | for idx in xrange(layer_size):
13 | output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx))
14 | transform_gate = tf.sigmoid(
15 | tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
16 | carry_gate = 1. - transform_gate
17 | output = transform_gate * output + carry_gate * input_
18 | return output
19 |
20 |
--------------------------------------------------------------------------------
/show-adapt-tell/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.misc
3 | import numpy as np
4 | import tensorflow as tf
5 | from pretrain_G import G_pretrained
6 | from pretrain_CNN_D import D_pretrained
7 | from model import SeqGAN
8 | from data_loader import mscoco, mscoco_negative
9 | import pprint
10 | import pdb
11 |
12 | flags = tf.app.flags
13 | flags.DEFINE_integer("epoch", 100, "Epoch to train [100]")
14 | flags.DEFINE_float("learning_rate", 5e-5, "Learning rate of for adam [0.0003]")
15 | flags.DEFINE_float("drop_out_rate", 0.3, "Drop out rate fro LSTM")
16 | flags.DEFINE_float("discount", 0.95, "discount factor in RL")
17 | flags.DEFINE_string("model_name", "cub_no_scheduled", "")
18 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") # 128:G, 32:D
19 | flags.DEFINE_integer("G_hidden_size", 512, "") # 512:G, 64:D
20 | flags.DEFINE_integer("D_hidden_size", 512, "")
21 | flags.DEFINE_integer("max_iter", 100000, "")
22 | flags.DEFINE_integer('max_to_keep', 40, '')
23 | flags.DEFINE_string("method", "ROUGE_L", "")
24 | flags.DEFINE_string("load_ckpt", './checkpoint/mscoco/G_pretrained/G_Pretrained-39000', "Directory name to loade the checkpoints [checkpoint]")
25 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
26 | flags.DEFINE_boolean("G_is_pretrain", False, "Do the G pretraining")
27 | flags.DEFINE_boolean("D_is_pretrain", False, "Do the D pretraining")
28 | flags.DEFINE_boolean("load_pretrain", True, "Load the pretraining")
29 | flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]")
30 |
31 | # Setting from Self-critical Sequence Training for Image Captioning
32 | tf.app.flags.DEFINE_float('init_lr', 5e-4, '') # follow IBM's paper
33 | tf.app.flags.DEFINE_float('lr_decay', 0.8, 'learning rate decay factor')
34 | tf.app.flags.DEFINE_float('lr_decay_every', 6600, 'every 3 epoch 3*2200')
35 | tf.app.flags.DEFINE_float('ss_ascent', 0.05, 'schedule sampling')
36 | tf.app.flags.DEFINE_float('ss_ascent_every', 11000, 'every 5 epoch 5*2200')
37 | tf.app.flags.DEFINE_float('ss_max', 0.25, '0.05*5=0.25')
38 |
39 | FLAGS = flags.FLAGS
40 | pp = pprint.PrettyPrinter()
41 | def main(_):
42 | pp.pprint(flags.FLAGS.__flags)
43 |
44 | if not os.path.exists(FLAGS.checkpoint_dir):
45 | os.makedirs(FLAGS.checkpoint_dir)
46 |
47 | dataset = mscoco(FLAGS)
48 | config = tf.ConfigProto()
49 | config.gpu_options.per_process_gpu_memory_fraction = 1/10
50 | config.gpu_options.allow_growth = True
51 | with tf.Session(config=config) as sess:
52 | filter_sizes = [1,2,3,4,5,6,7,8,9,10,16,24,dataset.max_words]
53 | num_filters = [100,200,200,200,200,100,100,100,100,100,160,160,160]
54 | num_filters_total = sum(num_filters)
55 | info={'num_classes':3, 'filter_sizes':filter_sizes, 'num_filters':num_filters,
56 | 'num_filters_total':num_filters_total, 'l2_reg_lambda':0.2}
57 | if FLAGS.G_is_pretrain:
58 | G_pretrained_model = G_pretrained(sess, dataset, conf=FLAGS)
59 | if FLAGS.is_train:
60 | G_pretrained_model.train()
61 | G_pretrained_model.evaluate('test', 0, )
62 | if FLAGS.D_is_pretrain:
63 | negative_dataset = mscoco_negative(dataset, FLAGS)
64 | D_pretrained_model = D_pretrained(sess, dataset, negative_dataset, info, conf=FLAGS)
65 | D_pretrained_model.train()
66 | if FLAGS.is_train:
67 | model = SeqGAN(sess, dataset, info, conf=FLAGS)
68 | model.train()
69 |
70 | if __name__ == '__main__':
71 | tf.app.run()
72 |
--------------------------------------------------------------------------------
/show-adapt-tell/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import time
4 | import tensorflow as tf
5 | import numpy as np
6 | from tqdm import tqdm
7 | from highway import *
8 | import copy
9 | from coco_caption.pycocoevalcap.eval import COCOEvalCap
10 | import pdb
11 |
12 | def calculate_loss_and_acc_with_logits(predictions, logits, label, l2_loss, l2_reg_lambda):
13 | # Calculate Mean cross-entropy loss
14 | with tf.variable_scope("loss"):
15 | losses = tf.nn.softmax_cross_entropy_with_logits(tf.squeeze(logits), label)
16 | D_loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
17 | with tf.variable_scope("accuracy"):
18 | correct_predictions = tf.equal(predictions, tf.argmax(label, 1))
19 | accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"))
20 | return D_loss, accuracy
21 |
22 |
23 | class SeqGAN():
24 | def __init__(self, sess, dataset, D_info, conf=None):
25 | self.sess = sess
26 | self.model_name = conf.model_name
27 | self.batch_size = conf.batch_size
28 | self.max_iter = conf.max_iter
29 | self.max_to_keep = conf.max_to_keep
30 | self.is_train = conf.is_train
31 | # Testing => dropout rate is 0
32 | if self.is_train:
33 | self.drop_out_rate = conf.drop_out_rate
34 | else:
35 | self.drop_out_rate = 0
36 |
37 | self.num_train = dataset.num_train
38 | self.G_hidden_size = conf.G_hidden_size # 512
39 | self.D_hidden_size = conf.D_hidden_size # 512
40 | self.dict_size = dataset.dict_size
41 | self.max_words = dataset.max_words
42 | self.dataset = dataset
43 | self.img_dims = self.dataset.img_dims
44 | self.checkpoint_dir = conf.checkpoint_dir
45 | self.lstm_steps = self.max_words+1
46 | self.START = self.dataset.word2ix[u'']
47 | self.END = self.dataset.word2ix[u'']
48 | self.UNK = self.dataset.word2ix[u'']
49 | self.NOT = self.dataset.word2ix[u'']
50 | self.method = conf.method
51 | self.discount = conf.discount
52 | self.load_pretrain = conf.load_pretrain
53 | self.filter_sizes = D_info['filter_sizes']
54 | self.num_filters = D_info['num_filters']
55 | self.num_filters_total = sum(self.num_filters)
56 | self.num_classes = D_info['num_classes']
57 | self.num_domains = 3
58 | self.l2_reg_lambda = D_info['l2_reg_lambda']
59 |
60 |
61 | # D placeholder
62 | self.images = tf.placeholder('float32', [self.batch_size, self.img_dims])
63 | self.right_text = tf.placeholder('int32', [self.batch_size, self.max_words])
64 | self.wrong_text = tf.placeholder('int32', [self.batch_size, self.max_words])
65 | self.wrong_length = tf.placeholder('int32', [self.batch_size], name="wrong_length")
66 | self.right_length = tf.placeholder('int32', [self.batch_size], name="right_length")
67 |
68 | # Domain Classider
69 | self.src_images = tf.placeholder('float32', [self.batch_size, self.img_dims])
70 | self.tgt_images = tf.placeholder('float32', [self.batch_size, self.img_dims])
71 | self.src_text = tf.placeholder('int32', [self.batch_size, self.max_words])
72 | self.tgt_text = tf.placeholder('int32', [self.batch_size, self.max_words])
73 | # Optimizer
74 | self.G_optim = tf.train.AdamOptimizer(conf.learning_rate)
75 | self.D_optim = tf.train.AdamOptimizer(conf.learning_rate)
76 | self.T_optim = tf.train.AdamOptimizer(conf.learning_rate)
77 | self.Domain_image_optim = tf.train.AdamOptimizer(conf.learning_rate)
78 | self.Domain_text_optim = tf.train.AdamOptimizer(conf.learning_rate)
79 | D_info["sentence_length"] = self.max_words
80 | self.D_info = D_info
81 |
82 | ###################################################
83 | # Generator #
84 | ###################################################
85 | # G placeholder
86 | state_list, predict_words_list_sample, log_probs_action_picked_list, self.rollout_mask, self.predict_mask = self.generator(name='G', reuse=False)
87 | predict_words_sample = tf.pack(predict_words_list_sample)
88 | self.predict_words_sample = tf.transpose(predict_words_sample, [1,0]) # B,S
89 | # for testing
90 | # argmax prediction
91 | _, predict_words_list_argmax, log_probs_action_picked_list_argmax, _, self.predict_mask_argmax = self.generator_test(name='G', reuse=True)
92 | predict_words_argmax = tf.pack(predict_words_list_argmax)
93 | self.predict_words_argmax = tf.transpose(predict_words_argmax, [1,0]) # B,S
94 | rollout = []
95 | rollout_length = []
96 | rollout_num = 3
97 | for i in range(rollout_num):
98 | rollout_i, rollout_length_i = self.rollout(predict_words_list_sample, state_list, name="G") # S*B, S
99 | rollout.append(rollout_i) # R,B,S
100 | rollout_length.append(rollout_length_i) # R,B, 1
101 |
102 | rollout = tf.pack(rollout) # R,B,S
103 | rollout = tf.reshape(rollout, [-1, self.max_words]) # R*B,S
104 | rollout_length = tf.pack(rollout_length) # R,B,1
105 | rollout_length = tf.reshape(rollout_length, [-1, 1]) # R*B, 1
106 | rollout_length = tf.squeeze(rollout_length)
107 | rollout_size = self.batch_size * self.max_words * rollout_num
108 | images_expand = tf.expand_dims(self.images, 1) # B,1,I
109 | images_tile = tf.tile(images_expand, [1, self.max_words, 1]) # B,S,I
110 | images_tile_transpose = tf.transpose(images_tile, [1,0,2]) # S,B,I
111 | images_tile_transpose = tf.tile(tf.expand_dims(images_tile_transpose, 0), [rollout_num,1,1,1]) #R,S,B,I
112 | images_reshape = tf.reshape(images_tile_transpose, [-1, self.img_dims]) #R*S*B,I
113 |
114 | D_rollout_vqa_softmax, D_rollout_logits_vqa = self.discriminator(rollout_size, images_reshape, rollout, rollout_length, name="D", reuse=False)
115 | D_rollout_text, D_rollout_text_softmax, D_logits_rollout_text, l2_loss_rollout_text = self.text_discriminator(rollout, D_info, name="D_text", reuse=False)
116 | reward = tf.multiply(D_rollout_vqa_softmax[:,0], D_rollout_text_softmax[:,0]) # S*B, 1
117 |
118 | reward = tf.reshape(reward, [rollout_num, -1]) # R, S*B
119 | reward = tf.reduce_mean(reward, 0) # S*B
120 |
121 | self.rollout_reward = tf.reshape(reward, [self.max_words, self.batch_size]) # S,B
122 | D_logits_rollout_reshape = tf.reshape(self.rollout_reward, [-1])
123 | self.G_loss = (-1)*tf.reduce_sum(log_probs_action_picked_list*tf.stop_gradient(D_logits_rollout_reshape)) / tf.reduce_sum(tf.stop_gradient(self.predict_mask))
124 |
125 | # Teacher Forcing
126 | self.mask = tf.placeholder('float32', [self.batch_size, self.max_words]) # mask out the loss
127 | self.teacher_loss, self.teacher_loss_sum = self.Teacher_Forcing(self.right_text, self.mask, name="G", reuse=True)
128 |
129 | ###################################################
130 | # Discriminator #
131 | ###################################################
132 | # take the sample as fake data
133 | D_info["sentence_length"] = self.max_words
134 |
135 | # take the argmax sample as fake data
136 | self.fake_length = tf.reduce_sum(tf.stop_gradient(self.predict_mask),1)
137 | D_fake_vqa_softmax, D_fake_logits_vqa = self.discriminator(self.batch_size, self.images, tf.to_int32(self.predict_words_sample), tf.to_int32(self.fake_length), name="D", reuse=True)
138 | D_right_vqa_softmax, D_right_logits_vqa = self.discriminator(self.batch_size, self.images, self.right_text,
139 | self.right_length, name="D", reuse=True)
140 | D_wrong_vqa_softmax, D_wrong_logits_vqa = self.discriminator(self.batch_size, self.images, self.wrong_text,
141 | self.wrong_length, name="D", reuse=True)
142 |
143 | D_right_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_right_logits_vqa,
144 | tf.concat(1,(tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1))))))
145 | D_wrong_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_wrong_logits_vqa,
146 | tf.concat(1,(tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1))))))
147 | D_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(D_fake_logits_vqa,
148 | tf.concat(1,(tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1))))))
149 |
150 |
151 | self.D_loss = D_fake_loss + D_right_loss + D_wrong_loss
152 | ###################################################
153 | # Text Domain Classifier
154 | ###################################################
155 | D_src_text, D_src_text_softmax, D_logits_src_text, l2_loss_src_text = self.text_discriminator(self.src_text, D_info, name="D_text", reuse=True)
156 | D_tgt_text, D_tgt_text_softmax, D_logits_tgt_text, l2_loss_tgt_text = self.text_discriminator(self.tgt_text, D_info, name="D_text", reuse=True)
157 | D_fake_text, D_fake_text_softmax, D_logits_fake_text, l2_loss_fake_text = self.text_discriminator(self.predict_words_sample, D_info, name="D_text", reuse=True)
158 |
159 |
160 | D_src_loss_text, D_src_acc_text = calculate_loss_and_acc_with_logits(D_src_text,
161 | D_logits_src_text, tf.concat(1,(tf.zeros((self.batch_size,1)), tf.zeros((self.batch_size,1)),
162 | tf.ones((self.batch_size,1)))), l2_loss_src_text, D_info["l2_reg_lambda"])
163 | D_fake_loss_text, D_fake_acc_text = calculate_loss_and_acc_with_logits(D_fake_text,
164 | D_logits_fake_text, tf.concat(1,(tf.zeros((self.batch_size,1)), tf.ones((self.batch_size,1)),
165 | tf.zeros((self.batch_size,1)))), l2_loss_fake_text, D_info["l2_reg_lambda"])
166 | D_tgt_loss_text, D_tgt_acc_text = calculate_loss_and_acc_with_logits(D_tgt_text,
167 | D_logits_tgt_text, tf.concat(1,(tf.ones((self.batch_size,1)), tf.zeros((self.batch_size,1)),
168 | tf.zeros((self.batch_size,1)))), l2_loss_tgt_text, D_info["l2_reg_lambda"])
169 | self.D_text_loss = D_src_loss_text + D_tgt_loss_text + D_fake_loss_text
170 |
171 |
172 | ########################## tensorboard summary:########################
173 | # D_real_sum, D_fake_sum = the sigmoid output
174 | # D_real_loss_sum, D_fake_loss_sum = the loss for different kinds input
175 | # D_loss_sum, G_loss_sum = loss of the G&D
176 | #######################################################################
177 | self.start_reward_sum = tf.scalar_summary("start_reward", tf.reduce_mean(self.rollout_reward[0,:]))
178 | self.total_reward_sum = tf.scalar_summary("total_mean_reward", tf.reduce_mean(self.rollout_reward))
179 | self.logprobs_mean_sum = tf.scalar_summary("logprobs_mean", tf.reduce_sum(log_probs_action_picked_list)/tf.reduce_sum(self.predict_mask))
180 | self.logprobs_dist_sum = tf.histogram_summary("log_probs", log_probs_action_picked_list)
181 | self.D_fake_loss_sum = tf.scalar_summary("D_fake_loss", D_fake_loss)
182 | self.D_wrong_loss_sum = tf.scalar_summary("D_wrong_loss", D_wrong_loss)
183 | self.D_right_loss_sum = tf.scalar_summary("D_right_loss", D_right_loss)
184 | self.D_loss_sum = tf.scalar_summary("D_loss", self.D_loss)
185 | self.G_loss_sum = tf.scalar_summary("G_loss", self.G_loss)
186 | ###################################################
187 | # Record the paramters #
188 | ###################################################
189 | params = tf.trainable_variables()
190 | self.R_params = []
191 | self.G_params = []
192 | self.D_params = []
193 | self.G_params_dict = {}
194 | self.D_params_dict = {}
195 | for param in params:
196 | if "R" in param.name:
197 | self.R_params.append(param)
198 | elif "G" in param.name:
199 | self.G_params.append(param)
200 | self.G_params_dict.update({param.name:param})
201 | elif "D" in param.name:
202 | self.D_params.append(param)
203 | self.D_params_dict.update({param.name:param})
204 | print "Build graph complete"
205 |
206 | def rollout_update(self):
207 | for r, g in zip(self.R_params, self.G_params):
208 | assign_op = r.assign(g)
209 | self.sess.run(assign_op)
210 | def discriminator(self, batch_size, images, text, length, name="discriminator", reuse=False):
211 |
212 | ### sentence: B, S
213 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
214 | with tf.variable_scope(name):
215 | if reuse:
216 | tf.get_variable_scope().reuse_variables()
217 | with tf.variable_scope("lstm"):
218 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.D_hidden_size, state_is_tuple=True)
219 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
220 | with tf.device('/cpu:0'), tf.variable_scope("embedding"):
221 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.D_hidden_size], "float32", random_uniform_init)
222 | with tf.variable_scope("text_emb"):
223 | text_W = tf.get_variable("text_W", [2*self.D_hidden_size, self.D_hidden_size],"float32", random_uniform_init)
224 | text_b = tf.get_variable("text_b", [self.D_hidden_size], "float32", random_uniform_init)
225 | with tf.variable_scope("images_emb"):
226 | images_W = tf.get_variable("images_W", [self.img_dims, self.D_hidden_size],"float32", random_uniform_init)
227 | images_b = tf.get_variable("images_b", [self.D_hidden_size], "float32", random_uniform_init)
228 | with tf.variable_scope("scores_emb"):
229 | # "generator/scores"
230 | scores_W = tf.get_variable("scores_W", [self.D_hidden_size, 3], "float32", random_uniform_init)
231 | scores_b = tf.get_variable("scores_b", [3], "float32", random_uniform_init)
232 |
233 | state = lstm1.zero_state(batch_size, 'float32')
234 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[batch_size])
235 | # VQA use states
236 | state_list = []
237 | for j in range(self.max_words+1):
238 | if j > 0:
239 | tf.get_variable_scope().reuse_variables()
240 | with tf.device('/cpu:0'):
241 | if j ==0:
242 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
243 | else:
244 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, text[:,j-1])
245 | with tf.variable_scope("lstm"):
246 | # "generator/lstm"
247 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
248 | # apppend state from index 1 (the start of the word)
249 | if j > 0:
250 | state_list.append(tf.concat(1,[state[0], state[1]]))
251 |
252 | state_list = tf.pack(state_list) # S,B,2H
253 | state_list = tf.transpose(state_list, [1,0,2]) # B,S,2H
254 | state_flatten = tf.reshape(state_list, [-1, 2*self.D_hidden_size]) # B*S, 2H
255 | # length-1 => index start from 0
256 | # need to prevent length = 0
257 | length_index = length-1
258 | condition = tf.greater_equal(length_index, 0) # B
259 | length_index = tf.select(condition, length_index, tf.constant(0, dtype=tf.int32, shape=[batch_size]))
260 | idx = tf.range(batch_size)*self.max_words + length_index # B
261 | state_gather = tf.gather(state_flatten, idx) # B, 2H
262 | # text embedding
263 | text_emb = tf.matmul(state_gather, text_W) + text_b # B,H
264 | text_emb = tf.nn.tanh(text_emb)
265 | # images embedding
266 | images_emb = tf.matmul(images, images_W) + images_b # B,H
267 | images_emb = tf.nn.tanh(images_emb)
268 | # embed to score
269 | logits = tf.mul(text_emb, images_emb) # B,H
270 | score = tf.matmul(logits, scores_W) + scores_b
271 |
272 | #return tf.nn.sigmoid(score), score
273 | return tf.nn.softmax(score), score
274 |
275 |
276 | def text_discriminator(self, sentence, info, name="text_discriminator", reuse=False):
277 | ### sentence: B, S
278 | hidden_size = self.D_hidden_size
279 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
280 | with tf.variable_scope(name):
281 | if reuse:
282 | tf.get_variable_scope().reuse_variables()
283 | with tf.device('/cpu:0'), tf.variable_scope("embedding"):
284 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, hidden_size], "float32", random_uniform_init)
285 | embedded_chars = tf.nn.embedding_lookup(word_emb_W, sentence) # B,S,H
286 | embedded_chars_expanded = tf.expand_dims(embedded_chars, -1) # B,S,H,1
287 | with tf.variable_scope("output"):
288 | output_W = tf.get_variable("output_W", [info["num_filters_total"], self.num_domains],
289 | "float32", random_uniform_init)
290 | output_b = tf.get_variable("output_b", [self.num_domains], "float32", random_uniform_init)
291 | # Create a convolution + maxpool layer for each filter size
292 | pooled_outputs = []
293 | # Keeping track of l2 regularization loss (optional)
294 | l2_loss = tf.constant(0.0)
295 | for filter_size, num_filter in zip(info["filter_sizes"], info["num_filters"]):
296 | with tf.variable_scope("conv-maxpool-%s" % filter_size):
297 | # Convolution Layer
298 | filter_shape = [filter_size, hidden_size, 1, num_filter]
299 | W = tf.get_variable("W", filter_shape, "float32", random_uniform_init)
300 | b = tf.get_variable("b", [num_filter], "float32", random_uniform_init)
301 | conv = tf.nn.conv2d(
302 | embedded_chars_expanded,
303 | W,
304 | strides=[1, 1, 1, 1],
305 | padding="VALID",
306 | name="conv")
307 | # Apply nonlinearity
308 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
309 | # Maxpooling over the outputs
310 | pooled = tf.nn.max_pool(
311 | h,
312 | ksize=[1, info["sentence_length"] - filter_size + 1, 1, 1],
313 | strides=[1, 1, 1, 1],
314 | padding='VALID',
315 | name="pool")
316 | pooled_outputs.append(pooled)
317 | h_pool = tf.concat(3, pooled_outputs) # B,1,1,total filters
318 | h_pool_flat = tf.reshape(h_pool, [-1, info["num_filters_total"]]) # b, total filters
319 |
320 | # Add highway
321 | with tf.variable_scope("highway"):
322 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0)
323 | with tf.variable_scope("output"):
324 | l2_loss += tf.nn.l2_loss(output_W)
325 | l2_loss += tf.nn.l2_loss(output_b)
326 | logits = tf.nn.xw_plus_b(h_highway, output_W, output_b, name="logits")
327 | logits_softmax = tf.nn.softmax(logits)
328 | predictions = tf.argmax(logits_softmax, 1, name="predictions")
329 | return predictions, logits_softmax, logits, l2_loss
330 |
331 | def domain_classifier(self, images, name="G", reuse=False):
332 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
333 | with tf.variable_scope(name):
334 | tf.get_variable_scope().reuse_variables()
335 | with tf.variable_scope("images"):
336 | # "generator/images"
337 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init)
338 | images_emb = tf.matmul(images, images_W) # B,H
339 |
340 | l2_loss = tf.constant(0.0)
341 | with tf.variable_scope("domain"):
342 | if reuse:
343 | tf.get_variable_scope().reuse_variables()
344 | with tf.variable_scope("output"):
345 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.num_domains],
346 | "float32", random_uniform_init)
347 | output_b = tf.get_variable("output_b", [self.num_domains], "float32", random_uniform_init)
348 | l2_loss += tf.nn.l2_loss(output_W)
349 | l2_loss += tf.nn.l2_loss(output_b)
350 | logits = tf.nn.xw_plus_b(images_emb, output_W, output_b, name="logits")
351 | predictions = tf.argmax(logits, 1, name="predictions")
352 |
353 | return predictions, logits, l2_loss
354 |
355 |
356 | def rollout(self, predict_words, state_list, name="R"):
357 |
358 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
359 | with tf.variable_scope(name):
360 | tf.get_variable_scope().reuse_variables()
361 | with tf.variable_scope("images"):
362 | # "generator/images"
363 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init)
364 | with tf.variable_scope("lstm"):
365 | # WONT BE CREATED HERE
366 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True)
367 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
368 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
369 | # "R/embedding"
370 | word_emb_W = tf.get_variable("word_emb_W",[self.dict_size, self.G_hidden_size], "float32", random_uniform_init)
371 | with tf.variable_scope("output"):
372 | # "R/output"
373 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init)
374 | rollout_list = []
375 | length_mask_list = []
376 | # rollout for the first time step
377 | for step in range(self.max_words):
378 | sample_words = predict_words[step]
379 | state = state_list[step]
380 | rollout_step_list = []
381 | mask = tf.constant(True, "bool", [self.batch_size])
382 | # used to calcualte the length of the rollout sentence
383 | length_mask_step = []
384 | for j in range(step+1):
385 | mask_out_word = tf.select(mask, predict_words[j],
386 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size]))
387 | rollout_step_list.append(mask_out_word)
388 | length_mask_step.append(mask)
389 | prev_mask = mask
390 | mask_step = tf.not_equal(predict_words[j], self.END) # B
391 | mask = tf.logical_and(prev_mask, mask_step)
392 | for j in range(self.max_words-step-1):
393 | if step != 0 or j != 0:
394 | tf.get_variable_scope().reuse_variables()
395 | with tf.device("/cpu:0"):
396 | sample_words_emb = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words))
397 | with tf.variable_scope("lstm"):
398 | output, state = lstm1(sample_words_emb, state, scope=tf.get_variable_scope()) # output: B,H
399 | logits = tf.matmul(output, output_W)
400 | # add 1e-8 to prevent log(0)
401 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D
402 | sample_words = tf.squeeze(tf.multinomial(log_probs,1))
403 | mask_out_word = tf.select(mask, sample_words,
404 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size]))
405 | rollout_step_list.append(mask_out_word)
406 | length_mask_step.append(mask)
407 | prev_mask = mask
408 | mask_step = tf.not_equal(sample_words, self.END) # B
409 | mask = tf.logical_and(prev_mask, mask_step)
410 |
411 | length_mask_step = tf.pack(length_mask_step) # S,B
412 | length_mask_step = tf.transpose(length_mask_step, [1,0]) # B,S
413 | length_mask_list.append(length_mask_step)
414 | rollout_step_list = tf.pack(rollout_step_list) # S,B
415 | rollout_step_list = tf.transpose(rollout_step_list, [1,0]) # B,S
416 | rollout_list.append(rollout_step_list)
417 |
418 | length_mask_list = tf.pack(length_mask_list) # S,B,S
419 | length_mask_list = tf.reshape(length_mask_list, [-1, self.max_words]) # S*B,S
420 | rollout_list = tf.pack(rollout_list) # S,B,S
421 | rollout_list = tf.reshape(rollout_list, [-1, self.max_words]) # S*B, S
422 | rollout_length = tf.to_int32(tf.reduce_sum(tf.to_float(length_mask_list),1))
423 | return rollout_list, rollout_length
424 |
425 | def Teacher_Forcing(self, target_sentence, mask, name='generator', reuse=False):
426 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
427 | with tf.variable_scope(name):
428 | if reuse:
429 | tf.get_variable_scope().reuse_variables()
430 | with tf.variable_scope("images"):
431 | # "generator/images"
432 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init)
433 | with tf.variable_scope("lstm"):
434 | # "generator/lstm"
435 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True)
436 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
437 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
438 | # "generator/embedding"
439 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init)
440 | with tf.variable_scope("output"):
441 | # "generator/output"
442 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init)
443 |
444 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size])
445 | state = lstm1.zero_state(self.batch_size, 'float32')
446 | teacher_loss = 0.
447 | for j in range(self.lstm_steps):
448 | if j == 0:
449 | images_emb = tf.matmul(self.images, images_W) # B,H
450 | lstm1_in = images_emb
451 | else:
452 | tf.get_variable_scope().reuse_variables()
453 | with tf.device("/cpu:0"):
454 | if j == 1:
455 | #
456 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
457 | else:
458 | # schedule sampling
459 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, target_sentence[:,j-2])
460 |
461 | with tf.variable_scope("lstm"):
462 | # "generator/lstm"
463 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
464 |
465 | if j > 0:
466 | logits = tf.matmul(output, output_W) # B,D
467 | # calculate loss
468 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D
469 | action_picked = tf.range(self.batch_size)*(self.dict_size) + target_sentence[:,j-1]
470 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), mask[:,j-1])
471 | loss_t = (-1)*tf.reduce_sum(log_probs_action_picked*tf.ones(self.batch_size))
472 | teacher_loss += loss_t
473 |
474 | teacher_loss /= tf.reduce_sum(mask)
475 | teacher_loss_sum = tf.scalar_summary("teacher_loss", teacher_loss)
476 |
477 | return teacher_loss, teacher_loss_sum
478 |
479 | def generator(self, name='generator', reuse=False):
480 |
481 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
482 | with tf.variable_scope(name):
483 | if reuse:
484 | tf.get_variable_scope().reuse_variables()
485 | with tf.variable_scope("images"):
486 | # "generator/images"
487 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init)
488 | #images_b = tf.get_variable("images_b", [self.G_hidden_size], "float32", random_uniform_init)
489 | with tf.variable_scope("lstm"):
490 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True)
491 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
492 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
493 | # "generator/embedding"
494 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init)
495 | with tf.variable_scope("output"):
496 | # "generator/output"
497 | # dict size minus 1 => remove
498 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init)
499 |
500 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size])
501 | state = lstm1.zero_state(self.batch_size, 'float32')
502 | mask = tf.constant(True, "bool", [self.batch_size])
503 | log_probs_action_picked_list = []
504 | predict_words = []
505 | state_list = []
506 | predict_mask_list = []
507 | for j in range(self.max_words+1):
508 | if j == 0:
509 | #images_emb = tf.matmul(self.images, images_W) + images_b # B,H
510 | images_emb = tf.matmul(self.images, images_W)
511 | lstm1_in = images_emb
512 | else:
513 | tf.get_variable_scope().reuse_variables()
514 | with tf.device("/cpu:0"):
515 | if j == 1:
516 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
517 | else:
518 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words))
519 | with tf.variable_scope("lstm"):
520 | # "generator/lstm"
521 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
522 | if j > 0:
523 | logits = tf.matmul(output, output_W)
524 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D
525 | # word drawn from the multinomial distribution
526 | sample_words = tf.reshape(tf.multinomial(log_probs,1), [self.batch_size])
527 | mask_out_word = tf.select(mask, sample_words,
528 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size]))
529 | predict_words.append(mask_out_word)
530 | #predict_words.append(sample_words)
531 | # the mask should be dynamic
532 | # if the sentence is: This is a dog
533 | # the predict_mask_list is: 1,1,1,1,1,0,0,.....
534 | predict_mask_list.append(tf.to_float(mask))
535 | action_picked = tf.range(self.batch_size)*(self.dict_size) + tf.to_int32(sample_words) # B
536 | # mask out the word beyond the
537 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), tf.to_float(mask))
538 | log_probs_action_picked_list.append(log_probs_action_picked)
539 | prev_mask = mask
540 | mask_step = tf.not_equal(sample_words, self.END) # B
541 | mask = tf.logical_and(prev_mask, mask_step)
542 | state_list.append(state)
543 |
544 | predict_mask_list = tf.pack(predict_mask_list) # S,B
545 | predict_mask_list = tf.transpose(predict_mask_list, [1,0]) # B,S
546 | log_probs_action_picked_list = tf.pack(log_probs_action_picked_list) # S,B
547 | log_probs_action_picked_list = tf.reshape(log_probs_action_picked_list, [-1]) # S*B
548 | return state_list, predict_words, log_probs_action_picked_list, None, predict_mask_list
549 |
550 | def generator_test(self, name='generator', reuse=False):
551 |
552 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
553 | with tf.variable_scope(name):
554 | if reuse:
555 | tf.get_variable_scope().reuse_variables()
556 | with tf.variable_scope("images"):
557 | # "generator/images"
558 | images_W = tf.get_variable("images_W", [self.img_dims, self.G_hidden_size], "float32", random_uniform_init)
559 | with tf.variable_scope("lstm"):
560 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.G_hidden_size, state_is_tuple=True)
561 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
562 | # "generator/embedding"
563 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.G_hidden_size], "float32", random_uniform_init)
564 | with tf.variable_scope("output"):
565 | # "generator/output"
566 | # dict size minus 1 => remove
567 | output_W = tf.get_variable("output_W", [self.G_hidden_size, self.dict_size], "float32", random_uniform_init)
568 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size])
569 | state = lstm1.zero_state(self.batch_size, 'float32')
570 | mask = tf.constant(True, "bool", [self.batch_size])
571 | log_probs_action_picked_list = []
572 | predict_words = []
573 | state_list = []
574 | predict_mask_list = []
575 | for j in range(self.max_words+1):
576 | if j == 0:
577 | images_emb = tf.matmul(self.images, images_W)
578 | lstm1_in = images_emb
579 | else:
580 | tf.get_variable_scope().reuse_variables()
581 | with tf.device("/cpu:0"):
582 | if j == 1:
583 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
584 | else:
585 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_words))
586 | with tf.variable_scope("lstm"):
587 | # "generator/lstm"
588 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
589 | if j > 0:
590 | #logits = tf.matmul(output, output_W) + output_b # B,D
591 | logits = tf.matmul(output, output_W)
592 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) # B,D
593 | # word drawn from the multinomial distribution
594 | sample_words = tf.argmax(log_probs, 1) # B
595 | mask_out_word = tf.select(mask, sample_words,
596 | tf.constant(self.NOT, dtype=tf.int64, shape=[self.batch_size]))
597 | predict_words.append(mask_out_word)
598 | # the mask should be dynamic
599 | # if the sentence is: This is a dog
600 | # the predict_mask_list is: 1,1,1,1,1,0,0,.....
601 | predict_mask_list.append(tf.to_float(mask))
602 | action_picked = tf.range(self.batch_size)*(self.dict_size) + tf.to_int32(sample_words) # B
603 | # mask out the word beyond the
604 | log_probs_action_picked = tf.mul(tf.gather(tf.reshape(log_probs, [-1]), action_picked), tf.to_float(mask))
605 | log_probs_action_picked_list.append(log_probs_action_picked)
606 | prev_mask = mask
607 | mask_step = tf.not_equal(sample_words, self.END) # B
608 | mask = tf.logical_and(prev_mask, mask_step)
609 | state_list.append(state)
610 |
611 | predict_mask_list = tf.pack(predict_mask_list) # S,B
612 | predict_mask_list = tf.transpose(predict_mask_list, [1,0]) # B,S
613 | log_probs_action_picked_list = tf.pack(log_probs_action_picked_list) # S,B
614 | log_probs_action_picked_list = tf.reshape(log_probs_action_picked_list, [-1]) # S*B
615 | return state_list, predict_words, log_probs_action_picked_list, None, predict_mask_list
616 |
617 |
618 | def train(self):
619 |
620 | self.G_train_op = self.G_optim.minimize(self.G_loss, var_list=self.G_params)
621 | self.G_hat_train_op = self.T_optim.minimize(self.teacher_loss, var_list=self.G_params)
622 | self.D_train_op = self.D_optim.minimize(self.D_loss, var_list=self.D_params)
623 | self.Domain_text_train_op = self.Domain_text_optim.minimize(self.D_text_loss)
624 | log_dir = os.path.join('.', 'logs', self.model_name)
625 | if not os.path.exists(log_dir):
626 | os.makedirs(log_dir)
627 | #### Old version
628 | self.writer = tf.train.SummaryWriter(os.path.join(log_dir, "SeqGAN_sample"), self.sess.graph)
629 | self.summary_op = tf.merge_all_summaries()
630 | tf.initialize_all_variables().run()
631 | if self.load_pretrain:
632 | print "[@] Load the pretrained model"
633 | self.G_saver = tf.train.Saver(self.G_params_dict)
634 | self.G_saver.restore(self.sess, "./checkpoint/mscoco/G_pretrained/G_Pretrained-39000")
635 |
636 | self.saver = tf.train.Saver(max_to_keep=self.max_to_keep)
637 | count = 0
638 | D_count = 0
639 | G_count = 0
640 | for idx in range(self.max_iter//250):
641 | self.save(self.checkpoint_dir, count)
642 | self.evaluate(count)
643 | for _ in tqdm(range(250)):
644 | tgt_image_feature = self.dataset.flickr_sequential_sample(self.batch_size)
645 | tgt_text = self.dataset.flickr_caption_sequential_sample(self.batch_size)
646 | image_feature, right_text, _ = self.dataset.sequential_sample(self.batch_size)
647 | nonENDs = np.array(map(lambda x: (x != self.NOT).sum(), right_text))
648 | mask_t = np.zeros([self.batch_size, self.max_words])
649 | for ind, row in enumerate(mask_t):
650 | # mask out the
651 | row[0:nonENDs[ind]] = 1
652 |
653 | wrong_text = self.dataset.get_wrong_text(self.batch_size)
654 | right_length = np.sum((right_text!=self.NOT)+0, 1)
655 | wrong_length = np.sum((wrong_text!=self.NOT)+0, 1)
656 | for _ in range(1): # g_step
657 | # update G
658 | feed_dict = {self.images: tgt_image_feature}
659 | _, G_loss = self.sess.run([self.G_train_op, self.G_loss], feed_dict)
660 | G_count += 1
661 | for _ in range(20): # d_step
662 | # update D
663 | feed_dict = {self.images: image_feature,
664 | self.right_text:right_text,
665 | self.wrong_text:wrong_text,
666 | self.right_length:right_length,
667 | self.wrong_length:wrong_length,
668 | self.mask: mask_t,
669 | self.src_images: image_feature,
670 | self.tgt_images: tgt_image_feature,
671 | self.src_text: right_text,
672 | self.tgt_text: tgt_text}
673 |
674 | _, D_loss = self.sess.run([self.D_train_op, self.D_loss], feed_dict)
675 | D_count += 1
676 | _, D_text_loss = self.sess.run([self.Domain_text_train_op, self.D_text_loss], \
677 | {self.src_text: right_text,
678 | self.tgt_text: tgt_text,
679 | self.images: tgt_image_feature
680 | })
681 |
682 | count += 1
683 |
684 | def evaluate(self, count):
685 |
686 | samples = []
687 | samples_index = []
688 | image_feature, image_id, test_annotation = self.dataset.get_test_for_eval()
689 | num_samples = self.dataset.num_test_images
690 | samples_index = np.full([self.batch_size*(num_samples//self.batch_size), self.max_words], self.NOT)
691 | for i in range(num_samples//self.batch_size):
692 | image_feature_test = image_feature[i*self.batch_size:(i+1)*self.batch_size]
693 | feed_dict = {self.images: image_feature_test}
694 | predict_words = self.sess.run(self.predict_words_argmax, feed_dict)
695 | for j in range(self.batch_size):
696 | samples.append([self.dataset.decode(predict_words[j, :], type='string', remove_END=True)[0]])
697 | sample_index = self.dataset.decode(predict_words[j, :], type='index', remove_END=False)[0]
698 | samples_index[i*self.batch_size+j][:len(sample_index)] = sample_index
699 | # predict from samples
700 | samples = np.asarray(samples)
701 | samples_index = np.asarray(samples_index)
702 | print '[%] Sentence:', samples[0]
703 | meteor_pd = {}
704 | meteor_id = []
705 | for j in range(len(samples)):
706 | if image_id[j] == 0:
707 | break
708 | meteor_pd[str(int(image_id[j]))] = [{'image_id':str(int(image_id[j])), 'caption':samples[j][0]}]
709 | meteor_id.append(str(int(image_id[j])))
710 | scorer = COCOEvalCap(test_annotation, meteor_pd, meteor_id)
711 | scorer.evaluate(verbose=True)
712 | sample_dir = os.path.join("./SeqGAN_samples_sample", self.model_name)
713 | if not os.path.exists(sample_dir):
714 | os.makedirs(sample_dir)
715 | file_name = "%s_%s" % (self.dataset.dataset_name, str(count))
716 | np.savez(os.path.join(sample_dir, file_name), string=samples, index=samples_index, id=meteor_id)
717 |
718 | def save(self, checkpoint_dir, step):
719 | model_name = "SeqGAN_sample"
720 | model_dir = "%s" % (self.dataset.dataset_name)
721 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, self.model_name)
722 | if not os.path.exists(checkpoint_dir):
723 | os.makedirs(checkpoint_dir)
724 | self.saver.save(self.sess,
725 | os.path.join(checkpoint_dir, model_name),
726 | global_step=step)
727 |
728 | def load(self, checkpoint_dir):
729 | print(" [*] Reading checkpoints...")
730 |
731 | model_dir = "%s" % (self.dataset.dataset_name)
732 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
733 |
734 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
735 | if ckpt and ckpt.model_checkpoint_path:
736 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
737 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
738 | return True
739 | else:
740 | return False
741 |
--------------------------------------------------------------------------------
/show-adapt-tell/pretrain_CNN_D.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import time
4 | import tensorflow as tf
5 | import numpy as np
6 | from tqdm import tqdm
7 | from highway import *
8 | import pdb
9 |
10 | class D_pretrained():
11 | def __init__(self, sess, dataset, negative_dataset, D_info, conf=None, l2_reg_lambda=0.2):
12 |
13 | self.sess = sess
14 | self.batch_size = conf.batch_size
15 | self.max_iter = conf.max_iter
16 | self.num_train = dataset.num_train
17 | self.hidden_size = conf.D_hidden_size # 512
18 | self.dict_size = dataset.dict_size
19 | self.max_words = dataset.max_words
20 | self.dataset = dataset
21 | self.negative_dataset = negative_dataset
22 | self.checkpoint_dir = conf.checkpoint_dir
23 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False)
24 | self.optim = tf.train.AdamOptimizer(conf.learning_rate)
25 | self.filter_sizes = D_info['filter_sizes']
26 | self.num_filters = D_info['num_filters']
27 | self.num_filters_total = sum(self.num_filters)
28 | self.num_classes = D_info['num_classes']
29 | self.l2_reg_lambda = l2_reg_lambda
30 | self.START = self.dataset.word2ix[u'']
31 | self.END = self.dataset.word2ix[u'']
32 | self.UNK = self.dataset.word2ix[u'']
33 | self.NOT = self.dataset.word2ix[u'']
34 | # placeholder
35 | self.text = tf.placeholder(tf.int32, [None, self.max_words], name="text")
36 | self.label = tf.placeholder(tf.float32, [None, self.num_classes], name="label")
37 | self.images = tf.placeholder(tf.float32, [None, self.dataset.img_dims], name="images")
38 |
39 | self.loss, self.pred = self.build_Discriminator(self.images, self.text, self.label, name='D')
40 | self.loss_sum = tf.scalar_summary("loss", self.loss)
41 |
42 | params = tf.trainable_variables()
43 | self.D_params_dict = {}
44 | self.D_params_train = []
45 | for param in params:
46 | self.D_params_dict.update({param.name:param})
47 | if "embedding" in param.name:
48 | embedding_matrix = np.load("embedding-42000.npy")
49 | self.embedding_assign_op = param.assign(tf.Variable(embedding_matrix, trainable=False))
50 | else:
51 | self.D_params_train.append(param)
52 |
53 | def build_Discriminator(self, images, text, label, name="discriminator", reuse=False):
54 |
55 | ### sentence: B, S
56 | hidden_size = self.hidden_size
57 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
58 | with tf.variable_scope(name):
59 | if reuse:
60 | tf.get_variable_scope().reuse_variables()
61 | with tf.device('/cpu:0'), tf.variable_scope("embedding"):
62 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, hidden_size], "float32", random_uniform_init)
63 | embedded_chars = tf.nn.embedding_lookup(word_emb_W, text) # B,S,H
64 | embedded_chars_expanded = tf.expand_dims(embedded_chars, -1) # B,S,H,1
65 | with tf.variable_scope("output"):
66 | output_W = tf.get_variable("output_W", [hidden_size, self.num_classes],
67 | "float32", random_uniform_init)
68 | output_b = tf.get_variable("output_b", [self.num_classes], "float32", random_uniform_init)
69 | with tf.variable_scope("images"):
70 | images_W = tf.get_variable("images_W", [self.dataset.img_dims, hidden_size],
71 | "float32", random_uniform_init)
72 | images_b = tf.get_variable("images_b", [hidden_size], "float32", random_uniform_init)
73 | with tf.variable_scope("text"):
74 | text_W = tf.get_variable("text_W", [self.num_filters_total, hidden_size],
75 | "float32", random_uniform_init)
76 | text_b = tf.get_variable("text_b", [hidden_size], "float32", random_uniform_init)
77 |
78 | # Create a convolution + maxpool layer for each filter size
79 | pooled_outputs = []
80 | # Keeping track of l2 regularization loss (optional)
81 | l2_loss = tf.constant(0.0)
82 | for filter_size, num_filter in zip(self.filter_sizes, self.num_filters):
83 | with tf.variable_scope("conv-maxpool-%s" % filter_size):
84 | # Convolution Layer
85 | filter_shape = [filter_size, hidden_size, 1, num_filter]
86 | W = tf.get_variable("W", filter_shape, "float32", random_uniform_init)
87 | b = tf.get_variable("b", [num_filter], "float32", random_uniform_init)
88 | conv = tf.nn.conv2d(
89 | embedded_chars_expanded,
90 | W,
91 | strides=[1, 1, 1, 1],
92 | padding="VALID",
93 | name="conv")
94 | # Apply nonlinearity
95 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
96 | # Maxpooling over the outputs
97 | pooled = tf.nn.max_pool(
98 | h,
99 | ksize=[1, self.max_words - filter_size + 1, 1, 1],
100 | strides=[1, 1, 1, 1],
101 | padding='VALID',
102 | name="pool")
103 | pooled_outputs.append(pooled)
104 | h_pool = tf.concat(3, pooled_outputs) # B,1,1,total filters
105 | h_pool_flat = tf.reshape(h_pool, [-1, self.num_filters_total]) # b, total filters
106 | # Add highway
107 | with tf.variable_scope("highway"):
108 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0)
109 | with tf.variable_scope("text"):
110 | text_emb = tf.nn.xw_plus_b(h_highway, text_W, text_b, name="text_emb")
111 | with tf.variable_scope("images"):
112 | images_emb = tf.nn.xw_plus_b(images, images_W, images_b, name="images_emb")
113 | with tf.variable_scope("output"):
114 | fusing_vec = tf.mul(text_emb, images_emb)
115 | l2_loss += tf.nn.l2_loss(output_W)
116 | l2_loss += tf.nn.l2_loss(output_b)
117 | logits = tf.nn.xw_plus_b(fusing_vec, output_W, output_b, name="logits")
118 | ypred_for_auc = tf.nn.softmax(logits)
119 | predictions = tf.argmax(logits, 1, name="predictions")
120 | #predictions = tf.nn.sigmoid(logits, name="predictions")
121 | # Calculate Mean cross-entropy loss
122 | with tf.variable_scope("loss"):
123 | losses = tf.nn.softmax_cross_entropy_with_logits(logits, label)
124 | #losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.squeeze(logits), self.input_y)
125 | loss = tf.reduce_mean(losses) + self.l2_reg_lambda * l2_loss
126 |
127 | return loss, predictions
128 |
129 | def train(self):
130 |
131 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step, var_list=self.D_params_train)
132 | #self.train_op = self.optim.minimize(self.loss, global_step=self.global_step)
133 | self.writer = tf.train.SummaryWriter("./logs/D_CNN_pretrained_sample", self.sess.graph)
134 | tf.initialize_all_variables().run()
135 | self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=30)
136 | # assign the G matrix to D pretrain
137 | self.sess.run(self.embedding_assign_op)
138 | count = 0
139 | for idx in range(self.max_iter//3000):
140 | self.save(self.checkpoint_dir, count)
141 | self.evaluate('test', count)
142 | self.evaluate('train', count)
143 | for k in tqdm(range(3000)):
144 | right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size)
145 | fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size)
146 | wrong_text = self.dataset.get_wrong_text(self.batch_size)
147 |
148 | images = np.concatenate((right_images, right_images, fake_images), axis=0)
149 | text = np.concatenate((right_text, wrong_text, fake_text.astype('int32')), axis=0)
150 | label = np.zeros((text.shape[0], self.num_classes))
151 | # right -> first entry
152 | # wrong -> second entry
153 | # fake -> third entry
154 | label[:self.batch_size, 0] = 1
155 | label[self.batch_size:2*self.batch_size, 1] = 1
156 | label[2*self.batch_size:, 2] = 1
157 | _, loss, summary_str = self.sess.run([self.train_op, self.loss, self.loss_sum],{
158 | self.text: text.astype('int32'),
159 | self.images: images,
160 | self.label: label
161 | })
162 | self.writer.add_summary(summary_str, count)
163 | count += 1
164 |
165 | def evaluate(self, split, count):
166 |
167 | if split == 'test':
168 | num_test_pair = -1
169 | elif split == 'train':
170 | num_test_pair = 5000
171 | right_images, right_text, _ = self.dataset.get_paired_data(num_test_pair, phase=split)
172 | # the true paired data we get
173 | num_test_pair = len(right_images)
174 | fake_images, fake_text, _ = self.negative_dataset.get_paired_data(num_test_pair, phase=split)
175 | random_idx = range(num_test_pair)
176 | np.random.shuffle(random_idx)
177 | wrong_text = np.squeeze(right_text[random_idx, :])
178 | count = 0.
179 | loss_t = []
180 | right_acc_t = []
181 | wrong_acc_t = []
182 | fake_acc_t = []
183 | for i in range(num_test_pair//self.batch_size):
184 | right_images_batch = right_images[i*self.batch_size:(i+1)*self.batch_size,:]
185 | fake_images_batch = fake_images[i*self.batch_size:(i+1)*self.batch_size,:]
186 | right_text_batch = right_text[i*self.batch_size:(i+1)*self.batch_size,:]
187 | fake_text_batch = fake_text[i*self.batch_size:(i+1)*self.batch_size,:]
188 | wrong_text_batch = wrong_text[i*self.batch_size:(i+1)*self.batch_size,:]
189 | text_batch = np.concatenate((right_text_batch, wrong_text_batch, fake_text_batch.astype('int32')), axis=0)
190 | images_batch = np.concatenate((right_images_batch, right_images_batch, fake_images_batch), axis=0)
191 | label = np.zeros((text_batch.shape[0], self.num_classes))
192 | # right -> first entry
193 | # wrong -> second entry
194 | # fake -> third entry
195 | label[:self.batch_size, 0] = 1
196 | label[self.batch_size:2*self.batch_size, 1] = 1
197 | label[2*self.batch_size:, 2] = 1
198 | feed_dict = {self.images:images_batch, self.text:text_batch, self.label:label}
199 | loss, pred, loss_str = self.sess.run([self.loss, self.pred, self.loss_sum], feed_dict)
200 | loss_t.append(loss)
201 | right_acc_t.append(np.sum((np.argmax(label[:self.batch_size],1)==pred[:self.batch_size])+0))
202 | wrong_acc_t.append(np.sum((np.argmax(label[self.batch_size:2*self.batch_size],1)==pred[self.batch_size:2*self.batch_size])+0))
203 | fake_acc_t.append(np.sum((np.argmax(label[2*self.batch_size:],1)==pred[2*self.batch_size:])+0))
204 | count += self.batch_size
205 | print "Phase =", split.capitalize()
206 | print "======================= Loss ====================="
207 | print '[$] Loss =', np.mean(loss_t)
208 | print "======================= Acc ======================"
209 | print '[$] Right Pair Acc. =', sum(right_acc_t)/count
210 | print '[$] Wrong Pair Acc. =', sum(wrong_acc_t)/count
211 | print '[$] Fake Pair Acc. =', sum(fake_acc_t)/count
212 |
213 | def save(self, checkpoint_dir, step):
214 | model_name = "D_Pretrained"
215 | model_dir = "%s" % (self.dataset.dataset_name)
216 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, "D_CNN_pretrained_sample")
217 | if not os.path.exists(checkpoint_dir):
218 | os.makedirs(checkpoint_dir)
219 | self.saver.save(self.sess,
220 | os.path.join(checkpoint_dir, model_name),
221 | global_step=step)
222 |
223 | def load(self, checkpoint_dir):
224 | print(" [*] Reading checkpoints...")
225 |
226 | model_dir = "%s" % (self.dataset.dataset_name)
227 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
228 |
229 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
230 | if ckpt and ckpt.model_checkpoint_path:
231 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
232 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
233 | return True
234 | else:
235 | return False
236 |
237 |
--------------------------------------------------------------------------------
/show-adapt-tell/pretrain_G.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import time
4 | import tensorflow as tf
5 | import numpy as np
6 | from tqdm import tqdm
7 | from coco_spice.pycocoevalcap.eval import COCOEvalCap
8 | import pdb
9 |
10 | class G_pretrained():
11 | def __init__(self, sess, dataset, conf=None):
12 | self.sess = sess
13 | self.batch_size = conf.batch_size
14 | self.max_iter = conf.max_iter
15 | self.num_train = dataset.num_train
16 | self.hidden_size = conf.G_hidden_size # 512
17 | self.dict_size = dataset.dict_size
18 | self.max_words = dataset.max_words
19 | self.dataset = dataset
20 | self.load_ckpt = conf.load_ckpt
21 | self.is_train = conf.is_train
22 | if self.is_train:
23 | self.drop_out_rate = conf.drop_out_rate
24 | else:
25 | self.drop_out_rate = 0
26 |
27 | self.init_lr = conf.init_lr
28 | self.lr_decay = conf.lr_decay
29 | self.lr_decay_every = conf.lr_decay_every
30 | self.ss_ascent = conf.ss_ascent
31 | self.ss_ascent_every = conf.ss_ascent_every
32 | self.ss_max = conf.ss_max
33 | # train pretrained model -> no need to add START_TOKEN
34 | # -> need to add END_TOKEN
35 | self.img_dims = self.dataset.img_dims
36 | self.lstm_steps = self.max_words+1
37 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False)
38 | #self.optim = tf.train.AdamOptimizer(conf.learning_rate)
39 | self.checkpoint_dir = conf.checkpoint_dir
40 | self.START = self.dataset.word2ix[u'']
41 | self.END = self.dataset.word2ix[u'']
42 | self.UNK = self.dataset.word2ix[u'']
43 | self.NOT = self.dataset.word2ix[u'']
44 |
45 | self.coins = tf.placeholder('bool', [self.batch_size, self.max_words-1])
46 | self.images_one = tf.placeholder('float32', [100, self.img_dims])
47 | self.images = tf.placeholder('float32', [self.batch_size, self.img_dims])
48 | self.target_sentence = tf.placeholder('int32', [self.batch_size, self.max_words])
49 | self.mask = tf.placeholder('float32', [self.batch_size, self.max_words]) # mask out the loss
50 | self.build_Generator(name='G')
51 | self._predict_words_argmax = []
52 | self._predict_words_sample = []
53 | self._predict_words_argmax = self.build_Generator_test(100, self._predict_words_argmax, type='max', name='G')
54 | self._predict_words_sample = self.build_Generator_test(100, self._predict_words_sample, type='sample', name='G')
55 |
56 | self.lr = tf.Variable(self.init_lr, trainable=False)
57 | self.optim = tf.train.AdamOptimizer(self.lr)
58 |
59 | params = tf.trainable_variables()
60 | self.G_params_dict = {}
61 | for param in params:
62 | self.G_params_dict.update({param.name:param})
63 |
64 | def build_Generator_test(self, batch_size=100, predict_words=None, type='max', name='generator'):
65 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
66 | with tf.variable_scope(name):
67 | tf.get_variable_scope().reuse_variables()
68 | with tf.variable_scope("images"):
69 | # "generator/images"
70 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size], "float32", random_uniform_init)
71 | with tf.variable_scope("lstm"):
72 | # WONT BE CREATED HERE
73 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True)
74 | # lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
75 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
76 | # "generator/embedding"
77 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init)
78 | with tf.variable_scope("output"):
79 | # "generator/output"
80 | output_W = tf.get_variable("output_W", [self.hidden_size, self.dict_size], "float32", random_uniform_init)
81 |
82 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[batch_size])
83 | state = lstm1.zero_state(batch_size, 'float32')
84 | for j in range(self.lstm_steps):
85 | tf.get_variable_scope().reuse_variables()
86 | if j == 0:
87 | images_emb = tf.matmul(self.images_one, images_W) # B,H
88 | lstm1_in = images_emb
89 | elif j == 1:
90 | with tf.device("/cpu:0"):
91 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
92 | else:
93 | with tf.device("/cpu:0"):
94 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, sample_words)
95 | with tf.variable_scope("lstm"):
96 | # "generator/lstm"
97 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
98 | if j > 0:
99 | logits = tf.matmul(output, output_W) # B,D
100 | #log_probs = tf.log(tf.nn.softmax(logits)) # B,D
101 | # word drawn from the multinomial distribution
102 | #sample_words = tf.reshape(tf.multinomial(log_probs,1), [batch_size])
103 | sample_words = tf.argmax(logits, 1)
104 | predict_words.append(sample_words)
105 |
106 | predict_words = tf.pack(predict_words)
107 | predict_words = tf.transpose(predict_words, [1,0])
108 | return predict_words
109 |
110 | def build_Generator(self, name='generator'):
111 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
112 | with tf.variable_scope(name):
113 | with tf.variable_scope("images"):
114 | # "generator/images"
115 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size], "float32", random_uniform_init)
116 | with tf.variable_scope("lstm"):
117 | # "generator/lstm"
118 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True)
119 | lstm1 = tf.nn.rnn_cell.DropoutWrapper(lstm1, output_keep_prob=1-self.drop_out_rate)
120 | with tf.device("/cpu:0"), tf.variable_scope("embedding"):
121 | # "generator/embedding"
122 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init)
123 | with tf.variable_scope("output"):
124 | # "generator/output"
125 | output_W = tf.get_variable("output_W", [self.hidden_size, self.dict_size], "float32", random_uniform_init)
126 |
127 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size])
128 | state = lstm1.zero_state(self.batch_size, 'float32')
129 | self.pretrained_loss = 0.
130 | for j in range(self.lstm_steps):
131 | if j == 0:
132 | images_emb = tf.matmul(self.images, images_W) # B,H
133 | lstm1_in = images_emb
134 | else:
135 | tf.get_variable_scope().reuse_variables()
136 | with tf.device("/cpu:0"):
137 | if j == 1:
138 | #
139 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
140 | else:
141 | # schedule sampling
142 | word = tf.select(self.coins[:,j-2], self.target_sentence[:,j-2], tf.stop_gradient(word_predict))
143 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, word)
144 |
145 | with tf.variable_scope("lstm"):
146 | # "generator/lstm"
147 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
148 |
149 | if j > 0:
150 | logits = tf.matmul(output, output_W) # B,D
151 | # calculate loss
152 | pretrained_loss_t = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, self.target_sentence[:,j-1])
153 | pretrained_loss_t = tf.reduce_sum(tf.mul(pretrained_loss_t, self.mask[:,j-1]))
154 | self.pretrained_loss += pretrained_loss_t
155 | word_predict = tf.to_int32(tf.argmax(logits, 1)) # B
156 |
157 |
158 | self.pretrained_loss /= tf.reduce_sum(self.mask)
159 | self.pretrained_loss_sum = tf.scalar_summary("pretrained_loss", self.pretrained_loss)
160 |
161 | def train(self):
162 | '''
163 | Train a caption generator with XE
164 | with learning rate decay and schedule sampling
165 | '''
166 |
167 | self.train_op = self.optim.minimize(self.pretrained_loss, global_step=self.global_step)
168 | self.writer = tf.train.SummaryWriter("./logs/G_pretrained", self.sess.graph)
169 | tf.initialize_all_variables().run()
170 | self.saver = tf.train.Saver(var_list=self.G_params_dict, max_to_keep=30)
171 | try:
172 | self.saver.restore(self.sess, self.load_ckpt)
173 | print "[#] Restore", self.load_ckpt
174 | except:
175 | print "[#] Fail to restore"
176 |
177 | self.current_lr = self.init_lr
178 | self.current_ss = 0.
179 | self.tr_count = 0
180 | for idx in range(self.max_iter//3000):
181 | print "Evaluate source test set..."
182 | self.evaluate('test', self.tr_count)
183 | print "Evaluate target test set..."
184 | self.evaluate('target_test', self.tr_count)
185 | self.evaluate('train', self.tr_count, eval_algo='max')
186 | self.evaluate('train', self.tr_count, eval_algo='sample')
187 | self.save(self.checkpoint_dir, self.tr_count)
188 | for k in tqdm(range(3000)):
189 | tgt_text = self.dataset.flickr_caption_sequential_sample(self.batch_size)
190 | image_feature, target, img_idx = self.dataset.sequential_sample(self.batch_size)
191 | # dummy_feature = np.zeros(image_feature.shape)
192 | nonENDs = np.array(map(lambda x: (x != self.NOT).sum(), target))
193 | mask = np.zeros([self.batch_size, self.max_words])
194 | tgt_mask = np.zeros([self.batch_size, self.max_words])
195 | for ind, row in enumerate(mask):
196 | # mask out the
197 | row[0:nonENDs[ind]] = 1
198 |
199 | for ind, row in enumerate(tgt_mask):
200 | row[0:nonENDs[ind]] = 1
201 | # schedule sampling condition
202 | coins = np.zeros([self.batch_size, self.max_words-1])
203 | for (x,y), value in np.ndenumerate(coins):
204 | if y==0:
205 | coins[x][y] = True
206 | elif np.random.rand() < self.current_ss:
207 | coins[x][y] = False
208 | else:
209 | coins[x][y] = True
210 |
211 |
212 | _, loss, summary_str = self.sess.run([self.train_op, self.pretrained_loss, self.pretrained_loss_sum],{
213 | self.images: image_feature,
214 | self.target_sentence: target,
215 | self.mask: mask,
216 | self.coins: coins
217 | })
218 | # _, dummy_loss, _ = self.sess.run([self.train_op, self.pretrained_loss, self.pretrained_loss_sum],{
219 | # self.images: dummy_feature,
220 | # self.target_sentence: tgt_text,
221 | # self.mask: tgt_mask,
222 | # self.coins: coins
223 | # })
224 |
225 | self.writer.add_summary(summary_str, self.tr_count)
226 | self.tr_count += 1
227 |
228 | #if k%1000 == 0:
229 | # print " [*] Iter {}, lr={}, ss={}, loss={}".format(self.tr_count, self.current_lr, self.current_ss, loss)
230 |
231 | if idx == 0 and k != 0 and k%1000 == 0:
232 | self.evaluate('train', self.tr_count, eval_algo='max')
233 | self.evaluate('train', self.tr_count, eval_algo='sample')
234 | self.evaluate('test', self.tr_count)
235 | self.evaluate('target_test', self.tr_count)
236 | # schedule sampling
237 | if (self.tr_count+1)%self.ss_ascent_every == 0 and self.current_ss']
28 | self.END = self.dataset.word2ix[u'']
29 | self.UNK = self.dataset.word2ix[u'']
30 | self.NOT = self.dataset.word2ix[u'']
31 |
32 | self.global_step = tf.get_variable('global_step', [],initializer=tf.constant_initializer(0), trainable=False)
33 | self.optim = tf.train.AdamOptimizer(conf.learning_rate)
34 |
35 | # placeholder
36 | self.fake_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="fake_images")
37 | self.wrong_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="wrong_images")
38 | self.right_images = tf.placeholder(tf.float32, [self.batch_size, self.img_dims], name="right_images")
39 |
40 | self.fake_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="fake_text")
41 | self.wrong_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="wrong_text")
42 | self.right_text = tf.placeholder(tf.int32, [self.batch_size, self.max_words], name="right_text")
43 |
44 | self.fake_length = tf.placeholder(tf.int32, [self.batch_size], name="fake_length")
45 | self.wrong_length = tf.placeholder(tf.int32, [self.batch_size], name="wrong_length")
46 | self.right_length = tf.placeholder(tf.int32, [self.batch_size], name="right_length")
47 |
48 | # build graph
49 | self.D_fake, D_fake_logits = self.build_Discriminator(self.fake_images, self.fake_text, self.fake_length,
50 | name="D", reuse=False)
51 | self.D_wrong, D_wrong_logits = self.build_Discriminator(self.wrong_images, self.wrong_text, self.wrong_length,
52 | name="D", reuse=True)
53 | self.D_right, D_right_logits = self.build_Discriminator(self.right_images, self.right_text, self.right_length,
54 | name="D", reuse=True)
55 | # loss
56 | self.D_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logits, tf.zeros_like(self.D_fake)))
57 | self.D_wrong_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_wrong_logits, tf.zeros_like(self.D_wrong)))
58 | self.D_right_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_right_logits, tf.ones_like(self.D_right)))
59 | self.loss = self.D_fake_loss+self.D_wrong_loss+self.D_right_loss
60 | # Summary
61 | self.D_fake_loss_sum = tf.scalar_summary("fake_loss", self.D_fake_loss)
62 | self.D_wrong_loss_sum = tf.scalar_summary("wrong_loss", self.D_wrong_loss)
63 | self.D_right_loss_sum = tf.scalar_summary("right_loss", self.D_right_loss)
64 | self.loss_sum = tf.scalar_summary("train_loss", self.loss)
65 |
66 | self.D_params_dict = {}
67 | params = tf.trainable_variables()
68 | for param in params:
69 | self.D_params_dict.update({param.name:param})
70 |
71 | def build_Discriminator(self, images, text, length, name="discriminator", reuse=False):
72 |
73 | ### sentence: B, S
74 | random_uniform_init = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
75 | with tf.variable_scope(name):
76 | if reuse:
77 | tf.get_variable_scope().reuse_variables()
78 | with tf.variable_scope("lstm"):
79 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_size, state_is_tuple=True)
80 | with tf.device('/cpu:0'), tf.variable_scope("embedding"):
81 | word_emb_W = tf.get_variable("word_emb_W", [self.dict_size, self.hidden_size], "float32", random_uniform_init)
82 | with tf.variable_scope("text_emb"):
83 | text_W = tf.get_variable("text_W", [2*self.hidden_size, self.hidden_size],"float32", random_uniform_init)
84 | text_b = tf.get_variable("text_b", [self.hidden_size], "float32", random_uniform_init)
85 | with tf.variable_scope("images_emb"):
86 | images_W = tf.get_variable("images_W", [self.img_dims, self.hidden_size],"float32", random_uniform_init)
87 | images_b = tf.get_variable("images_b", [self.hidden_size], "float32", random_uniform_init)
88 | with tf.variable_scope("scores_emb"):
89 | # "generator/scores"
90 | scores_W = tf.get_variable("scores_W", [self.hidden_size, 1], "float32", random_uniform_init)
91 | scores_b = tf.get_variable("scores_b", [1], "float32", random_uniform_init)
92 |
93 | state = lstm1.zero_state(self.batch_size, 'float32')
94 | start_token = tf.constant(self.START, dtype=tf.int32, shape=[self.batch_size])
95 | # VQA use states
96 | state_list = []
97 | for j in range(self.lstm_steps):
98 | if j > 0:
99 | tf.get_variable_scope().reuse_variables()
100 | with tf.device('/cpu:0'):
101 | if j ==0:
102 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token)
103 | else:
104 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, text[:,j-1])
105 | with tf.variable_scope("lstm"):
106 | # "generator/lstm"
107 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) # output: B,H
108 | # apppend state from index 1 (the start of the word)
109 | if j > 0:
110 | state_list.append(tf.concat(1,[state[0], state[1]]))
111 |
112 | state_list = tf.pack(state_list) # S,B,2H
113 | state_list = tf.transpose(state_list, [1,0,2]) # B,S,2H
114 | state_flatten = tf.reshape(state_list, [-1, 2*self.hidden_size]) # B*S, 2H
115 | # length-1 => index start from 0
116 | idx = tf.range(self.batch_size)*self.max_words + (length-1) # B
117 | state_gather = tf.gather(state_flatten, idx) # B, 2H
118 |
119 | # text embedding
120 | text_emb = tf.matmul(state_gather, text_W) + text_b # B,H
121 | text_emb = tf.nn.tanh(text_emb)
122 | # images embedding
123 | images_emb = tf.matmul(images, images_W) + images_b # B,H
124 | images_emb = tf.nn.tanh(images_emb)
125 | # embed to score
126 | logits = tf.mul(text_emb, images_emb) # B,H
127 | score = tf.matmul(logits, scores_W) + scores_b
128 |
129 | return tf.nn.sigmoid(score), score
130 |
131 | def train(self):
132 |
133 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step)
134 | self.writer = tf.train.SummaryWriter("./logs/D_pretrained", self.sess.graph)
135 | self.summary_op = tf.merge_all_summaries()
136 | tf.initialize_all_variables().run()
137 | self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=self.max_to_keep)
138 | count = 0
139 | for idx in range(self.max_iter//3000):
140 | self.save(self.checkpoint_dir, count)
141 | self.evaluate('test', count)
142 | self.evaluate('train', count)
143 | for k in tqdm(range(3000)):
144 | right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size)
145 | right_length = np.sum((right_text!=self.NOT)+0, 1)
146 | fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size)
147 | fake_length = np.sum((fake_text!=self.NOT)+0, 1)
148 | wrong_text = self.dataset.get_wrong_text(self.batch_size)
149 | wrong_length = np.sum((wrong_text!=self.NOT)+0, 1)
150 | feed_dict = {self.right_images:right_images, self.right_text:right_text, self.right_length:right_length,
151 | self.fake_images:fake_images, self.fake_text:fake_text, self.fake_length:fake_length,
152 | self.wrong_images:right_images, self.wrong_text:wrong_text, self.wrong_length:wrong_length}
153 | _, loss, summary_str = self.sess.run([self.train_op, self.loss, self.summary_op], feed_dict)
154 | self.writer.add_summary(summary_str, count)
155 | count += 1
156 |
157 | def evaluate(self, split, count):
158 |
159 | if split == 'test':
160 | num_test_pair = -1
161 | elif split == 'train':
162 | num_test_pair = 5000
163 | right_images, right_text, _ = self.dataset.get_paired_data(num_test_pair, phase=split)
164 | # the true paired data we get
165 | num_test_pair = len(right_images)
166 | fake_images, fake_text, _ = self.negative_dataset.get_paired_data(num_test_pair, phase=split)
167 | random_idx = range(num_test_pair)
168 | np.random.shuffle(random_idx)
169 | wrong_text = np.squeeze(right_text[random_idx, :])
170 | D_right_loss_t = []
171 | D_fake_loss_t = []
172 | D_wrong_loss_t = []
173 | D_right_acc_t = []
174 | D_fake_acc_t = []
175 | D_wrong_acc_t = []
176 | count = 0.
177 | for i in range(num_test_pair//self.batch_size):
178 | right_images_batch = right_images[i*self.batch_size:(i+1)*self.batch_size,:]
179 | fake_images_batch = fake_images[i*self.batch_size:(i+1)*self.batch_size,:]
180 | right_text_batch = right_text[i*self.batch_size:(i+1)*self.batch_size,:]
181 | fake_text_batch = fake_text[i*self.batch_size:(i+1)*self.batch_size,:]
182 | wrong_text_batch = wrong_text[i*self.batch_size:(i+1)*self.batch_size,:]
183 | right_length_batch = np.sum((right_text_batch!=self.NOT)+0, 1)
184 | fake_length_batch = np.sum((fake_text_batch!=self.NOT)+0, 1)
185 | wrong_length_batch = np.sum((wrong_text_batch!=self.NOT)+0, 1)
186 | feed_dict = {self.right_images:right_images_batch, self.right_text:right_text_batch,
187 | self.right_length:right_length_batch, self.fake_images:fake_images_batch,
188 | self.fake_text:fake_text_batch, self.fake_length:fake_length_batch,
189 | self.wrong_images:right_images_batch, self.wrong_text:wrong_text_batch,
190 | self.wrong_length:wrong_length_batch}
191 | D_right, D_fake, D_wrong, D_right_loss, D_fake_loss, D_wrong_loss = self.sess.run([self.D_right, self.D_fake,
192 | self.D_wrong, self.D_right_loss, self.D_fake_loss, self.D_wrong_loss], feed_dict)
193 | D_right_loss_t.append(D_right_loss)
194 | D_fake_loss_t.append(D_fake_loss)
195 | D_wrong_loss_t.append(D_wrong_loss)
196 | D_right_acc_t.append(np.sum((D_right>0.5)+0))
197 | D_fake_acc_t.append(np.sum((D_fake<0.5)+0))
198 | D_wrong_acc_t.append(np.sum((D_wrong<0.5)+0))
199 | count += self.batch_size
200 |
201 | print "Phase =", split.capitalize()
202 | print "======================= Loss ====================="
203 | print '[$] Right Pair Loss =', sum(D_right_loss_t)/count
204 | print '[$] Wrong Pair Loss =', sum(D_wrong_loss_t)/count
205 | print '[$] Fake Pair Loss =', sum(D_fake_loss_t)/count
206 | print "======================= Acc ======================"
207 | print '[$] Right Pair Acc. =', sum(D_right_acc_t)/count
208 | print '[$] Wrong Pair Acc. =', sum(D_wrong_acc_t)/count
209 | print '[$] Fake Pair Acc. =', sum(D_fake_acc_t)/count
210 |
211 | def save(self, checkpoint_dir, step):
212 | model_name = "D_Pretrained"
213 | model_dir = "%s" % (self.dataset.dataset_name)
214 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir, "D_pretrained")
215 | if not os.path.exists(checkpoint_dir):
216 | os.makedirs(checkpoint_dir)
217 | self.saver.save(self.sess,
218 | os.path.join(checkpoint_dir, model_name),
219 | global_step=step)
220 |
221 | def load(self, checkpoint_dir):
222 | print(" [*] Reading checkpoints...")
223 |
224 | model_dir = "%s" % (self.dataset.dataset_name)
225 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
226 |
227 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
228 | if ckpt and ckpt.model_checkpoint_path:
229 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
230 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
231 | return True
232 | else:
233 | return False
234 |
235 |
--------------------------------------------------------------------------------
/show-adapt-tell/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import time
3 | import json
4 | import h5py
5 | from functools import reduce
6 | from tensorflow.contrib.layers.python.layers import initializers
7 | import cPickle
8 | import numpy as np
9 |
10 |
11 | def load_h5(file):
12 | train_data = {}
13 | with h5py.File(file,'r') as hf:
14 | for k in hf.keys():
15 | tem = hf.get(k)
16 | train_data[k] = np.array(tem)
17 | return train_data
18 |
19 | def load_json(file):
20 | fo = open(file, 'rb')
21 | dict = json.load(fo)
22 | fo.close()
23 | return dict
24 |
25 | def unpickle(file):
26 | fo = open(file, 'rb')
27 | dict = cPickle.load(fo)
28 | fo.close()
29 | return dict
30 |
31 | def load_h5py(file, key=None):
32 | if key != None:
33 | with h5py.File(file,'r') as hf:
34 | data = hf.get(key)
35 | return np.asarray(data)
36 | else:
37 | print '[-] Can not load file'
38 |
--------------------------------------------------------------------------------