├── LICENSE
├── README.md
├── bipartite_utils.py
├── data
└── README.md
├── eval_utils.py
├── image_feature_extract
├── extract.py
└── make_python_image_info.py
├── image_utils.py
├── model_utils.py
├── paper_commands
├── generate_training_commands_paper.py
├── generate_training_commands_paper_finetune.py
└── generate_training_commands_training_dynamics.py
├── requirements.txt
├── summary.png
├── text_utils.py
├── train_doc.py
├── training_utils.py
├── visualize_predictions_graph.py
└── visualize_predictions_html.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Jack Hessel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## What's in here?
2 |
3 |
4 | This repository contains code to accompany "Unsupervised Discovery of
5 | Multimodal Links in Multi-image/Multi-sentence Documents." (EMNLP 2019; [link](https://arxiv.org/abs/1904.07826))
6 |
7 |
8 |
9 |
10 |
11 | If you find the code, data, or paper useful, please consider citing
12 | ```
13 | @inproceedings{hessel-lee-mimno-2019unsupervised,
14 | title={Unsupervised Discovery of Multimodal Links in Multi-Image, Multi-Sentence Documents},
15 | author={Hessel, Jack and Lee, Lillian and Mimno, David},
16 | booktitle={EMNLP},
17 | year={2019}
18 | }
19 | ```
20 |
21 | *note: I recently upgraded the implementation of this paper to TF2. If you're interested in the exact code used for the EMNLP paper for reproduction purposes, you should check out the tf1 branch, and run with those requirements. However --- I'd highly reccomend using the main tf2 branch. It is much faster, and I've been able to reproduce the paper results with it.*
22 |
23 | ## Requirements
24 | This code requires python3 and several python
25 | libraries. You can install the python requirements with:
26 |
27 | ```
28 | pip3 install -r requirements.txt
29 | ```
30 |
31 | Also --- it helps performance to initialize the word embedding
32 | matrices with word2vec embeddings. You can download those embeddings
33 | [here](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit)
34 | (be sure to extract them). When you run the training command, it is
35 | recommended to use the option `--word2vec_binary XXX` where XXX is the
36 | path to the extracted/downloaded word embeddings.
37 |
38 | *A note about evaluating with MT metrics:* the machine translation
39 | metrics, with the exception of sacrebleu, are based on
40 | [pycocoevalcap](https://github.com/salaniz/pycocoevalcap) which itself
41 | has several dependencies. In particular, it requires java 1.8+, and
42 | for the permissions setting to be set so that temporary files can be
43 | written wherever pip installs pycocoevalcap. If you don't have these
44 | additional, only BLEU will be computed, and a warning will print.
45 |
46 | ## How to run
47 |
48 | ### Preparing the dataset
49 |
50 | The training script takes three inputs:
51 |
52 | 1. A json of training/validation/test documents. This json stores a dictionary with three keys: `train`, `val`, and `test`. Each of the keys maps to a list of documents. A document is a list containing 3 things: `[list_of_images, list_of_sentences, metadata]`.
53 | - `list_of_images` is a list of `(identifier, label_text_idx)` tuples, where the identifier is the name of the image, and `label_text_idx` is an integer indicating the index of the corresponding ground-truth sentence in `list_of_sentences`. If there are no labels in the corpus, this index can be set to `None`. If there are labels, but this particular image doesn't correspond to a sentence, you can set the index to `-1`.
54 | - `list_of_sentences` is a list of `(sentence, label_image_idx)` tuples, where sentence is the sentence, and `label_image_idx` is an integer indicating the index of the corresponding ground-truth image in `list_of_images`. If there are no labels in the corpus, this index can be set to `None`. If there are labels, but this particular image doesn't correspond to a sentence, you can set the index to `-1`.
55 | - `metadata` is an optional document identifier.
56 | 2. A json mapping image ids (see `list_of_images`) to row indices in the features matrix.
57 | 3. An image feature matrix, where `matrix[id2row[img_id]]` is the image feature vector corresponding to the image with image id `img_id` and `id2row` is the dictionary stored in the previously described json mapping file.
58 |
59 | Here is an example document from the MSCOCO dataset.
60 | ```
61 | [[['000000074794', -1],
62 | ['000000339384', 9],
63 | ['000000100064', -1],
64 | ['000000072850', 8],
65 | ['000000046251', -1],
66 | ['000000531828', -1],
67 | ['000000574207', 0],
68 | ['000000185258', 5],
69 | ['000000416357', 1],
70 | ['000000490222', -1]],
71 | [['Two street signs at an intersection on a cloudy day.', -1],
72 | ['A man holding a tennis racquet on a tennis court.', -1],
73 | ['A seagull opens its mouth while standing on a beach.', -1],
74 | ['a man reaching up to hit a tennis ball', -1],
75 | ['A horse sticks his head out of an open stable door. ', -1],
76 | ['Couple standing on a pier with a lot of flags.', -1],
77 | ['A man is riding a skateboard on a ramp.', -1],
78 | ['A man on snow skis leans on his ski poles as he stands in the snow and '
79 | 'gazes into the distance.',
80 | -1],
81 | ['a close up of a baseball player with a ball and glove', -1],
82 | ['four people jumping in the air and reaching for a frisbee.', -1]],
83 | 'na']
84 | ```
85 |
86 | The [image with ID](http://cocodataset.org/#explore?id=339384)
87 | `000000339384` in the MSCOCO dataset corresponds to the caption with
88 | sentence with index 9 in this document, "four people jumping in the
89 | air and reaching for a frisbee.". The underlying graph is undirected,
90 | so the labels are stored only in the image list (though, if you like,
91 | you could redundantly store them on the text-side). For the MSCOCO
92 | dataset, the metadata is un-used.
93 |
94 | The exact train/val/test splits we used, along with pre-extracted
95 | image features, are available for download (see below). You can download
96 | these and extract them in the `data` folder.
97 |
98 | ### Extracting image features for a new dataset
99 |
100 | If you would like to extract image features for a new dataset, there
101 | are a number of existing codebases for that, depending on what neural
102 | network you would like to use. We have included the script that we
103 | used to do that, if you'd like to use ours. In particular, you should:
104 |
105 | 1. Get all of the images of interest into a single folder. Your images should all have unique filenames, as the scripts assume that, e.g., the name of the `jpg` file is the identifier, e.g., `my_images/000000072850.jpg`'s identifier will be `000000072850`.
106 | 2. Create a text file with the full paths of each image
107 | 3. Call `python3 image_feature_extract/extract.py [filenames text file] extracted_features`
108 | 4. Call `python3 make_python_image_info.py extracted_features [filenames text file]`
109 |
110 | This will output a feature matrix (in npy format) and an id2row json
111 | file. These are two of the three arguments. Note --- you may need to
112 | modify `make_python_image_info.py` if your images have different
113 | folders, or if you have multiple images with the same name but
114 | different extensions, e.g., `id.jpg` and `id.png` will both erronously
115 | be mapped to `id`. I may add support for this later (in addition to
116 | cleaning up these scripts...).
117 |
118 | ## How to run the code
119 |
120 | An example training command for the mscoco dataset with reasonable settings is:
121 | ```
122 | python3 train_doc.py data/mscoco/docs.json \
123 | --image_id2row data/mscoco/id2row.json \
124 | --image_features data/mscoco/features.npy \
125 | --word2vec_binary data/GoogleNews-vectors-negative300.bin \
126 | --cached_word_embeddings mscoco_cached_word_embs.json \
127 | --print_metrics 1 \
128 | --output mscoco_results.pkl
129 | ```
130 |
131 | note that even though metrics are printing during training if you use
132 | `--print_metrics 1`, there is no early stopping/supervision happening
133 | on the labels during training.
134 |
135 | you can run this to get more information about particular training options
136 | ```
137 | python3 train_doc.py --help
138 | ```
139 |
140 | From the paper, here's an example of running with hard negative
141 | mining, the AP similarity function, and 20 negative samples
142 | ```
143 | python3 train_doc.py data/mscoco/docs.json --image_id2row data/mscoco/id2row.json \
144 | --image_features data/mscoco/features.npy \
145 | --word2vec_binary data/GoogleNews-vectors-negative300.bin \
146 | --cached_word_embeddings mscoco_cached_word_embs.json \
147 | --print_metrics 1 \
148 | --output mscoco_results.pkl \
149 | --sim_mode AP \
150 | --docs_per_batch 21 \
151 | --cached_vocab mscoco_vocab.json
152 | ```
153 |
154 | ## How to reproduce the results of the paper:
155 |
156 | The datasets we use with specific splits/pre-extracted image features
157 | are available for download. If you are just using the datasets, please
158 | cite the original creators of the datasets. Furthermore, all datasets
159 | are subsets of their original creator's releases; please use the
160 | versions from the original links if you are looking for more complete
161 | datasets!
162 |
163 | - MSCOCO ([original source](http://cocodataset.org/#home)) [link](https://drive.google.com/open?id=1LGqUst-BB8N4nFPNGHD0uVa3x_cAZ7UV)
164 | - DII ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1zFouzVhXvnK19zv3AYT-wZJt8SFTcRXY)
165 | - SIS ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1MN6gPGhymAHvPJL6dRTu-VbXfYlI0L7-)
166 | - DII-Stress ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1vLOMftRh8U5r3sn29X2l8XxVXQppsLYS)
167 | - RQA ([original source](https://hucvl.github.io/recipeqa/)) [link](https://drive.google.com/open?id=1BbD1OnV4h02QUk1eZT1hFWWKlDwUyz3O)
168 | - DIY (we collected this dataset) [link](https://drive.google.com/open?id=1EdgL2VYrVTLccP8wHpynpFhv3PNuZiOv)
169 | - WIKI ([original source](https://www.imageclef.org/wikidata)) [link](https://drive.google.com/open?id=1Ecb1LkTXX4sskx-PLB2o3vMru-8I8rEy)
170 |
171 | In addition, we have included scripts that generate the exact training commands executed in the paper itself. These are located in the paper_commands directory.
172 | *Note, however, that the code used for the paper is now located in the tf1 branch. The main branch has been ported to TF2.*
173 |
--------------------------------------------------------------------------------
/bipartite_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lapjv import lapjv
3 |
4 | def generate_fast_hungarian_solving_function():
5 | def base_solve(W):
6 | orig_shape = W.shape
7 | if orig_shape[0] != orig_shape[1]:
8 | if orig_shape[0] > orig_shape[1]:
9 | pad_idxs = [[0, 0], [0, W.shape[0]-W.shape[1]]]
10 | col_pad = True
11 | else:
12 | pad_idxs = [[0, W.shape[1]-W.shape[0]], [0, 0]]
13 | col_pad = False
14 | W = np.pad(W, pad_idxs, 'constant', constant_values=-100)
15 | sol, _, cost = lapjv(-W)
16 |
17 | i_s = np.arange(len(sol))
18 | j_s = sol[i_s]
19 |
20 | sort_idxs = np.argsort(-W[i_s, j_s])
21 | i_s, j_s = map(lambda x: x[sort_idxs], [i_s, j_s])
22 |
23 | if orig_shape[0] != orig_shape[1]:
24 | if col_pad:
25 | valid_idxs = np.where(j_s < orig_shape[1])[0]
26 | else:
27 | valid_idxs = np.where(i_s < orig_shape[0])[0]
28 | i_s, j_s = i_s[valid_idxs], j_s[valid_idxs]
29 |
30 | indices = np.hstack([np.expand_dims(i_s, -1), np.expand_dims(j_s, -1)]).astype(np.int32)
31 | return indices
32 |
33 | def hungarian_solve(W, k, max_val=1000):
34 | min_dim = min(*W.shape)
35 | if k <= 0 or k >= min_dim:
36 | return base_solve(W)
37 |
38 | add_rows = W.shape[0] > W.shape[1]
39 | add_len = min_dim-k
40 |
41 | if add_rows:
42 | add_mat = np.zeros((add_len, W.shape[1]))
43 | add_mat[:] = max_val
44 | new_W = np.vstack([W, add_mat])
45 | else:
46 | add_mat = np.zeros((W.shape[0], add_len))
47 | add_mat[:] = max_val
48 | new_W = np.hstack([W, add_mat])
49 |
50 | indices = base_solve(new_W)
51 | indices = indices[-k:, :]
52 | return indices
53 |
54 | return hungarian_solve
55 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## What goes in here?
2 |
3 | Dataset info (document jsons, image features, image mappings, raw images)
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from sklearn.metrics import roc_auc_score
4 | import sacrebleu
5 |
6 | # pycocoevalcap imports
7 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
8 | from pycocoevalcap.bleu.bleu import Bleu
9 | from pycocoevalcap.meteor.meteor import Meteor
10 | from pycocoevalcap.rouge.rouge import Rouge
11 | from pycocoevalcap.cider.cider import Cider
12 |
13 |
14 | import collections
15 | import tqdm
16 | import image_utils
17 | import text_utils
18 | import os
19 | import sklearn.preprocessing
20 | import json
21 |
22 |
23 | def compute_match_metrics_doc(docs,
24 | image_matrix,
25 | image_idx2row,
26 | vocab,
27 | text_trans,
28 | image_trans,
29 | args,
30 | ks=[1,2,3,4,5,6,7,8,9,10]):
31 |
32 | all_aucs, all_match_metrics = [], collections.defaultdict(list)
33 | all_texts, all_images = [], []
34 |
35 |
36 | #for MT metrics
37 | all_refs = []
38 | all_sys = []
39 |
40 | # supress warwning that I believe is thrown by keras' timedistributed
41 | # over dynamic tensors...
42 | tf.get_logger().setLevel('ERROR')
43 | for images, text, meta in tqdm.tqdm(docs):
44 | # text by image matrix
45 | cur_text = [t[0] for t in text]
46 | cur_images = [t[0] for t in images]
47 | n_text, n_image = map(len, [cur_text, cur_images])
48 |
49 | cur_text = text_utils.text_to_matrix(cur_text, vocab, max_len=args.seq_len)
50 |
51 | if args.end2end:
52 | cur_images = image_utils.images_to_images(cur_images, False, args)
53 | else:
54 | cur_images = image_utils.images_to_matrix(cur_images, image_matrix, image_idx2row)
55 |
56 | cur_text = np.expand_dims(cur_text, 0)
57 | cur_images = np.expand_dims(cur_images, 0)
58 |
59 | text_vec = text_trans.predict(cur_text)
60 | image_vec = image_trans.predict(cur_images)
61 |
62 | text_vec = text_vec[0,:n_text,:]
63 | image_vec = image_vec[0,:n_image,:]
64 |
65 | pred_adj = text_vec.dot(image_vec.transpose())
66 | true_adj = np.zeros((len(text), len(images)))
67 |
68 | # for MT metrics, for each image with a ground-truth sentence,
69 | # extract predicted sentences.
70 | im2best_text_idxs = np.argmax(pred_adj, axis=0)
71 | im2all_predicted_captions = [text[idx][0] for idx in im2best_text_idxs]
72 |
73 | im2predicted_captions = {}
74 | im2ground_truth_captions = collections.defaultdict(list)
75 |
76 | for text_idx, t in enumerate(text):
77 | if t[1] == -1: continue
78 | true_adj[text_idx, t[1]] = 1
79 | for image_idx, t in enumerate(images):
80 | if t[1] == -1: continue
81 | true_adj[t[1], image_idx] = 1
82 |
83 | if np.sum(true_adj.flatten()) == 0:
84 | continue
85 | if np.sum(true_adj.flatten()) == len(true_adj.flatten()):
86 | continue
87 |
88 | for text_gt_idx, image_gt_idx in zip(*np.where(true_adj==1)):
89 | im2predicted_captions[image_gt_idx] = im2all_predicted_captions[image_idx]
90 | im2ground_truth_captions[image_gt_idx].append(text[text_gt_idx][0])
91 |
92 | for img_idx, pred in im2predicted_captions.items():
93 | all_refs.append(im2ground_truth_captions[img_idx])
94 | all_sys.append(pred)
95 |
96 | pred_adj = pred_adj.flatten()
97 | true_adj = true_adj.flatten()
98 |
99 |
100 | pred_order = true_adj[np.argsort(-pred_adj)]
101 | for k in ks:
102 | all_match_metrics[k].append(np.mean(pred_order[:k]))
103 |
104 | all_aucs.append(roc_auc_score(true_adj, pred_adj))
105 | tf.get_logger().setLevel('INFO')
106 |
107 | # give each instance a unique IDX for the metric computation...
108 | all_refs = {idx: refs for idx, refs in enumerate(all_refs)}
109 | all_sys = {idx: pred for idx, pred in enumerate(all_sys)}
110 |
111 | if len(all_refs) > 0:
112 | all_mt_metrics = compute_mt_metrics(all_sys, all_refs, args)
113 | else:
114 | all_mt_metrics = {}
115 | return all_aucs, all_match_metrics, all_mt_metrics
116 |
117 |
118 | def compute_mt_metrics(all_sys, all_refs, args):
119 | '''
120 | # we need a dictionary mapping
121 | all_sys maps {unique_idx --> pred}
122 | all_ref maps {unique_idx --> [ground truths]
123 | '''
124 | res_dict = {}
125 |
126 | # need all cases to have the same number of references for
127 | # sacrebleu. however --- this will not always be the case in our
128 | # data, e.g., if an image has two ground truths. If there's an
129 | # image with 3 ground truth links, we can repeat ground truths
130 | # until all have 3. However --- this duplication modifies some
131 | # of the corpus-level normalizations, in particular, the number
132 | # of ground truth tokens. So --- we should prefer the MSCOCO
133 | # bleu scorer in this case. However --- we'll still compute
134 | # the sacrebleu scores anyway, but include a flag that says
135 | # to not trust them.
136 |
137 | # add in repeated predictions for cases will less than
138 | # the maximum number of references:
139 |
140 | n_refs_max = np.max([len(r) for r in all_refs.values()])
141 | n_refs_min = np.min([len(r) for r in all_refs.values()])
142 |
143 | print('Using {} maximum references for MT metrics'.format(n_refs_max))
144 |
145 | trust_sacrebleu = n_refs_max == n_refs_min
146 |
147 | all_refs_sacrebleu = []
148 | for outer_idx in range(n_refs_max):
149 | cur_refs = [all_refs[inner_idx][min(outer_idx, len(all_refs[inner_idx])-1)]
150 | for inner_idx in range(len(all_refs))]
151 | all_refs_sacrebleu.append(cur_refs)
152 |
153 | sacre_bleu = sacrebleu.corpus_bleu([all_sys[idx] for idx in range(len(all_sys))], all_refs_sacrebleu)
154 | res_dict['sacre_bleu'] = sacre_bleu.score
155 | res_dict['can_trust_sacrebleu_with_global_counts'] = trust_sacrebleu
156 |
157 | if not args.compute_mscoco_eval_metrics:
158 | return res_dict
159 |
160 | try:
161 | tokenizer = PTBTokenizer()
162 | coco_sys = {idx: [{'caption': ' '.join(r.split())}] for idx, r in all_sys.items()}
163 | coco_sys = tokenizer.tokenize(coco_sys)
164 |
165 | coco_ref = {idx: [{'caption': ' '.join(r.split())} for r in refs] for idx, refs in all_refs.items()}
166 | coco_ref = tokenizer.tokenize(coco_ref)
167 |
168 | scorers = [
169 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
170 | (Meteor(), "METEOR"),
171 | (Rouge(), "ROUGE_L"),
172 | (Cider(), "CIDEr")
173 | ]
174 |
175 | for scorer, method in scorers:
176 | score, _ = scorer.compute_score(coco_ref, coco_sys)
177 | if type(method) is list:
178 | for s, m in zip(score, method):
179 | res_dict['MSCOCO_{}'.format(m)] = s
180 | else:
181 | res_dict['MSCOCO_{}'.format(method)] = score
182 | except Exception as e:
183 | print('Unable to compute MSCOCO metrics: {}'.format(e))
184 | print('continuing nonetheless')
185 |
186 | return res_dict
187 |
188 |
189 | def compute_feat_spread(feats_in):
190 | feats_in = sklearn.preprocessing.normalize(feats_in)
191 | mean = np.expand_dims(np.mean(feats_in, axis=0), 0)
192 | feats_in = feats_in - mean
193 | norms = np.linalg.norm(feats_in, axis=1)
194 | norms = norms * norms
195 | return np.mean(norms)
196 |
197 |
198 | def save_predictions(docs,
199 | image_matrix,
200 | image_idx2row,
201 | vocab,
202 | text_trans,
203 | image_trans,
204 | output_dir,
205 | args,
206 | limit=None):
207 |
208 | if limit is not None:
209 | docs = docs[:limit]
210 |
211 | for idx, (images, text, meta) in tqdm.tqdm(enumerate(docs)):
212 | # text by image matrix
213 | identifier = meta if meta else str(idx)
214 | identifier = meta.replace('/', '_')
215 |
216 | cur_text = [t[0] for t in text]
217 | cur_images = [t[0] for t in images]
218 |
219 | n_text, n_image = map(len, [cur_text, cur_images])
220 |
221 | cur_text = text_utils.text_to_matrix(cur_text, vocab, max_len=args.seq_len)
222 |
223 | if args.end2end:
224 | cur_images = image_utils.images_to_images(cur_images, False, args)
225 | else:
226 | cur_images = image_utils.images_to_matrix(cur_images, image_matrix, image_idx2row)
227 |
228 | cur_images = np.expand_dims(cur_images, 0)
229 | cur_text = np.expand_dims(cur_text, 0)
230 |
231 | text_vec = text_trans.predict(cur_text)
232 | image_vec = image_trans.predict(cur_images)
233 |
234 | text_vec = text_vec[0,:n_text,:]
235 | image_vec = image_vec[0,:n_image,:]
236 |
237 | image_spread, text_spread = map(compute_feat_spread,
238 | [text_vec, image_vec])
239 |
240 | cur_out_dir = output_dir + '/{}_textspread_{:.4f}_imagespread_{:.4f}/'.format(
241 | identifier, text_spread, image_spread)
242 |
243 | if not os.path.exists(cur_out_dir):
244 | os.makedirs(cur_out_dir)
245 | pred_adj = text_vec.dot(image_vec.transpose())
246 | np.save(cur_out_dir + '/pred_weights.npy',
247 | pred_adj)
248 | with open(cur_out_dir + '/doc.json', 'w') as f:
249 | f.write(json.dumps((images, text, meta)))
250 |
251 |
252 | def print_all_metrics(data,
253 | image_features,
254 | image_idx2row,
255 | word2idx,
256 | single_text_doc_model,
257 | single_img_doc_model,
258 | args):
259 | metrics_dict = {}
260 | aucs, rec2prec, mt_metrics = compute_match_metrics_doc(data,
261 | image_features,
262 | image_idx2row,
263 | word2idx,
264 | single_text_doc_model,
265 | single_img_doc_model,
266 | args)
267 | print('Validation AUC={:.2f}'.format(100*np.mean(aucs)))
268 | metrics_dict['aucs'] = 100 * np.mean(aucs)
269 | prec = {}
270 | prec_str = 'Validation '
271 | ks = sorted(list(rec2prec.keys()))
272 | for k in ks:
273 | res = np.mean(rec2prec[k])*100
274 | metrics_dict['p@{}'.format(k)] = res
275 | prec_str += 'p@{}={:.2f} '.format(k, res)
276 | print(prec_str.strip())
277 | print('Machine translation metrics: {}'.format(str(mt_metrics)))
278 | metrics_dict['mt_metrics'] = mt_metrics
279 | return metrics_dict
280 |
--------------------------------------------------------------------------------
/image_feature_extract/extract.py:
--------------------------------------------------------------------------------
1 | from keras.applications.densenet import preprocess_input as preprocess_input_dn
2 | from keras.applications.densenet import DenseNet169
3 |
4 | from keras.preprocessing import image
5 | from keras.models import Sequential, Model
6 | from keras.layers import Flatten
7 | import os, sys, numpy as np
8 | import keras
9 | import tensorflow as tf
10 | import keras.backend as K
11 |
12 |
13 | def load_images(image_list):
14 | images = []
15 | for i in image_list:
16 | c_img = np.expand_dims(image.img_to_array(
17 | image.load_img(i, target_size = (256, 256))), axis=0)
18 | images.append(c_img)
19 | return np.vstack(images)
20 |
21 |
22 | def image_generator(fnames, batch_size):
23 | cfns = []
24 | for i, p in enumerate(fnames):
25 | cfns.append(p)
26 | if len(cfns) == batch_size:
27 | yield load_images(cfns)
28 | cfns = []
29 | if len(cfns) != 0:
30 | yield load_images(cfns)
31 | cfns = []
32 |
33 |
34 | if __name__ == "__main__":
35 |
36 | if len(sys.argv) != 3:
37 | print("usage: [image list] [output name]")
38 | quit()
39 |
40 |
41 | keras.backend.set_learning_phase(0)
42 | do_center_crop = True
43 |
44 | file_list = sys.argv[1]
45 | fpath = sys.argv[2]
46 |
47 | all_paths = []
48 | with open(file_list) as f:
49 | for i,line in enumerate(f):
50 | all_paths.append(line.strip())
51 |
52 | print("Extracting from {} files".format(len(all_paths)))
53 |
54 | print('Saving to {}'.format(fpath))
55 | base_model = DenseNet169(include_top=False, input_shape=(224,224,3))
56 | base_model.trainable = False
57 | base_model.summary()
58 |
59 | m_image = keras.layers.Input((256, 256, 3), dtype='float32')
60 |
61 | if do_center_crop:
62 | crop = keras.layers.Lambda(lambda x: tf.image.central_crop(x, .875),
63 | output_shape=(224, 224, 3))(m_image)
64 | else:
65 | crop = keras.layers.Lambda(lambda x: tf.map_fn(lambda y: tf.random_crop(y, [224, 224, 3]), x),
66 | output_shape=(224, 224, 3))(m_image)
67 |
68 | crop = keras.layers.Lambda(lambda x: preprocess_input_dn(x))(crop)
69 | trans = base_model(crop)
70 | trans = keras.layers.GlobalAveragePooling2D()(trans)
71 | model = Model(inputs=m_image,
72 | outputs=trans)
73 | model.summary()
74 |
75 | batch_size = 100
76 | print('Getting started...')
77 | gen = image_generator(all_paths, batch_size)
78 |
79 | feats = model.predict_generator(gen,
80 | np.ceil(len(all_paths)/batch_size),
81 | use_multiprocessing=False,
82 | workers=1,
83 | verbose=1)
84 | print(feats.shape)
85 | print("Saving to {}".format(fpath))
86 | with open(fpath, 'w') as f:
87 | for i in range(feats.shape[0]):
88 | f.write(" ".join(["{:.10f}".format(x) for x in feats[i,:]]) + "\n")
89 |
--------------------------------------------------------------------------------
/image_feature_extract/make_python_image_info.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import numpy as np
4 | import tqdm
5 |
6 |
7 | if len(sys.argv) != 3:
8 | print('usage: [text file of features] [filenames text file]')
9 | quit()
10 |
11 |
12 | def load_matrix(fname, n_rows):
13 | vecs = []
14 | with open(fname) as f:
15 | for idx in tqdm.tqdm(range(n_rows)):
16 | vecs.append(np.array([float(x) for x in f.readline().split()]))
17 | return np.vstack(vecs)
18 |
19 |
20 | def load_ordered_ids(fname):
21 | ids = []
22 | with open(fname) as f:
23 | for line in f:
24 | if line.strip():
25 | ids.append('.'.join(line.split('/')[-1].split('.')[:-1]))
26 | return ids
27 |
28 |
29 | ordered_ids = load_ordered_ids(sys.argv[2])
30 | id2row = {str(idx): row for row, idx in enumerate(ordered_ids)}
31 | print('loading {} rows of matrix...'.format(len(ordered_ids)))
32 | m_mat = load_matrix(sys.argv[1], len(ordered_ids))
33 |
34 | with open('id2row.json', 'w') as f:
35 | f.write(json.dumps(id2row))
36 |
37 | np.save(sys.argv[1].split('.')[0], m_mat)
38 |
39 |
--------------------------------------------------------------------------------
/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.keras.preprocessing import image
4 | from tensorflow.keras.applications.nasnet import preprocess_input
5 | import scipy.misc
6 |
7 | data_augmentor = image.ImageDataGenerator(rotation_range=20,
8 | horizontal_flip=True,
9 | fill_mode='reflect')
10 |
11 | def crop(image_in, targ, center=False):
12 | x, y = image_in.shape[:2]
13 | if center:
14 | x_off, y_off = (x-targ) // 2, (y-targ) // 2
15 | else:
16 | x_off, y_off = np.random.randint(0, (x-targ)), np.random.randint(0, (y-targ))
17 | res = image_in[x_off:x_off+targ, y_off:y_off+targ, ...]
18 | return res
19 |
20 |
21 | def augment_image(image_in):
22 | image_out = data_augmentor.random_transform(image_in)
23 | return np.expand_dims(crop(image_out, 224), axis=0)
24 |
25 |
26 | def images_to_matrix(image_list, image_matrix, image_idx2row):
27 | rows = []
28 | for img in image_list:
29 | rows.append(image_idx2row[str(img)])
30 |
31 | if type(image_matrix) is list:
32 | image_matrix_choices = np.random.choice(len(image_matrix), size=len(rows))
33 | all_features = [image_matrix[image_matrix_choices[idx]][r,:] for idx, r in enumerate(rows)]
34 | return np.vstack(all_features)
35 | else:
36 | return image_matrix[np.array(rows), :]
37 |
38 |
39 | def load_images(image_list, preprocess=True, target_size=(224, 224)):
40 | images = []
41 | for i in image_list:
42 | c_img = np.expand_dims(image.img_to_array(
43 | image.load_img(i, target_size=target_size)), axis=0)
44 | images.append(c_img)
45 | images = np.vstack(images)
46 | if preprocess:
47 | images = preprocess_input(images)
48 | return images
49 |
50 |
51 | def images_to_images(image_list, augment, args):
52 | if args.full_image_paths:
53 | images = load_images([args.image_dir + '/' + img for img in image_list], target_size=(256, 256))
54 | else:
55 | images = load_images([args.image_dir + '/' + img + '.jpg' for img in image_list], target_size=(256, 256))
56 | if augment:
57 | images = np.vstack(list(map(augment_image, images)))
58 | else:
59 | images = np.vstack(list(map(lambda x: np.expand_dims(crop(x, 224, center=True), axis=0), images)))
60 | return images
61 |
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Model utility functions
3 | '''
4 | import tensorflow as tf
5 | import collections
6 | import sys
7 |
8 |
9 | def make_get_pos_neg_sims(args, sim_fn):
10 |
11 | def get_pos_neg_sims(inp):
12 | '''
13 | Applies the similarity function between all text_idx, img_idx pairs.
14 |
15 | inp is a list of three arguments:
16 | - sims: the stacked similarity matrix
17 | - text_n_inp: how many sentences are in each document
18 | - img_n_inp: how many images are in each document
19 | '''
20 |
21 | sims, text_n_inp, img_n_inp = inp
22 | text_index_borders = tf.dtypes.cast(tf.cumsum(text_n_inp), tf.int32)
23 | img_index_borders = tf.dtypes.cast(tf.cumsum(img_n_inp), tf.int32)
24 | zero = tf.expand_dims(tf.expand_dims(tf.constant(0, dtype=tf.int32), axis=-1), axis=-1)
25 |
26 | # these give the indices of the borders between documents in our big sim matrix...
27 | text_index_borders = tf.concat([zero, text_index_borders], axis=0)
28 | img_index_borders = tf.concat([zero, img_index_borders], axis=0)
29 |
30 | doc2pos_sim = {}
31 | doc2neg_img_sims = collections.defaultdict(list)
32 | doc2neg_text_sims = collections.defaultdict(list)
33 |
34 | # for each pair of text set and image set...
35 | for text_idx in range(args.docs_per_batch):
36 | for img_idx in range(args.docs_per_batch):
37 | text_start = tf.squeeze(text_index_borders[text_idx])
38 | text_end = tf.squeeze(text_index_borders[text_idx+1])
39 | img_start = tf.squeeze(img_index_borders[img_idx])
40 | img_end = tf.squeeze(img_index_borders[img_idx+1])
41 | cur_sims = sims[text_start:text_end, img_start:img_end]
42 | sim = sim_fn(cur_sims)
43 | if text_idx == img_idx:
44 | doc2pos_sim[text_idx] = sim
45 | else: # negative cases
46 | doc2neg_img_sims[text_idx].append(sim)
47 | doc2neg_text_sims[img_idx].append(sim)
48 |
49 | pos_sims, neg_img_sims, neg_text_sims = [], [], []
50 | for idx in range(args.docs_per_batch):
51 | pos_sims.append(doc2pos_sim[idx])
52 | neg_img_sims.append(tf.stack(doc2neg_img_sims[idx]))
53 | neg_text_sims.append(tf.stack(doc2neg_text_sims[idx]))
54 |
55 | pos_sims = tf.expand_dims(tf.stack(pos_sims), -1)
56 | neg_img_sims = tf.stack(neg_img_sims)
57 | neg_text_sims = tf.stack(neg_text_sims)
58 |
59 | return [pos_sims, neg_img_sims, neg_text_sims]
60 |
61 | return get_pos_neg_sims
62 |
--------------------------------------------------------------------------------
/paper_commands/generate_training_commands_paper.py:
--------------------------------------------------------------------------------
1 | '''
2 | This script generates the training commands for the results reported in the paper
3 | '''
4 | import collections
5 | import numpy as np
6 | import os
7 |
8 | np.random.seed(1)
9 | if not os.path.exists('paper_res'):
10 | os.makedirs('paper_res')
11 | if not os.path.exists('paper_checkpoints'):
12 | os.makedirs('paper_checkpoints')
13 |
14 | n_parallel = 4
15 |
16 | all_commands = []
17 |
18 | for batch_size in [11]:#21,31]:
19 | for dset in ['diy', 'mscoco', 'rqa', 'sis', 'dii-r', 'wiki']:
20 | for alg in ['DC','TK','AP','NoStruct']:
21 | for k in [-1] if alg in ['DC', 'NoStruct'] else [-1, 2]:
22 | for neg_mining in ['negative_sample', 'hard_negative']:
23 |
24 | # for wiki, just do limited settings
25 | if dset == 'wiki' and (alg not in ['AP', 'NoStruct'] or k != -1 or neg_mining != 'hard_negative' or batch_size != 11): continue
26 |
27 | identifier_str = '{}+{}+{}+{}+{}'.format(dset, batch_size, alg, k, neg_mining)
28 | output_f = 'paper_res/{}.pkl'.format(identifier_str)
29 |
30 | cmd = 'python3 train_doc.py data/{}/docs.json --image_id2row data/{}/id2row.json --image_features data/{}/features.npy '.format(
31 | dset, dset, dset)
32 | if alg == 'NoStruct':
33 | cmd += '--subsample_image 1 --subsample_text 1 '
34 |
35 | if dset in ['rqa','diy','wiki']:
36 | cmd += '--seq_len 50 '
37 |
38 | cmd += '--gpu_memory_frac .45 --sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg if alg != 'NoStruct' else 'DC', neg_mining, k)
39 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset)
40 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f)
41 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str)
42 |
43 | if os.path.exists(output_f): continue
44 | all_commands.append(cmd)
45 |
46 | files = [open('{}_commands.txt'.format(idx+1), 'w') for idx in range(n_parallel)]
47 | idx = 0
48 | for cmd in all_commands:
49 | files[idx].write(cmd + '\n')
50 | idx += 1
51 | idx = idx % n_parallel
52 |
53 | for f in files:
54 | f.close()
55 |
--------------------------------------------------------------------------------
/paper_commands/generate_training_commands_paper_finetune.py:
--------------------------------------------------------------------------------
1 | '''
2 | This script generates the training commands for the results reported in the paper
3 | '''
4 | import collections
5 | import numpy as np
6 | import os
7 |
8 | np.random.seed(1)
9 | if not os.path.exists('paper_res'):
10 | os.makedirs('paper_res')
11 | if not os.path.exists('paper_checkpoints'):
12 | os.makedirs('paper_checkpoints')
13 |
14 | n_parallel = 1
15 |
16 | all_commands = []
17 |
18 | for batch_size, dset, alg, subsample_img, subsample_txt, neg_mining in [
19 | (11, 'rqa', 'AP', 10, 10, 'hard_negative'),
20 | (11, 'diy', 'AP', 10, 10, 'hard_negative'),
21 | (11, 'wiki', 'AP', 10, 10, 'hard_negative')]:
22 | identifier_str = '{}+{}+{}+{}+{}+FT'.format(dset, batch_size, alg, -1, neg_mining)
23 | output_f = 'paper_res/{}.pkl'.format(identifier_str)
24 | #if os.path.exists(output_f): continue
25 | if 'wiki' != dset:
26 | cmd = 'python3 train_doc.py data/{}/docs.json --image_dir data/{}/images '.format(dset, dset)
27 | else:
28 | cmd = 'python3 train_doc.py data/{}/docs.json --image_dir data/{} '.format(dset, dset)
29 | cmd += '--sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg, neg_mining, -1)
30 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset)
31 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f)
32 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str)
33 | cmd += '--subsample_text {} '.format(subsample_txt)
34 | cmd += '--subsample_image {} '.format(subsample_img)
35 | cmd += '--end2end 1 '
36 | cmd += '--seq_len 50 '
37 | cmd += '--force 1 '
38 | if dset == 'wiki':
39 | cmd += '--full_image_paths 1'
40 | all_commands.append(cmd)
41 |
42 | files = [open('{}_commands_FT.txt'.format(idx+1), 'w') for idx in range(n_parallel)]
43 | idx = 0
44 | for cmd in all_commands:
45 | files[idx].write(cmd + '\n')
46 | idx += 1
47 | idx = idx % n_parallel
48 |
49 | for f in files:
50 | f.close()
51 |
--------------------------------------------------------------------------------
/paper_commands/generate_training_commands_training_dynamics.py:
--------------------------------------------------------------------------------
1 | '''
2 | This script generates the training commands for the results reported in the paper
3 | '''
4 | import collections
5 | import numpy as np
6 | import os
7 |
8 | np.random.seed(1)
9 | if not os.path.exists('paper_res'):
10 | os.makedirs('paper_res')
11 | if not os.path.exists('paper_checkpoints'):
12 | os.makedirs('paper_checkpoints')
13 |
14 | n_parallel = 4
15 |
16 | all_commands = []
17 |
18 | for batch_size in [11]:
19 | for dset in ['dii', 'diy', 'mscoco', 'rqa', 'sis', 'dii-r']:
20 | for alg in ['AP']:
21 | for neg_mining in ['hard_negative']:
22 | k = -1
23 | identifier_str = '{}+{}+{}+{}+{}+DYNAMICS'.format(dset, batch_size, alg, k, neg_mining)
24 | output_f = 'paper_res/{}.pkl'.format(identifier_str)
25 | cmd = 'python3 train_doc.py data/{}/docs.json --image_id2row data/{}/id2row.json --image_features data/{}/features.npy '.format(
26 | dset, dset, dset)
27 | if dset in ['rqa','diy',]:
28 | cmd += '--seq_len 50 '
29 | cmd += '--print_metrics 1 '
30 | cmd += '--gpu_memory_frac .45 --sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg if alg != 'NoStruct' else 'DC', neg_mining, k)
31 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset)
32 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f)
33 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str)
34 |
35 | if os.path.exists(output_f): continue
36 | all_commands.append(cmd)
37 |
38 | files = [open('{}_commands_training_dynamics.txt'.format(idx+1), 'w') for idx in range(n_parallel)]
39 | idx = 0
40 | for cmd in all_commands:
41 | files[idx].write(cmd + '\n')
42 | idx += 1
43 | idx = idx % n_parallel
44 |
45 | for f in files:
46 | f.close()
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/src-d/lapjv
2 | numpy
3 | scipy
4 | scikit-learn
5 | tqdm
6 | spacy
7 | gensim
8 | nltk
9 | sacrebleu
10 | tensorflow>=2.1
11 | git+https://github.com/jmhessel/pycocoevalcap
12 |
--------------------------------------------------------------------------------
/summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmhessel/multi-retrieval/20178dfda368ed4117f3aaee7b0bd2946fcdf9eb/summary.png
--------------------------------------------------------------------------------
/text_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import collections
4 | import tensorflow as tf
5 | import tqdm
6 | import numpy as np
7 | import pprint
8 | import warnings
9 |
10 | from gensim.models.keyedvectors import KeyedVectors
11 |
12 | from nltk.tokenize import TweetTokenizer
13 | global _TOKENIZER
14 | _TOKENIZER = TweetTokenizer()
15 |
16 |
17 | def preprocess_caption(cap_in):
18 | return ' '.join(_TOKENIZER.tokenize(cap_in)).lower()
19 |
20 |
21 | def get_vocab(data, min_count=5, cached=None):
22 | if cached is None or not os.path.exists(cached):
23 | voc_counter = collections.Counter()
24 | for c in tqdm.tqdm(data):
25 | voc_counter.update(preprocess_caption(c).split())
26 | word2idx = {'': 0, '': 1}
27 | idx = len(word2idx)
28 | for v, c in sorted(voc_counter.items(), key=lambda x: x[1], reverse=True):
29 | if c < min_count:
30 | break
31 | word2idx[v] = idx
32 | idx += 1
33 | if cached is not None:
34 | with open(cached, 'w') as f:
35 | f.write(json.dumps(word2idx))
36 | else:
37 | with open(cached) as f:
38 | word2idx = json.loads(f.read())
39 | return word2idx
40 |
41 |
42 | def get_word2vec_matrix(vocab, cache_file, word2vec_binary):
43 | if cache_file is None and word2vec_binary is None:
44 | return None
45 | if cache_file is None or not os.path.exists(cache_file):
46 | print('Loading word2vec binary...')
47 | word2vec = KeyedVectors.load_word2vec_format(word2vec_binary, binary=True)
48 | word2vec_cachable = {}
49 | for w, idx in vocab.items():
50 | if w in word2vec:
51 | word2vec_cachable[w] = list([float(x) for x in word2vec[w]])
52 | if cache_file is not None:
53 | with open(cache_file, 'w') as f:
54 | f.write(json.dumps(word2vec_cachable))
55 | else:
56 | with open(cache_file) as f:
57 | word2vec_cachable = json.loads(f.read())
58 | word2vec = {w:np.array(v) for w, v in word2vec_cachable.items()}
59 | m_matrix = np.random.uniform(-.2, .2, size=(len(vocab), 300))
60 | for w, idx in vocab.items():
61 | if w in word2vec:
62 | m_matrix[idx, :] = word2vec[w]
63 | return m_matrix
64 |
65 |
66 | def text_to_matrix(captions, vocab, max_len=15, padding='post'):
67 | seqs = []
68 | for c in captions:
69 | tokens = preprocess_caption(c).split()
70 |
71 | # for reasons I dont understand, the new version of CUDNN
72 | # doesn't play nice with padding, etc. After painstakingly
73 | # narrowing down why this happens, CUDNN errors when:
74 |
75 | # 1) you're using RNN
76 | # 2) your batch consists of fully-padded sequences and non-padded sequences only
77 |
78 | # I filed a tensorflow issue:
79 | # see https://github.com/tensorflow/tensorflow/issues/36139
80 |
81 | # Upon a tensorflower's reply, CuDNN currently doesn't like empty sequences,
82 | # and the CuDNN kernel only gets called when things are right-padded
83 |
84 | # so, for now, until this bug is fixed, so we can still use CuDNN:
85 | # 1) we will right/post-pad
86 | # 2) in the data iterator, for padding sequences, we will prepend with
87 | # an unk. These sentences don't affect the gradient, and we are
88 | # expecting CuDNN to return junk anyway in those cases, so this
89 | # should be fine, but I will experimentally verify
90 | idxs = [vocab[v] if v in vocab else vocab[''] for v in tokens]
91 |
92 | if len(idxs) == 0:
93 | warnings.warn(
94 | 'Wanring: detected at least one zero-length sentence. '
95 | 'Running will continue, but check your inputs.')
96 | idxs = [vocab['']]
97 |
98 | seqs.append(idxs)
99 | m_mat = tf.keras.preprocessing.sequence.pad_sequences(seqs, maxlen=max_len,
100 | padding=padding, truncating='post',
101 | value=0)
102 | return m_mat
103 |
--------------------------------------------------------------------------------
/train_doc.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code to accompany
3 | "Unsupervised Discovery of Multimodal Links in Multi-Sentence/Multi-Image Documents."
4 | https://github.com/jmhessel/multi-retrieval
5 |
6 | This is a work-in-progress TF2.0 port.
7 | '''
8 | import argparse
9 | import collections
10 | import json
11 | import tensorflow as tf
12 | import numpy as np
13 | import os
14 | import sys
15 | import tqdm
16 | import text_utils
17 | import image_utils
18 | import eval_utils
19 | import model_utils
20 | import training_utils
21 | import bipartite_utils
22 | import pickle
23 | import sklearn.preprocessing
24 | from pprint import pprint
25 |
26 |
27 | def load_data(fname):
28 | with open(fname) as f:
29 | return json.loads(f.read())
30 |
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('documents',
35 | help='json of train/val/test documents.')
36 | parser.add_argument('--image_features',
37 | help='path to pre-extracted image-feature numpy array.')
38 | parser.add_argument('--image_id2row',
39 | help='path to mapping from image id --> numpy row for image features.')
40 | parser.add_argument('--joint_emb_dim',
41 | type=int,
42 | help='Embedding dimension of the shared, multimodal space.',
43 | default=1024)
44 | parser.add_argument('--margin',
45 | type=float,
46 | help='Margin for computing hinge loss.',
47 | default=.2)
48 | parser.add_argument('--seq_len',
49 | type=int,
50 | help='Maximum token sequence length for each sentence before truncation.',
51 | default=20)
52 | parser.add_argument('--docs_per_batch',
53 | type=int,
54 | help='How many docs per batch? 11 docs = 10 negative samples per doc.',
55 | default=11)
56 | parser.add_argument('--neg_mining',
57 | help='What type of negative mining?',
58 | default='hard_negative',
59 | choices=['negative_sample', 'hard_negative'],
60 | type=str)
61 | parser.add_argument('--sim_mode',
62 | help='What similarity function should we use?',
63 | default='AP',
64 | choices=['DC','TK','AP'],
65 | type=str)
66 | parser.add_argument('--sim_mode_k',
67 | help='If --sim_mode=TK/AP, what should the k be? '
68 | 'k=-1 for dynamic = min(n_images, n_sentences))? '
69 | 'if k > 0, then k=ceil(1./k * min(n_images, n_sentences))',
70 | default=-1,
71 | type=float)
72 | parser.add_argument('--lr',
73 | type=float,
74 | help='Starting learning rate',
75 | default=.0002)
76 | parser.add_argument('--n_epochs',
77 | type=int,
78 | help='How many epochs to run for?',
79 | default=60)
80 | parser.add_argument('--checkpoint_dir',
81 | type=str,
82 | help='What directory to save checkpoints in?',
83 | default='checkpoints')
84 | parser.add_argument('--word2vec_binary',
85 | type=str,
86 | help='If cached word embeddings have not been generated, '
87 | 'what is the location of the word2vec binary?',
88 | default=None)
89 | parser.add_argument('--cached_word_embeddings',
90 | type=str,
91 | help='Where are/will the cached word embeddings saved?',
92 | default='cached_word2vec.json')
93 | parser.add_argument('--print_metrics',
94 | type=int,
95 | help='Should we print the metrics if there are ground-truth '
96 | 'labels, or no?',
97 | default=0)
98 | parser.add_argument('--cached_vocab',
99 | type=str,
100 | help='Where should we cache the vocab, if anywhere '
101 | '(None means no caching)',
102 | default=None)
103 | parser.add_argument('--output',
104 | type=str,
105 | default=None,
106 | help='If output is set, we will save a pkl file'
107 | 'with the validation/test metrics.')
108 | parser.add_argument('--seed',
109 | type=int,
110 | help='Random seed',
111 | default=1)
112 | parser.add_argument('--dropout',
113 | type=float,
114 | default=0.5,
115 | help='How much dropout should we apply?')
116 | parser.add_argument('--subsample_image',
117 | type=int,
118 | default=-1,
119 | help='Should we subsample images to constant lengths? '
120 | 'This option is useful if the model is being trained end2end '
121 | 'and there are memory issues.')
122 | parser.add_argument('--subsample_text',
123 | type=int,
124 | default=-1,
125 | help='Should we subsample sentences to constant lengths? '
126 | 'This option is useful if the model is being trained end2end '
127 | 'and there are memory issues.')
128 | parser.add_argument('--rnn_type',
129 | type=str,
130 | default='GRU',
131 | help='What RNN should we use')
132 | parser.add_argument('--end2end',
133 | type=int,
134 | default=0,
135 | help='Should we backprop through the whole vision network?')
136 | parser.add_argument('--image_dir',
137 | type=str,
138 | default=None,
139 | help='What image dir should we use, if end2end?')
140 | parser.add_argument('--lr_patience',
141 | type=int,
142 | default=3,
143 | help='What learning rate patience should we use?')
144 | parser.add_argument('--lr_decay',
145 | type=float,
146 | default=.2,
147 | help='What learning rate decay factor should we use?')
148 | parser.add_argument('--min_lr',
149 | type=float,
150 | default=.0000001,
151 | help='What learning rate decay factor should we use?')
152 | parser.add_argument('--full_image_paths',
153 | type=int,
154 | default=0,
155 | help='For end2end training, should we use full image paths '
156 | '(i.e., is the file extention already on images?)?')
157 | parser.add_argument('--test_eval',
158 | type=int,
159 | help='(DEBUG OPTION) If test_eval >= 1, then training '
160 | 'only happens over this many batches',
161 | default=-1)
162 | parser.add_argument('--force',
163 | type=int,
164 | default=0,
165 | help='Should we force the run if the output exists?')
166 | parser.add_argument('--save_predictions',
167 | type=str,
168 | default=None,
169 | help='Should we save the train/val/test predictions? '
170 | 'If so --- they will be saved in this directory.')
171 | parser.add_argument('--image_model_checkpoint',
172 | type=str,
173 | default=None,
174 | help='If set, the image model will be initialized from '
175 | 'this model checkpoint.')
176 | parser.add_argument('--text_model_checkpoint',
177 | type=str,
178 | default=None,
179 | help='If set, the text model will be initialized from '
180 | 'this model checkpoint.')
181 | parser.add_argument('--loss_mode',
182 | help='What loss function should we use?',
183 | default='hinge',
184 | choices=['hinge', 'logistic', 'softmax'],
185 | type=str)
186 | parser.add_argument('--compute_mscoco_eval_metrics',
187 | help='Should we compute the mscoco MT metrics?',
188 | default=0,
189 | type=int)
190 | parser.add_argument('--compute_metrics_train',
191 | help='Should we also compute metrics over the training set?',
192 | default=1,
193 | type=int)
194 | parser.add_argument('--lr_warmup_steps',
195 | help='If positive value, we will warmup the learning rate linearly '
196 | 'over this many steps.',
197 | default=-1,
198 | type=int)
199 | parser.add_argument('--l2_norm',
200 | help='If 1, we will l2 normalize extracted features, else, no normalization.',
201 | default=1,
202 | type=int)
203 | parser.add_argument('--n_layers',
204 | help='How many layers in the encoders?',
205 | default=1,
206 | type=int,
207 | choices=[1,2,3])
208 | parser.add_argument('--scale_image_features',
209 | help='Should we standard scale image features?',
210 | default=0,
211 | type=int)
212 | args = parser.parse_args()
213 |
214 | # check to make sure that various flags are set correctly
215 | if args.end2end:
216 | assert args.image_dir is not None
217 | if not args.end2end:
218 | assert args.image_features is not None and args.image_id2row is not None
219 |
220 | # print out some info about the run's inputs/outputs
221 | if args.output and '.pkl' not in args.output:
222 | args.output += '.pkl'
223 |
224 | if args.output:
225 | print('Output will be saved to {}'.format(args.output))
226 | print('Model checkpoints will be saved in {}'.format(args.checkpoint_dir))
227 |
228 | if args.output and os.path.exists(args.output) and not args.force:
229 | print('{} already done! If you want to force it, set --force 1'.format(args.output))
230 | quit()
231 |
232 | if not os.path.exists(args.checkpoint_dir):
233 | os.makedirs(args.checkpoint_dir)
234 |
235 | if args.save_predictions:
236 | if not os.path.exists(args.checkpoint_dir):
237 | os.makedirs(args.checkpoint_dir)
238 | os.makedirs(args.checkpoint_dir + '/train')
239 | os.makedirs(args.checkpoint_dir + '/val')
240 | os.makedirs(args.checkpoint_dir + '/test')
241 |
242 | return args
243 |
244 |
245 | def main():
246 | args = parse_args()
247 | np.random.seed(args.seed)
248 |
249 | data = load_data(args.documents)
250 | train, val, test = data['train'], data['val'], data['test']
251 |
252 | np.random.shuffle(train); np.random.shuffle(val); np.random.shuffle(test)
253 | max_n_sentence, max_n_image = -1, -1
254 | for d in train + val + test:
255 | imgs, sents, meta = d
256 | max_n_sentence = max(max_n_sentence, len(sents))
257 | max_n_image = max(max_n_image, len(imgs))
258 |
259 | # remove zero image/zero sentence cases:
260 | before_lens = list(map(len, [train, val, test]))
261 |
262 | train = [t for t in train if len(t[0]) > 0 and len(t[1]) > 0]
263 | val = [t for t in val if len(t[0]) > 0 and len(t[1]) > 0]
264 | test = [t for t in test if len(t[0]) > 0 and len(t[1]) > 0]
265 |
266 | after_lens = list(map(len, [train, val, test]))
267 | for bl, al, split in zip(before_lens, after_lens, ['train', 'val', 'test']):
268 | if bl == al: continue
269 | print('Removed {} documents from {} split that had zero images and/or sentences'.format(
270 | bl-al, split))
271 |
272 | print('Max n sentence={}, max n image={}'.format(max_n_sentence, max_n_image))
273 | if args.cached_vocab:
274 | print('Saving/loading vocab from {}'.format(args.cached_vocab))
275 |
276 | # create vocab from training documents:
277 | flattened_train_sents = []
278 | for _, sents, _ in train:
279 | flattened_train_sents.extend([s[0] for s in sents])
280 | word2idx = text_utils.get_vocab(flattened_train_sents, cached=args.cached_vocab)
281 | print('Vocab size was {}'.format(len(word2idx)))
282 |
283 | if args.word2vec_binary:
284 | we_init = text_utils.get_word2vec_matrix(
285 | word2idx, args.cached_word_embeddings, args.word2vec_binary)
286 | else:
287 | we_init = np.random.uniform(low=-.02, high=.02, size=(len(word2idx), 300))
288 |
289 | if args.end2end:
290 | image_features = None
291 | image_idx2row = None
292 | else:
293 | image_features = np.load(args.image_features)
294 | image_idx2row = load_data(args.image_id2row)
295 |
296 | if args.scale_image_features:
297 | ss = sklearn.preprocessing.StandardScaler()
298 | all_train_images = []
299 | for img, txt, cid in train:
300 | all_train_images.extend([x[0] for x in img])
301 | print('standard scaling with {} images total'.format(len(all_train_images)))
302 | all_train_rows = [image_idx2row[cid] for cid in all_train_images]
303 | ss.fit(image_features[np.array(all_train_rows)])
304 | image_features = ss.transform(image_features)
305 |
306 | word_emb_dim = 300
307 |
308 | if val[0][0][0][1] is not None:
309 | ground_truth = True
310 | print('The input has ground truth, so AUC will be computed.')
311 | else:
312 | ground_truth = False
313 |
314 | # Step 1: Specify model inputs/outputs:
315 |
316 | # (n docs, n sent, max n words,)
317 | text_inp = tf.keras.layers.Input((None, args.seq_len))
318 |
319 | # this input tells you how many sentences are really in each doc
320 | text_n_inp = tf.keras.layers.Input((1,), dtype='int32')
321 | if args.end2end:
322 | # (n docs, n image, x, y, color)
323 | img_inp = tf.keras.layers.Input((None, 224, 224, 3))
324 | else:
325 | # (n docs, n image, feature dim)
326 | img_inp = tf.keras.layers.Input((None, image_features.shape[1]))
327 | # this input tells you how many images are really in each doc
328 | img_n_inp = tf.keras.layers.Input((1,), dtype='int32')
329 |
330 | # Step 2: Define transformations to shared multimodal space.
331 |
332 | # Step 2.1: The text model:
333 | if args.text_model_checkpoint:
334 | print('Loading pretrained text model from {}'.format(
335 | args.text_model_checkpoint))
336 | single_text_doc_model = tf.keras.models.load_model(args.text_model_checkpoint)
337 | extracted_text_features = single_text_doc_model(text_inp)
338 | else:
339 | word_embedding = tf.keras.layers.Embedding(
340 | len(word2idx),
341 | word_emb_dim,
342 | weights=[we_init] if we_init is not None else None,
343 | mask_zero=True)
344 | element_dropout = tf.keras.layers.SpatialDropout1D(args.dropout)
345 |
346 | if args.rnn_type == 'GRU':
347 | rnn_maker = tf.keras.layers.GRU
348 | else:
349 | rnn_maker = tf.keras.layers.LSTM
350 |
351 | enc_layers = []
352 | for idx in range(args.n_layers):
353 | if idx == args.n_layers-1:
354 | enc_layers.append(rnn_maker(args.joint_emb_dim))
355 | else:
356 | enc_layers.append(rnn_maker(args.joint_emb_dim, return_sequences=True))
357 |
358 | embedded_text_inp = word_embedding(text_inp)
359 | extracted_text_features = tf.keras.layers.TimeDistributed(element_dropout)(embedded_text_inp)
360 |
361 | for l in enc_layers:
362 | extracted_text_features = tf.keras.layers.TimeDistributed(l)(extracted_text_features)
363 |
364 | # extracted_text_features is now (n docs, max n setnences, multimodal dim)
365 | if args.l2_norm:
366 | l2_norm_layer = tf.keras.layers.Lambda(lambda x: tf.nn.l2_normalize(x, axis=-1))
367 | extracted_text_features = l2_norm_layer(extracted_text_features)
368 |
369 | single_text_doc_model = tf.keras.models.Model(
370 | inputs=text_inp,
371 | outputs=extracted_text_features)
372 |
373 | # Step 2.2: The image model:
374 | if args.image_model_checkpoint:
375 | print('Loading pretrained image model from {}'.format(
376 | args.image_model_checkpoint))
377 | single_img_doc_model = tf.keras.models.load_model(args.image_model_checkpoint)
378 | extracted_img_features = single_img_doc_model(img_inp)
379 | else:
380 | if args.end2end:
381 | img_projection = tf.keras.layers.Dense(args.joint_emb_dim)
382 | from tf.keras.applications.nasnet import NASNetMobile
383 | cnn = tf.keras.applications.nasnet.NASNetMobile(
384 | include_top=False, input_shape=(224, 224, 3), pooling='avg')
385 |
386 | extracted_img_features = tf.keras.layers.TimeDistributed(cnn)(img_inp)
387 | if args.dropout > 0.0:
388 | extracted_img_features = tf.keras.layers.TimeDistributed(
389 | tf.keras.layers.Dropout(args.dropout))(extracted_img_features)
390 | extracted_img_features = keras.layers.TimeDistributed(img_projection)(
391 | extracted_img_features)
392 | else:
393 | extracted_img_features = tf.keras.layers.Masking()(img_inp)
394 | if args.dropout > 0.0:
395 | extracted_img_features = tf.keras.layers.TimeDistributed(
396 | tf.keras.layers.Dropout(args.dropout))(extracted_img_features)
397 |
398 | enc_layers = []
399 | for idx in range(args.n_layers):
400 | if idx == args.n_layers-1:
401 | enc_layers.append(tf.keras.layers.Dense(args.joint_emb_dim))
402 | else:
403 | enc_layers.append(tf.keras.layers.Dense(args.joint_emb_dim, activation='relu'))
404 | enc_layers.append(tf.keras.layers.BatchNormalization())
405 |
406 | for l in enc_layers:
407 | extracted_img_features = tf.keras.layers.TimeDistributed(l)(extracted_img_features)
408 |
409 | # extracted_img_features is now (n docs, max n images, multimodal dim)
410 | if args.l2_norm:
411 | l2_norm_layer = tf.keras.layers.Lambda(lambda x: tf.nn.l2_normalize(x, axis=-1))
412 | extracted_img_features = l2_norm_layer(extracted_img_features)
413 |
414 | single_img_doc_model = tf.keras.models.Model(
415 | inputs=img_inp,
416 | outputs=extracted_img_features)
417 |
418 | # Step 3: Extract/stack the non-padding image/sentence representations
419 | def mask_slice_and_stack(inp):
420 | stacker = []
421 | features, n_inputs = inp
422 | n_inputs = tf.dtypes.cast(n_inputs, tf.int32)
423 | # for each document, we will extract the portion of input features that are not padding
424 | # this means, for features[doc_idx], we will take the first n_inputs[doc_idx] rows.
425 | # we stack them into one big array so we can do a big cosine sim dot product between all
426 | # sentence image pairs in parallel. We'll slice up this array back up later.
427 | for idx in range(args.docs_per_batch):
428 | cur_valid_idxs = tf.range(n_inputs[idx,0])
429 | cur_valid_features = features[idx]
430 | feats = tf.gather(cur_valid_features, cur_valid_idxs)
431 | stacker.append(feats)
432 | return tf.concat(stacker, axis=0)
433 |
434 | # extracted text/img features are (n_docs, max_in_seq, dim)
435 | # we want to compute cosine sims between all (sent, img) pairs quickly
436 | # so we will stack them into new tensors ...
437 | # text_enc has shape (total number of sent in batch, dim)
438 | # img_enc has shape (total number of image in batch, dim)
439 | text_enc = mask_slice_and_stack([extracted_text_features, text_n_inp])
440 | img_enc = mask_slice_and_stack([extracted_img_features, img_n_inp])
441 |
442 | def DC_sim(sim_matrix):
443 | text2im_S = tf.reduce_mean(tf.reduce_max(sim_matrix, 1))
444 | im2text_S = tf.reduce_mean(tf.reduce_max(sim_matrix, 0))
445 | return text2im_S + im2text_S
446 |
447 | def get_k(sim_matrix):
448 | k = tf.minimum(tf.shape(sim_matrix)[0], tf.shape(sim_matrix)[1])
449 | if args.sim_mode_k > 0:
450 | k = tf.dtypes.cast(k, tf.float32)
451 | k = tf.math.ceil(tf.div(k, args.sim_mode_k))
452 | k = tf.dtypes.cast(k, tf.int32)
453 | return k
454 |
455 | def TK_sim(sim_matrix):
456 | k = get_k(sim_matrix)
457 | im2text_S, text2im_S = tf.reduce_max(sim_matrix, 0), tf.reduce_max(sim_matrix, 1)
458 | text2im_S = tf.reduce_mean(tf.math.top_k(text2im_S, k=k)[0], axis=-1)
459 | im2text_S = tf.reduce_mean(tf.math.top_k(im2text_S, k=k)[0], axis=-1)
460 | return text2im_S + im2text_S
461 |
462 | bipartite_match_fn = bipartite_utils.generate_fast_hungarian_solving_function()
463 | def AP_sim(sim_matrix):
464 | k = get_k(sim_matrix)
465 | sol = tf.numpy_function(bipartite_match_fn, [sim_matrix, k], tf.int32)
466 | return tf.reduce_mean(tf.gather_nd(sim_matrix, sol))
467 |
468 | if args.sim_mode == 'DC':
469 | sim_fn = DC_sim
470 | elif args.sim_mode == 'TK':
471 | sim_fn = TK_sim
472 | elif args.sim_mode == 'AP':
473 | sim_fn = AP_sim
474 | else:
475 | raise NotImplementedError('{} is not implemented sim function'.format(args.sim_fn))
476 |
477 | def make_sims(inp):
478 | sims = tf.keras.backend.dot(inp[0], tf.keras.backend.transpose(inp[1]))
479 | return sims
480 |
481 | all_sims = make_sims([text_enc, img_enc])
482 | get_pos_neg_sims = model_utils.make_get_pos_neg_sims(
483 | args,
484 | sim_fn)
485 |
486 | pos_sims, neg_img_sims, neg_text_sims = tf.keras.layers.Lambda(
487 | get_pos_neg_sims)([all_sims, text_n_inp, img_n_inp])
488 |
489 | if args.loss_mode == 'hinge':
490 | def per_neg_loss(inp):
491 | pos_s, neg_s = inp
492 | return tf.math.maximum(neg_s - pos_s + args.margin, 0)
493 | elif args.loss_mode == 'logistic':
494 | def per_neg_loss(inp):
495 | pos_s, neg_s = inp
496 | return tf.nn.sigmoid_cross_entropy_with_logits(
497 | labels=tf.ones_like(neg_s),
498 | logits=pos_s - neg_s)
499 | elif args.loss_mode == 'softmax':
500 | def per_neg_loss(inp):
501 | pos_s, neg_s = inp
502 | pos_s -= args.margin
503 | pos_l, neg_l = tf.ones_like(pos_s), tf.zeros_like(neg_s)
504 | return tf.nn.softmax_cross_entropy_with_logits(
505 | tf.concat([pos_l, neg_l], axis=1),
506 | tf.concat([pos_s, neg_s], axis=1))
507 |
508 | neg_img_losses = per_neg_loss([pos_sims, neg_img_sims])
509 | neg_text_losses = per_neg_loss([pos_sims, neg_text_sims])
510 |
511 | if args.loss_mode != 'softmax':
512 | if args.neg_mining == 'negative_sample':
513 | pool_fn = lambda x: tf.reduce_mean(x, axis=1, keepdims=True)
514 | elif args.neg_mining == 'hard_negative':
515 | pool_fn = lambda x: tf.reduce_max(x, axis=1, keepdims=True)
516 | else:
517 | raise NotImplementedError('{} is not a valid for args.neg_mining'.format(
518 | args.neg_mining))
519 |
520 | neg_img_loss = tf.keras.layers.Lambda(pool_fn, name='neg_img')(neg_img_losses)
521 | neg_text_loss = tf.keras.layers.Lambda(pool_fn, name='neg_text')(neg_text_losses)
522 | else:
523 | neg_img_loss = neg_img_losses
524 | neg_text_loss = neg_text_losses
525 |
526 | inputs = [text_inp,
527 | img_inp,
528 | text_n_inp,
529 | img_n_inp]
530 |
531 | model = tf.keras.models.Model(inputs=inputs,
532 | outputs=[neg_img_loss, neg_text_loss])
533 |
534 | opt = tf.keras.optimizers.Adam(args.lr)
535 |
536 | def identity(y_true, y_pred):
537 | del y_true
538 | return tf.reduce_mean(y_pred, axis=-1)
539 |
540 | model.compile(opt, loss=identity)
541 |
542 | if args.test_eval > 0:
543 | train = train[:args.test_eval * args.docs_per_batch]
544 | val = val[:args.test_eval * args.docs_per_batch]
545 | test = test[:args.test_eval * args.docs_per_batch]
546 |
547 | train_seq = training_utils.DocumentSequence(
548 | train,
549 | image_features,
550 | image_idx2row,
551 | max_n_sentence,
552 | max_n_image,
553 | word2idx,
554 | args=args,
555 | shuffle_docs=True,
556 | shuffle_sentences=False,
557 | shuffle_images=True)
558 |
559 | val_seq = training_utils.DocumentSequence(
560 | val,
561 | image_features,
562 | image_idx2row,
563 | max_n_sentence,
564 | max_n_image,
565 | word2idx,
566 | args=args,
567 | augment=False,
568 | shuffle_sentences=False,
569 | shuffle_docs=False,
570 | shuffle_images=False)
571 |
572 | sdm = training_utils.SaveDocModels(
573 | args.checkpoint_dir,
574 | single_text_doc_model,
575 | single_img_doc_model)
576 |
577 | if args.loss_mode == 'hinge':
578 | val_loss_thresh = 2 * args.margin # constant prediction performance
579 | else:
580 | val_loss_thresh = np.inf
581 |
582 | reduce_lr = training_utils.ReduceLROnPlateauAfterValLoss(
583 | activation_val_loss=val_loss_thresh,
584 | factor=args.lr_decay,
585 | patience=args.lr_patience,
586 | min_lr=args.min_lr,
587 | verbose=True)
588 |
589 | callbacks = [reduce_lr, sdm]
590 |
591 | if args.print_metrics:
592 | metrics_printer = training_utils.PrintMetrics(
593 | val,
594 | image_features,
595 | image_idx2row,
596 | word2idx,
597 | single_text_doc_model,
598 | single_img_doc_model,
599 | args)
600 | callbacks.append(metrics_printer)
601 |
602 | if args.lr_warmup_steps > 0:
603 | warmup_lr = training_utils.LearningRateLinearIncrease(
604 | args.lr,
605 | args.lr_warmup_steps)
606 | callbacks.append(warmup_lr)
607 |
608 | history = model.fit(
609 | train_seq,
610 | epochs=args.n_epochs,
611 | validation_data=val_seq,
612 | callbacks=callbacks)
613 |
614 | if args.output:
615 | best_image_model_str, best_sentence_model_str, best_logs, best_epoch = sdm.best_checkpoints_and_logs
616 |
617 | single_text_doc_model = tf.keras.models.load_model(best_sentence_model_str)
618 | single_image_doc_model = tf.keras.models.load_model(best_image_model_str)
619 |
620 | if args.scale_image_features:
621 | with open(args.checkpoint_dir + '/image_standardscaler.pkl', 'wb') as f:
622 | pickle.dump(ss, f)
623 |
624 | if ground_truth and args.compute_metrics_train:
625 | train_aucs, train_match_metrics, train_mt_metrics = eval_utils.compute_match_metrics_doc(
626 | train,
627 | image_features,
628 | image_idx2row,
629 | word2idx,
630 | single_text_doc_model,
631 | single_img_doc_model,
632 | args)
633 | else:
634 | train_aucs, train_match_metrics, train_mt_metrics = None, None, None
635 |
636 |
637 | if ground_truth:
638 | val_aucs, val_match_metrics, val_mt_metrics = eval_utils.compute_match_metrics_doc(
639 | val,
640 | image_features,
641 | image_idx2row,
642 | word2idx,
643 | single_text_doc_model,
644 | single_img_doc_model,
645 | args)
646 |
647 | test_aucs, test_match_metrics, test_mt_metrics = eval_utils.compute_match_metrics_doc(
648 | test,
649 | image_features,
650 | image_idx2row,
651 | word2idx,
652 | single_text_doc_model,
653 | single_img_doc_model,
654 | args)
655 |
656 | else:
657 | train_aucs, val_aucs, test_aucs = None, None, None
658 | train_match_metrics, val_match_metrics, test_match_metrics = None, None, None
659 | train_mt_metrics, val_mt_metrics, test_mt_metrics = None, None, None
660 |
661 |
662 | output = {'logs':best_logs,
663 |
664 | 'best_sentence_model_str':best_sentence_model_str,
665 | 'best_image_model_str':best_image_model_str,
666 |
667 | 'train_aucs':train_aucs,
668 | 'train_match_metrics':train_match_metrics,
669 | 'train_mt_metrics':train_mt_metrics,
670 |
671 | 'val_aucs':val_aucs,
672 | 'val_match_metrics':val_match_metrics,
673 | 'val_mt_metrics':val_mt_metrics,
674 |
675 | 'test_aucs':test_aucs,
676 | 'test_match_metrics':test_match_metrics,
677 | 'test_mt_metrics':test_mt_metrics,
678 |
679 | 'args':args,
680 | 'epoch':best_epoch}
681 |
682 | if args.scale_image_features:
683 | output['image_standard_scaler_str'] = args.checkpoint_dir + '/image_standardscaler.pkl'
684 |
685 | for k, v in history.history.items():
686 | output['history_{}'.format(k)] = v
687 | if args.print_metrics:
688 | for k, v in metrics_printer.history.items():
689 | output['metrics_history_{}'.format(k)] = v
690 |
691 | with open(args.output, 'wb') as f:
692 | pickle.dump(output, f, protocol=pickle.HIGHEST_PROTOCOL)
693 | print('saved output to {}'.format(args.output))
694 |
695 | if args.save_predictions:
696 | for d, name in zip([train, val, test], ['train', 'val', 'test']):
697 | out_dir = args.save_predictions + '/' + name
698 | if not os.path.exists(out_dir):
699 | os.makedirs(out_dir)
700 | eval_utils.save_predictions(
701 | d,
702 | image_features,
703 | image_idx2row,
704 | word2idx,
705 | single_text_doc_model,
706 | single_img_doc_model,
707 | out_dir,
708 | args)
709 |
710 |
711 | if __name__ == '__main__':
712 | main()
713 |
--------------------------------------------------------------------------------
/training_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import eval_utils
4 | import text_utils
5 | import image_utils
6 |
7 |
8 | class DocumentSequence(tf.keras.utils.Sequence):
9 | def __init__(self,
10 | data_in,
11 | image_matrix,
12 | image_idx2row,
13 | max_sentences_per_doc,
14 | max_images_per_doc,
15 | vocab,
16 | args=None,
17 | augment=True,
18 | shuffle_sentences=False,
19 | shuffle_images=True,
20 | shuffle_docs=True):
21 | self.data_in = data_in
22 | self.image_matrix = image_matrix
23 | self.image_idx2row = image_idx2row
24 | self.max_sentences_per_doc = max_sentences_per_doc
25 | self.max_images_per_doc = max_images_per_doc
26 | self.vocab = vocab
27 | self.args = args
28 | self.argument = augment
29 | self.shuffle_sentences = shuffle_sentences
30 | self.shuffle_images = shuffle_images
31 | self.shuffle_docs = shuffle_docs
32 |
33 | def __len__(self):
34 | return int(np.ceil(len(self.data_in) / self.args.docs_per_batch))
35 |
36 | def __getitem__(self, idx):
37 | start = idx * self.args.docs_per_batch
38 | end = (idx + 1) * self.args.docs_per_batch
39 | cur_doc_b = self.data_in[start: end]
40 |
41 | if idx == len(self) - 1: # final batch may have wrong number of docs
42 | docs_to_add = self.args.docs_per_batch - len(cur_doc_b)
43 | cur_doc_b += self.data_in[:docs_to_add]
44 |
45 | images, texts = [], []
46 | image_n_docs, text_n_docs = [], []
47 | for idx, vers in enumerate(cur_doc_b):
48 | cur_images = [img[0] for img in vers[0]]
49 | cur_text = [text[0] for text in vers[1]]
50 |
51 | if self.shuffle_sentences and not (self.args and self.args.subsample_text > 0):
52 | np.random.shuffle(cur_text)
53 |
54 | if self.shuffle_images and not (self.args and self.args.subsample_image > 0):
55 | np.random.shuffle(cur_images)
56 |
57 | if self.args and self.args.subsample_image > 0:
58 | np.random.shuffle(cur_images)
59 | cur_images = cur_images[:self.args.subsample_image]
60 |
61 | if self.args and self.args.subsample_text > 0:
62 | np.random.shuffle(cur_text)
63 | cur_text = cur_text[:self.args.subsample_text]
64 |
65 | if self.args.end2end:
66 | cur_images = image_utils.images_to_images(cur_images, augment, args)
67 | if self.args and self.args.subsample_image > 0:
68 | image_padding = np.zeros(
69 | (self.args.subsample_image - cur_images.shape[0], 224, 224, 3))
70 | else:
71 | image_padding = np.zeros(
72 | (self.max_images_per_doc - cur_images.shape[0], 224, 224, 3))
73 | else:
74 | cur_images = image_utils.images_to_matrix(
75 | cur_images, self.image_matrix, self.image_idx2row)
76 | if self.args and self.args.subsample_image > 0:
77 | image_padding = np.zeros(
78 | (self.args.subsample_image - cur_images.shape[0], cur_images.shape[-1]))
79 | else:
80 | image_padding = np.zeros(
81 | (self.max_images_per_doc - cur_images.shape[0], cur_images.shape[-1]))
82 |
83 | cur_text = text_utils.text_to_matrix(cur_text, self.vocab, max_len=self.args.seq_len)
84 | image_n_docs.append(cur_images.shape[0])
85 | text_n_docs.append(cur_text.shape[0])
86 |
87 | if self.args and self.args.subsample_text > 0:
88 | text_padding = np.zeros(
89 | (self.args.subsample_text - cur_text.shape[0], cur_text.shape[-1]))
90 | else:
91 | text_padding = np.zeros(
92 | (self.max_sentences_per_doc - cur_text.shape[0], cur_text.shape[-1]))
93 |
94 | # cudnn cant do empty sequences, so for now, I will just put an UNK in front of all sequences.
95 | # see comment in text_utils.py for more information.
96 | text_padding[:, 0] = 1
97 |
98 | cur_images = np.vstack([cur_images, image_padding])
99 | cur_text = np.vstack([cur_text, text_padding])
100 |
101 | cur_images = np.expand_dims(cur_images, 0)
102 | cur_text = np.expand_dims(cur_text, 0)
103 |
104 | images.append(cur_images)
105 | texts.append(cur_text)
106 |
107 | images = np.vstack(images)
108 | texts = np.vstack(texts)
109 |
110 | image_n_docs = np.expand_dims(np.array(image_n_docs), -1)
111 | text_n_docs = np.expand_dims(np.array(text_n_docs), -1)
112 |
113 | y = [np.zeros(len(text_n_docs)), np.zeros(len(image_n_docs))]
114 |
115 | return ([texts,
116 | images,
117 | text_n_docs,
118 | image_n_docs], y)
119 |
120 |
121 | def on_epoch_end(self):
122 | if self.shuffle_docs:
123 | np.random.shuffle(self.data_in)
124 |
125 |
126 | class SaveDocModels(tf.keras.callbacks.Callback):
127 |
128 | def __init__(self,
129 | checkpoint_dir,
130 | single_text_doc_model,
131 | single_image_doc_model):
132 | super(SaveDocModels, self).__init__()
133 | self.checkpoint_dir = checkpoint_dir
134 | self.single_text_doc_model = single_text_doc_model
135 | self.single_image_doc_model = single_image_doc_model
136 |
137 |
138 | def on_train_begin(self, logs={}):
139 | self.best_val_loss = np.inf
140 | self.best_checkpoints_and_logs = None
141 |
142 | def on_epoch_end(self, epoch, logs):
143 | if logs['val_loss'] < self.best_val_loss:
144 | print('New best val loss: {:.5f}'.format(logs['val_loss']))
145 | self.best_val_loss = logs['val_loss']
146 | else:
147 | return
148 | image_model_str = self.checkpoint_dir + '/image_model_epoch_{}_val={:.5f}.model'.format(epoch, logs['val_loss'])
149 | sentence_model_str = self.checkpoint_dir + '/text_model_epoch_{}_val={:.5f}.model'.format(epoch, logs['val_loss'])
150 | self.best_checkpoints_and_logs = (image_model_str, sentence_model_str, logs, epoch)
151 |
152 | self.single_text_doc_model.save(sentence_model_str, overwrite=True, save_format='h5')
153 | self.single_image_doc_model.save(image_model_str, overwrite=True, save_format='h5')
154 |
155 |
156 |
157 | class ReduceLROnPlateauAfterValLoss(tf.keras.callbacks.ReduceLROnPlateau):
158 | '''
159 | Delays the normal operation of ReduceLROnPlateau until the validation
160 | loss reaches a given value.
161 | '''
162 | def __init__(self, activation_val_loss=np.inf, *args, **kwargs):
163 | super(ReduceLROnPlateauAfterValLoss, self).__init__(*args, **kwargs)
164 | self.activation_val_loss = activation_val_loss
165 | self.val_threshold_activated = False
166 |
167 | def in_cooldown(self):
168 | if not self.val_threshold_activated: # check to see if we should activate
169 | if self.current_logs['val_loss'] < self.activation_val_loss:
170 | print('Current validation loss ({}) less than activation val loss ({})'.
171 | format(self.current_logs['val_loss'],
172 | self.activation_val_loss))
173 | print('Normal operation of val LR reduction started.')
174 | self.val_threshold_activated = True
175 | self._reset()
176 |
177 | return self.cooldown_counter > 0 or not self.val_threshold_activated
178 |
179 | def on_epoch_end(self, epoch, logs=None):
180 | self.current_logs = logs
181 | super(ReduceLROnPlateauAfterValLoss, self).on_epoch_end(epoch, logs=logs)
182 |
183 |
184 | class PrintMetrics(tf.keras.callbacks.Callback):
185 | def __init__(self,
186 | val,
187 | image_features,
188 | image_idx2row,
189 | word2idx,
190 | single_text_doc_model,
191 | single_img_doc_model,
192 | args):
193 | super(PrintMetrics, self).__init__()
194 | self.val = val
195 | self.image_features = image_features
196 | self.image_idx2row = image_idx2row
197 | self.word2idx = word2idx
198 | self.single_text_doc_model = single_text_doc_model
199 | self.single_img_doc_model = single_img_doc_model
200 | self.args = args
201 |
202 | def on_train_begin(self, logs=None):
203 | self.epoch = []
204 | self.history = {}
205 |
206 | def on_epoch_end(self, epoch, logs):
207 | metrics = eval_utils.print_all_metrics(
208 | self.val,
209 | self.image_features,
210 | self.image_idx2row,
211 | self.word2idx,
212 | self.single_text_doc_model,
213 | self.single_img_doc_model,
214 | self.args)
215 | self.epoch.append(epoch)
216 | for k, v in metrics.items():
217 | self.history.setdefault(k, []).append(v)
218 |
219 |
220 | class LearningRateLinearIncrease(tf.keras.callbacks.Callback):
221 | def __init__(self, max_lr, warmup_steps, verbose=0):
222 | super(LearningRateLinearIncrease, self).__init__()
223 | self.max_lr = max_lr
224 | self.warmup_steps = warmup_steps
225 | self.verbose = verbose
226 | self.cur_step_count = 0
227 |
228 | def on_train_begin(self, logs=None):
229 | tf.keras.backend.set_value(self.model.optimizer.lr, 0.0)
230 |
231 | def on_batch_begin(self, batch, logs=None):
232 | if self.cur_step_count >= self.warmup_steps:
233 | return
234 | lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
235 | lr += 1./self.warmup_steps * self.max_lr
236 | if self.verbose and self.cur_step_count % 50 == 0:
237 | print('\n new LR = {}\n'.format(lr))
238 | tf.keras.backend.set_value(self.model.optimizer.lr, lr)
239 | self.cur_step_count += 1
240 |
--------------------------------------------------------------------------------
/visualize_predictions_graph.py:
--------------------------------------------------------------------------------
1 | '''
2 | for i in predictions/test/*; do python visualize_predictions.py $i\/doc.json $i/pred_weights.npy prediction_dir/$i ; done;
3 | '''
4 | import argparse
5 | import numpy as np
6 | import bipartite_utils
7 | import json
8 | import os
9 | import subprocess
10 | import matplotlib.pyplot as plt
11 |
12 | from sklearn.metrics import roc_auc_score
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('document')
17 | parser.add_argument('predictions')
18 | parser.add_argument('output')
19 | parser.add_argument('--n_to_show', default=5)
20 | return parser.parse_args()
21 |
22 |
23 | def call(x):
24 | subprocess.call(x, shell=True)
25 |
26 |
27 | def main():
28 | args = parse_args()
29 | pred_adj = np.load(args.predictions)
30 | with open(args.document) as f:
31 | data = json.loads(f.read())
32 |
33 | images, text = data[0], data[1]
34 |
35 | solve_fn = bipartite_utils.generate_fast_hungarian_solving_function()
36 | sol = solve_fn(pred_adj, args.n_to_show)
37 | scores = pred_adj[sol[:,0], sol[:,1]]
38 |
39 | true_adj = np.zeros((len(text), len(images)))
40 | for text_idx, t in enumerate(text):
41 | if t[1] == -1: continue
42 | true_adj[text_idx, t[1]] = 1
43 | for image_idx, t in enumerate(images):
44 | if t[1] == -1: continue
45 | true_adj[t[1], image_idx] = 1
46 |
47 | auc = 100 * roc_auc_score(true_adj.flatten(),
48 | pred_adj.flatten())
49 | print('AUC: {:.2f} {}'.format(auc,
50 | data[-1]))
51 |
52 | ordered_images, ordered_sentences = [], []
53 | for img_idx, sent_idx, sc in sorted(
54 | zip(sol[:,1], sol[:,0], scores), key=lambda x:-x[-1])[:args.n_to_show]:
55 | ordered_images.append(img_idx)
56 | ordered_sentences.append(sent_idx)
57 | print(sc)
58 |
59 | pred_adj_subgraph = pred_adj[np.array(ordered_sentences),:][:,np.array(ordered_images)]
60 | true_adj_subgraph = true_adj[np.array(ordered_sentences),:][:,np.array(ordered_images)]
61 | selected_images = [images[img_idx][0] for img_idx in ordered_images]
62 | selected_sentences = [text[sent_idx][0] for sent_idx in ordered_sentences]
63 |
64 | # normalize predicted sims to have max 1 and min 0
65 | # first, clip out negative values
66 | pred_adj_subgraph = np.clip(pred_adj_subgraph, 0, 1.0)
67 | pred_adj_subgraph -= np.min(pred_adj_subgraph.flatten())
68 | pred_adj_subgraph /= np.max(pred_adj_subgraph.flatten())
69 | assert np.min(pred_adj_subgraph.flatten()) == 0.0
70 | assert np.max(pred_adj_subgraph.flatten()) == 1.0
71 |
72 | print(pred_adj_subgraph.shape)
73 | print(ordered_images)
74 | print(ordered_sentences)
75 | print(selected_images)
76 | print(selected_sentences)
77 |
78 | # each line has ((x1, y1, x2, y2), strength, correctness)
79 | # images go above text
80 | lines_to_plot = []
81 | image_text_gap = 2
82 | same_mode_gap = 2
83 | offdiag_alpha_mul = .5
84 |
85 | def cosine_to_width(cos, exp=2.0, maxwidth=8.0):
86 | return cos**exp * maxwidth
87 | def cosine_to_alpha(cos, exp=1/2., maxalpha=1.0):
88 | return cos**exp * maxalpha
89 |
90 | correct_color, incorrect_color = '#1b7837', '#762a83'
91 | lines_to_plot = []
92 | for text_idx in range(args.n_to_show):
93 | for image_idx in range(args.n_to_show):
94 | coords = (text_idx*same_mode_gap, 0, image_idx*same_mode_gap, image_text_gap)
95 | strength = max(pred_adj_subgraph[text_idx, image_idx], 0)
96 | correctness = true_adj_subgraph[text_idx, image_idx] == 1
97 | lines_to_plot.append((coords, strength, correctness))
98 |
99 | plt.figure(figsize=(args.n_to_show*same_mode_gap, image_text_gap))
100 | for (x1, y1, x2, y2), strength, correct in sorted(lines_to_plot,
101 | key=lambda x: x[1]):
102 | if x1 == x2: continue
103 | plt.plot([x1, x2], [y1, y2],
104 | linewidth=cosine_to_width(strength),
105 | alpha=cosine_to_alpha(strength) * offdiag_alpha_mul,
106 | color=correct_color if correct else incorrect_color)
107 | for (x1, y1, x2, y2), strength, correct in sorted(lines_to_plot,
108 | key=lambda x: x[1]):
109 | if x1 != x2: continue
110 | plt.plot([x1, x2], [y1, y2],
111 | linewidth=cosine_to_width(strength),
112 | color=correct_color if correct else incorrect_color)
113 | plt.axis('off')
114 | plt.tight_layout()
115 |
116 | if not os.path.exists(args.output):
117 | os.makedirs(args.output)
118 | with open(args.output + '/sentences.txt', 'w') as f:
119 | f.write('\n'.join([' '.join(s.split()) for s in selected_sentences]))
120 | with open(args.output + '/images.txt', 'w') as f:
121 | f.write('\n'.join(selected_images))
122 | with open(args.output + '/all_sentences.txt', 'w') as f:
123 | f.write('\n'.join([' '.join(s[0].split()) for s in text]))
124 | with open(args.output + '/all_images.txt', 'w') as f:
125 | f.write('\n'.join([x[0] for x in images]))
126 | with open(args.output + '/auc.txt', 'w') as f:
127 | f.write('{:.4f}'.format(auc))
128 | plt.savefig(args.output + '/graph.png', dpi=300)
129 | call('convert {} -trim {}'.format(args.output + '/graph.png',
130 | args.output + '/graph_cropped.png'))
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/visualize_predictions_html.py:
--------------------------------------------------------------------------------
1 | '''
2 | for i in predictions/train/*; do python visualize_predictions_html.py $i\/doc.json $i/pred_weights.npy $PWD/data/wiki/ prediction_dir/$i ; done;
3 | '''
4 | import argparse
5 | import numpy as np
6 | import bipartite_utils
7 | import json
8 | import os
9 | import subprocess
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('document')
15 | parser.add_argument('predictions')
16 | parser.add_argument('image_dir')
17 | parser.add_argument('output')
18 | return parser.parse_args()
19 |
20 |
21 | def call(x):
22 | subprocess.call(x, shell=True)
23 |
24 |
25 | def main():
26 | args = parse_args()
27 | pred_adj = np.load(args.predictions)
28 | with open(args.document) as f:
29 | data = json.loads(f.read())
30 |
31 | cur_images = [d[0] for d in data[0]]
32 | cur_text = [d[0] for d in data[1]]
33 |
34 |
35 | solve_fn = bipartite_utils.generate_fast_hungarian_solving_function()
36 | sol = solve_fn(pred_adj, max(*pred_adj.shape))
37 | scores = pred_adj[sol[:,0], sol[:,1]]
38 | lines = []
39 | images_to_copy = []
40 | for img_idx, sent_idx, sc in sorted(
41 | zip(sol[:,1], sol[:,0], scores), key=lambda x:-x[-1]):
42 | lines.append('{} ({:.1f})
'.format(cur_text[sent_idx],
43 | sc*100))
44 | images_to_copy.append('{}/{}'.format(args.image_dir, cur_images[img_idx]))
45 | lines.append('
'.format(
46 | cur_images[img_idx].split('/')[-1]))
47 | print(cur_text[sent_idx])
48 | print(cur_images[img_idx])
49 | print(sc)
50 |
51 | lines.append('Article Text:
')
52 | for sent in cur_text:
53 | lines.append('{}
'.format(sent))
54 |
55 | if not os.path.exists(args.output):
56 | os.makedirs(args.output)
57 |
58 | with open('{}/view.html'.format(args.output), 'w') as f:
59 | f.write('\n'.join(lines))
60 |
61 | for im in images_to_copy:
62 | call('cp {} {}'.format(im, args.output))
63 | print()
64 |
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------