├── LICENSE ├── README.md ├── bipartite_utils.py ├── data └── README.md ├── eval_utils.py ├── image_feature_extract ├── extract.py └── make_python_image_info.py ├── image_utils.py ├── model_utils.py ├── paper_commands ├── generate_training_commands_paper.py ├── generate_training_commands_paper_finetune.py └── generate_training_commands_training_dynamics.py ├── requirements.txt ├── summary.png ├── text_utils.py ├── train_doc.py ├── training_utils.py ├── visualize_predictions_graph.py └── visualize_predictions_html.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jack Hessel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## What's in here? 2 | 3 | 4 | This repository contains code to accompany "Unsupervised Discovery of 5 | Multimodal Links in Multi-image/Multi-sentence Documents." (EMNLP 2019; [link](https://arxiv.org/abs/1904.07826)) 6 | 7 |

8 | 9 |

10 | 11 | If you find the code, data, or paper useful, please consider citing 12 | ``` 13 | @inproceedings{hessel-lee-mimno-2019unsupervised, 14 | title={Unsupervised Discovery of Multimodal Links in Multi-Image, Multi-Sentence Documents}, 15 | author={Hessel, Jack and Lee, Lillian and Mimno, David}, 16 | booktitle={EMNLP}, 17 | year={2019} 18 | } 19 | ``` 20 | 21 | *note: I recently upgraded the implementation of this paper to TF2. If you're interested in the exact code used for the EMNLP paper for reproduction purposes, you should check out the tf1 branch, and run with those requirements. However --- I'd highly reccomend using the main tf2 branch. It is much faster, and I've been able to reproduce the paper results with it.* 22 | 23 | ## Requirements 24 | This code requires python3 and several python 25 | libraries. You can install the python requirements with: 26 | 27 | ``` 28 | pip3 install -r requirements.txt 29 | ``` 30 | 31 | Also --- it helps performance to initialize the word embedding 32 | matrices with word2vec embeddings. You can download those embeddings 33 | [here](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit) 34 | (be sure to extract them). When you run the training command, it is 35 | recommended to use the option `--word2vec_binary XXX` where XXX is the 36 | path to the extracted/downloaded word embeddings. 37 | 38 | *A note about evaluating with MT metrics:* the machine translation 39 | metrics, with the exception of sacrebleu, are based on 40 | [pycocoevalcap](https://github.com/salaniz/pycocoevalcap) which itself 41 | has several dependencies. In particular, it requires java 1.8+, and 42 | for the permissions setting to be set so that temporary files can be 43 | written wherever pip installs pycocoevalcap. If you don't have these 44 | additional, only BLEU will be computed, and a warning will print. 45 | 46 | ## How to run 47 | 48 | ### Preparing the dataset 49 | 50 | The training script takes three inputs: 51 | 52 | 1. A json of training/validation/test documents. This json stores a dictionary with three keys: `train`, `val`, and `test`. Each of the keys maps to a list of documents. A document is a list containing 3 things: `[list_of_images, list_of_sentences, metadata]`. 53 | - `list_of_images` is a list of `(identifier, label_text_idx)` tuples, where the identifier is the name of the image, and `label_text_idx` is an integer indicating the index of the corresponding ground-truth sentence in `list_of_sentences`. If there are no labels in the corpus, this index can be set to `None`. If there are labels, but this particular image doesn't correspond to a sentence, you can set the index to `-1`. 54 | - `list_of_sentences` is a list of `(sentence, label_image_idx)` tuples, where sentence is the sentence, and `label_image_idx` is an integer indicating the index of the corresponding ground-truth image in `list_of_images`. If there are no labels in the corpus, this index can be set to `None`. If there are labels, but this particular image doesn't correspond to a sentence, you can set the index to `-1`. 55 | - `metadata` is an optional document identifier. 56 | 2. A json mapping image ids (see `list_of_images`) to row indices in the features matrix. 57 | 3. An image feature matrix, where `matrix[id2row[img_id]]` is the image feature vector corresponding to the image with image id `img_id` and `id2row` is the dictionary stored in the previously described json mapping file. 58 | 59 | Here is an example document from the MSCOCO dataset. 60 | ``` 61 | [[['000000074794', -1], 62 | ['000000339384', 9], 63 | ['000000100064', -1], 64 | ['000000072850', 8], 65 | ['000000046251', -1], 66 | ['000000531828', -1], 67 | ['000000574207', 0], 68 | ['000000185258', 5], 69 | ['000000416357', 1], 70 | ['000000490222', -1]], 71 | [['Two street signs at an intersection on a cloudy day.', -1], 72 | ['A man holding a tennis racquet on a tennis court.', -1], 73 | ['A seagull opens its mouth while standing on a beach.', -1], 74 | ['a man reaching up to hit a tennis ball', -1], 75 | ['A horse sticks his head out of an open stable door. ', -1], 76 | ['Couple standing on a pier with a lot of flags.', -1], 77 | ['A man is riding a skateboard on a ramp.', -1], 78 | ['A man on snow skis leans on his ski poles as he stands in the snow and ' 79 | 'gazes into the distance.', 80 | -1], 81 | ['a close up of a baseball player with a ball and glove', -1], 82 | ['four people jumping in the air and reaching for a frisbee.', -1]], 83 | 'na'] 84 | ``` 85 | 86 | The [image with ID](http://cocodataset.org/#explore?id=339384) 87 | `000000339384` in the MSCOCO dataset corresponds to the caption with 88 | sentence with index 9 in this document, "four people jumping in the 89 | air and reaching for a frisbee.". The underlying graph is undirected, 90 | so the labels are stored only in the image list (though, if you like, 91 | you could redundantly store them on the text-side). For the MSCOCO 92 | dataset, the metadata is un-used. 93 | 94 | The exact train/val/test splits we used, along with pre-extracted 95 | image features, are available for download (see below). You can download 96 | these and extract them in the `data` folder. 97 | 98 | ### Extracting image features for a new dataset 99 | 100 | If you would like to extract image features for a new dataset, there 101 | are a number of existing codebases for that, depending on what neural 102 | network you would like to use. We have included the script that we 103 | used to do that, if you'd like to use ours. In particular, you should: 104 | 105 | 1. Get all of the images of interest into a single folder. Your images should all have unique filenames, as the scripts assume that, e.g., the name of the `jpg` file is the identifier, e.g., `my_images/000000072850.jpg`'s identifier will be `000000072850`. 106 | 2. Create a text file with the full paths of each image 107 | 3. Call `python3 image_feature_extract/extract.py [filenames text file] extracted_features` 108 | 4. Call `python3 make_python_image_info.py extracted_features [filenames text file]` 109 | 110 | This will output a feature matrix (in npy format) and an id2row json 111 | file. These are two of the three arguments. Note --- you may need to 112 | modify `make_python_image_info.py` if your images have different 113 | folders, or if you have multiple images with the same name but 114 | different extensions, e.g., `id.jpg` and `id.png` will both erronously 115 | be mapped to `id`. I may add support for this later (in addition to 116 | cleaning up these scripts...). 117 | 118 | ## How to run the code 119 | 120 | An example training command for the mscoco dataset with reasonable settings is: 121 | ``` 122 | python3 train_doc.py data/mscoco/docs.json \ 123 | --image_id2row data/mscoco/id2row.json \ 124 | --image_features data/mscoco/features.npy \ 125 | --word2vec_binary data/GoogleNews-vectors-negative300.bin \ 126 | --cached_word_embeddings mscoco_cached_word_embs.json \ 127 | --print_metrics 1 \ 128 | --output mscoco_results.pkl 129 | ``` 130 | 131 | note that even though metrics are printing during training if you use 132 | `--print_metrics 1`, there is no early stopping/supervision happening 133 | on the labels during training. 134 | 135 | you can run this to get more information about particular training options 136 | ``` 137 | python3 train_doc.py --help 138 | ``` 139 | 140 | From the paper, here's an example of running with hard negative 141 | mining, the AP similarity function, and 20 negative samples 142 | ``` 143 | python3 train_doc.py data/mscoco/docs.json --image_id2row data/mscoco/id2row.json \ 144 | --image_features data/mscoco/features.npy \ 145 | --word2vec_binary data/GoogleNews-vectors-negative300.bin \ 146 | --cached_word_embeddings mscoco_cached_word_embs.json \ 147 | --print_metrics 1 \ 148 | --output mscoco_results.pkl \ 149 | --sim_mode AP \ 150 | --docs_per_batch 21 \ 151 | --cached_vocab mscoco_vocab.json 152 | ``` 153 | 154 | ## How to reproduce the results of the paper: 155 | 156 | The datasets we use with specific splits/pre-extracted image features 157 | are available for download. If you are just using the datasets, please 158 | cite the original creators of the datasets. Furthermore, all datasets 159 | are subsets of their original creator's releases; please use the 160 | versions from the original links if you are looking for more complete 161 | datasets! 162 | 163 | - MSCOCO ([original source](http://cocodataset.org/#home)) [link](https://drive.google.com/open?id=1LGqUst-BB8N4nFPNGHD0uVa3x_cAZ7UV) 164 | - DII ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1zFouzVhXvnK19zv3AYT-wZJt8SFTcRXY) 165 | - SIS ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1MN6gPGhymAHvPJL6dRTu-VbXfYlI0L7-) 166 | - DII-Stress ([original source](http://visionandlanguage.net/VIST/dataset.html)) [link](https://drive.google.com/open?id=1vLOMftRh8U5r3sn29X2l8XxVXQppsLYS) 167 | - RQA ([original source](https://hucvl.github.io/recipeqa/)) [link](https://drive.google.com/open?id=1BbD1OnV4h02QUk1eZT1hFWWKlDwUyz3O) 168 | - DIY (we collected this dataset) [link](https://drive.google.com/open?id=1EdgL2VYrVTLccP8wHpynpFhv3PNuZiOv) 169 | - WIKI ([original source](https://www.imageclef.org/wikidata)) [link](https://drive.google.com/open?id=1Ecb1LkTXX4sskx-PLB2o3vMru-8I8rEy) 170 | 171 | In addition, we have included scripts that generate the exact training commands executed in the paper itself. These are located in the paper_commands directory. 172 | *Note, however, that the code used for the paper is now located in the tf1 branch. The main branch has been ported to TF2.* 173 | -------------------------------------------------------------------------------- /bipartite_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lapjv import lapjv 3 | 4 | def generate_fast_hungarian_solving_function(): 5 | def base_solve(W): 6 | orig_shape = W.shape 7 | if orig_shape[0] != orig_shape[1]: 8 | if orig_shape[0] > orig_shape[1]: 9 | pad_idxs = [[0, 0], [0, W.shape[0]-W.shape[1]]] 10 | col_pad = True 11 | else: 12 | pad_idxs = [[0, W.shape[1]-W.shape[0]], [0, 0]] 13 | col_pad = False 14 | W = np.pad(W, pad_idxs, 'constant', constant_values=-100) 15 | sol, _, cost = lapjv(-W) 16 | 17 | i_s = np.arange(len(sol)) 18 | j_s = sol[i_s] 19 | 20 | sort_idxs = np.argsort(-W[i_s, j_s]) 21 | i_s, j_s = map(lambda x: x[sort_idxs], [i_s, j_s]) 22 | 23 | if orig_shape[0] != orig_shape[1]: 24 | if col_pad: 25 | valid_idxs = np.where(j_s < orig_shape[1])[0] 26 | else: 27 | valid_idxs = np.where(i_s < orig_shape[0])[0] 28 | i_s, j_s = i_s[valid_idxs], j_s[valid_idxs] 29 | 30 | indices = np.hstack([np.expand_dims(i_s, -1), np.expand_dims(j_s, -1)]).astype(np.int32) 31 | return indices 32 | 33 | def hungarian_solve(W, k, max_val=1000): 34 | min_dim = min(*W.shape) 35 | if k <= 0 or k >= min_dim: 36 | return base_solve(W) 37 | 38 | add_rows = W.shape[0] > W.shape[1] 39 | add_len = min_dim-k 40 | 41 | if add_rows: 42 | add_mat = np.zeros((add_len, W.shape[1])) 43 | add_mat[:] = max_val 44 | new_W = np.vstack([W, add_mat]) 45 | else: 46 | add_mat = np.zeros((W.shape[0], add_len)) 47 | add_mat[:] = max_val 48 | new_W = np.hstack([W, add_mat]) 49 | 50 | indices = base_solve(new_W) 51 | indices = indices[-k:, :] 52 | return indices 53 | 54 | return hungarian_solve 55 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## What goes in here? 2 | 3 | Dataset info (document jsons, image features, image mappings, raw images) -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from sklearn.metrics import roc_auc_score 4 | import sacrebleu 5 | 6 | # pycocoevalcap imports 7 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 8 | from pycocoevalcap.bleu.bleu import Bleu 9 | from pycocoevalcap.meteor.meteor import Meteor 10 | from pycocoevalcap.rouge.rouge import Rouge 11 | from pycocoevalcap.cider.cider import Cider 12 | 13 | 14 | import collections 15 | import tqdm 16 | import image_utils 17 | import text_utils 18 | import os 19 | import sklearn.preprocessing 20 | import json 21 | 22 | 23 | def compute_match_metrics_doc(docs, 24 | image_matrix, 25 | image_idx2row, 26 | vocab, 27 | text_trans, 28 | image_trans, 29 | args, 30 | ks=[1,2,3,4,5,6,7,8,9,10]): 31 | 32 | all_aucs, all_match_metrics = [], collections.defaultdict(list) 33 | all_texts, all_images = [], [] 34 | 35 | 36 | #for MT metrics 37 | all_refs = [] 38 | all_sys = [] 39 | 40 | # supress warwning that I believe is thrown by keras' timedistributed 41 | # over dynamic tensors... 42 | tf.get_logger().setLevel('ERROR') 43 | for images, text, meta in tqdm.tqdm(docs): 44 | # text by image matrix 45 | cur_text = [t[0] for t in text] 46 | cur_images = [t[0] for t in images] 47 | n_text, n_image = map(len, [cur_text, cur_images]) 48 | 49 | cur_text = text_utils.text_to_matrix(cur_text, vocab, max_len=args.seq_len) 50 | 51 | if args.end2end: 52 | cur_images = image_utils.images_to_images(cur_images, False, args) 53 | else: 54 | cur_images = image_utils.images_to_matrix(cur_images, image_matrix, image_idx2row) 55 | 56 | cur_text = np.expand_dims(cur_text, 0) 57 | cur_images = np.expand_dims(cur_images, 0) 58 | 59 | text_vec = text_trans.predict(cur_text) 60 | image_vec = image_trans.predict(cur_images) 61 | 62 | text_vec = text_vec[0,:n_text,:] 63 | image_vec = image_vec[0,:n_image,:] 64 | 65 | pred_adj = text_vec.dot(image_vec.transpose()) 66 | true_adj = np.zeros((len(text), len(images))) 67 | 68 | # for MT metrics, for each image with a ground-truth sentence, 69 | # extract predicted sentences. 70 | im2best_text_idxs = np.argmax(pred_adj, axis=0) 71 | im2all_predicted_captions = [text[idx][0] for idx in im2best_text_idxs] 72 | 73 | im2predicted_captions = {} 74 | im2ground_truth_captions = collections.defaultdict(list) 75 | 76 | for text_idx, t in enumerate(text): 77 | if t[1] == -1: continue 78 | true_adj[text_idx, t[1]] = 1 79 | for image_idx, t in enumerate(images): 80 | if t[1] == -1: continue 81 | true_adj[t[1], image_idx] = 1 82 | 83 | if np.sum(true_adj.flatten()) == 0: 84 | continue 85 | if np.sum(true_adj.flatten()) == len(true_adj.flatten()): 86 | continue 87 | 88 | for text_gt_idx, image_gt_idx in zip(*np.where(true_adj==1)): 89 | im2predicted_captions[image_gt_idx] = im2all_predicted_captions[image_idx] 90 | im2ground_truth_captions[image_gt_idx].append(text[text_gt_idx][0]) 91 | 92 | for img_idx, pred in im2predicted_captions.items(): 93 | all_refs.append(im2ground_truth_captions[img_idx]) 94 | all_sys.append(pred) 95 | 96 | pred_adj = pred_adj.flatten() 97 | true_adj = true_adj.flatten() 98 | 99 | 100 | pred_order = true_adj[np.argsort(-pred_adj)] 101 | for k in ks: 102 | all_match_metrics[k].append(np.mean(pred_order[:k])) 103 | 104 | all_aucs.append(roc_auc_score(true_adj, pred_adj)) 105 | tf.get_logger().setLevel('INFO') 106 | 107 | # give each instance a unique IDX for the metric computation... 108 | all_refs = {idx: refs for idx, refs in enumerate(all_refs)} 109 | all_sys = {idx: pred for idx, pred in enumerate(all_sys)} 110 | 111 | if len(all_refs) > 0: 112 | all_mt_metrics = compute_mt_metrics(all_sys, all_refs, args) 113 | else: 114 | all_mt_metrics = {} 115 | return all_aucs, all_match_metrics, all_mt_metrics 116 | 117 | 118 | def compute_mt_metrics(all_sys, all_refs, args): 119 | ''' 120 | # we need a dictionary mapping 121 | all_sys maps {unique_idx --> pred} 122 | all_ref maps {unique_idx --> [ground truths] 123 | ''' 124 | res_dict = {} 125 | 126 | # need all cases to have the same number of references for 127 | # sacrebleu. however --- this will not always be the case in our 128 | # data, e.g., if an image has two ground truths. If there's an 129 | # image with 3 ground truth links, we can repeat ground truths 130 | # until all have 3. However --- this duplication modifies some 131 | # of the corpus-level normalizations, in particular, the number 132 | # of ground truth tokens. So --- we should prefer the MSCOCO 133 | # bleu scorer in this case. However --- we'll still compute 134 | # the sacrebleu scores anyway, but include a flag that says 135 | # to not trust them. 136 | 137 | # add in repeated predictions for cases will less than 138 | # the maximum number of references: 139 | 140 | n_refs_max = np.max([len(r) for r in all_refs.values()]) 141 | n_refs_min = np.min([len(r) for r in all_refs.values()]) 142 | 143 | print('Using {} maximum references for MT metrics'.format(n_refs_max)) 144 | 145 | trust_sacrebleu = n_refs_max == n_refs_min 146 | 147 | all_refs_sacrebleu = [] 148 | for outer_idx in range(n_refs_max): 149 | cur_refs = [all_refs[inner_idx][min(outer_idx, len(all_refs[inner_idx])-1)] 150 | for inner_idx in range(len(all_refs))] 151 | all_refs_sacrebleu.append(cur_refs) 152 | 153 | sacre_bleu = sacrebleu.corpus_bleu([all_sys[idx] for idx in range(len(all_sys))], all_refs_sacrebleu) 154 | res_dict['sacre_bleu'] = sacre_bleu.score 155 | res_dict['can_trust_sacrebleu_with_global_counts'] = trust_sacrebleu 156 | 157 | if not args.compute_mscoco_eval_metrics: 158 | return res_dict 159 | 160 | try: 161 | tokenizer = PTBTokenizer() 162 | coco_sys = {idx: [{'caption': ' '.join(r.split())}] for idx, r in all_sys.items()} 163 | coco_sys = tokenizer.tokenize(coco_sys) 164 | 165 | coco_ref = {idx: [{'caption': ' '.join(r.split())} for r in refs] for idx, refs in all_refs.items()} 166 | coco_ref = tokenizer.tokenize(coco_ref) 167 | 168 | scorers = [ 169 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 170 | (Meteor(), "METEOR"), 171 | (Rouge(), "ROUGE_L"), 172 | (Cider(), "CIDEr") 173 | ] 174 | 175 | for scorer, method in scorers: 176 | score, _ = scorer.compute_score(coco_ref, coco_sys) 177 | if type(method) is list: 178 | for s, m in zip(score, method): 179 | res_dict['MSCOCO_{}'.format(m)] = s 180 | else: 181 | res_dict['MSCOCO_{}'.format(method)] = score 182 | except Exception as e: 183 | print('Unable to compute MSCOCO metrics: {}'.format(e)) 184 | print('continuing nonetheless') 185 | 186 | return res_dict 187 | 188 | 189 | def compute_feat_spread(feats_in): 190 | feats_in = sklearn.preprocessing.normalize(feats_in) 191 | mean = np.expand_dims(np.mean(feats_in, axis=0), 0) 192 | feats_in = feats_in - mean 193 | norms = np.linalg.norm(feats_in, axis=1) 194 | norms = norms * norms 195 | return np.mean(norms) 196 | 197 | 198 | def save_predictions(docs, 199 | image_matrix, 200 | image_idx2row, 201 | vocab, 202 | text_trans, 203 | image_trans, 204 | output_dir, 205 | args, 206 | limit=None): 207 | 208 | if limit is not None: 209 | docs = docs[:limit] 210 | 211 | for idx, (images, text, meta) in tqdm.tqdm(enumerate(docs)): 212 | # text by image matrix 213 | identifier = meta if meta else str(idx) 214 | identifier = meta.replace('/', '_') 215 | 216 | cur_text = [t[0] for t in text] 217 | cur_images = [t[0] for t in images] 218 | 219 | n_text, n_image = map(len, [cur_text, cur_images]) 220 | 221 | cur_text = text_utils.text_to_matrix(cur_text, vocab, max_len=args.seq_len) 222 | 223 | if args.end2end: 224 | cur_images = image_utils.images_to_images(cur_images, False, args) 225 | else: 226 | cur_images = image_utils.images_to_matrix(cur_images, image_matrix, image_idx2row) 227 | 228 | cur_images = np.expand_dims(cur_images, 0) 229 | cur_text = np.expand_dims(cur_text, 0) 230 | 231 | text_vec = text_trans.predict(cur_text) 232 | image_vec = image_trans.predict(cur_images) 233 | 234 | text_vec = text_vec[0,:n_text,:] 235 | image_vec = image_vec[0,:n_image,:] 236 | 237 | image_spread, text_spread = map(compute_feat_spread, 238 | [text_vec, image_vec]) 239 | 240 | cur_out_dir = output_dir + '/{}_textspread_{:.4f}_imagespread_{:.4f}/'.format( 241 | identifier, text_spread, image_spread) 242 | 243 | if not os.path.exists(cur_out_dir): 244 | os.makedirs(cur_out_dir) 245 | pred_adj = text_vec.dot(image_vec.transpose()) 246 | np.save(cur_out_dir + '/pred_weights.npy', 247 | pred_adj) 248 | with open(cur_out_dir + '/doc.json', 'w') as f: 249 | f.write(json.dumps((images, text, meta))) 250 | 251 | 252 | def print_all_metrics(data, 253 | image_features, 254 | image_idx2row, 255 | word2idx, 256 | single_text_doc_model, 257 | single_img_doc_model, 258 | args): 259 | metrics_dict = {} 260 | aucs, rec2prec, mt_metrics = compute_match_metrics_doc(data, 261 | image_features, 262 | image_idx2row, 263 | word2idx, 264 | single_text_doc_model, 265 | single_img_doc_model, 266 | args) 267 | print('Validation AUC={:.2f}'.format(100*np.mean(aucs))) 268 | metrics_dict['aucs'] = 100 * np.mean(aucs) 269 | prec = {} 270 | prec_str = 'Validation ' 271 | ks = sorted(list(rec2prec.keys())) 272 | for k in ks: 273 | res = np.mean(rec2prec[k])*100 274 | metrics_dict['p@{}'.format(k)] = res 275 | prec_str += 'p@{}={:.2f} '.format(k, res) 276 | print(prec_str.strip()) 277 | print('Machine translation metrics: {}'.format(str(mt_metrics))) 278 | metrics_dict['mt_metrics'] = mt_metrics 279 | return metrics_dict 280 | -------------------------------------------------------------------------------- /image_feature_extract/extract.py: -------------------------------------------------------------------------------- 1 | from keras.applications.densenet import preprocess_input as preprocess_input_dn 2 | from keras.applications.densenet import DenseNet169 3 | 4 | from keras.preprocessing import image 5 | from keras.models import Sequential, Model 6 | from keras.layers import Flatten 7 | import os, sys, numpy as np 8 | import keras 9 | import tensorflow as tf 10 | import keras.backend as K 11 | 12 | 13 | def load_images(image_list): 14 | images = [] 15 | for i in image_list: 16 | c_img = np.expand_dims(image.img_to_array( 17 | image.load_img(i, target_size = (256, 256))), axis=0) 18 | images.append(c_img) 19 | return np.vstack(images) 20 | 21 | 22 | def image_generator(fnames, batch_size): 23 | cfns = [] 24 | for i, p in enumerate(fnames): 25 | cfns.append(p) 26 | if len(cfns) == batch_size: 27 | yield load_images(cfns) 28 | cfns = [] 29 | if len(cfns) != 0: 30 | yield load_images(cfns) 31 | cfns = [] 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | if len(sys.argv) != 3: 37 | print("usage: [image list] [output name]") 38 | quit() 39 | 40 | 41 | keras.backend.set_learning_phase(0) 42 | do_center_crop = True 43 | 44 | file_list = sys.argv[1] 45 | fpath = sys.argv[2] 46 | 47 | all_paths = [] 48 | with open(file_list) as f: 49 | for i,line in enumerate(f): 50 | all_paths.append(line.strip()) 51 | 52 | print("Extracting from {} files".format(len(all_paths))) 53 | 54 | print('Saving to {}'.format(fpath)) 55 | base_model = DenseNet169(include_top=False, input_shape=(224,224,3)) 56 | base_model.trainable = False 57 | base_model.summary() 58 | 59 | m_image = keras.layers.Input((256, 256, 3), dtype='float32') 60 | 61 | if do_center_crop: 62 | crop = keras.layers.Lambda(lambda x: tf.image.central_crop(x, .875), 63 | output_shape=(224, 224, 3))(m_image) 64 | else: 65 | crop = keras.layers.Lambda(lambda x: tf.map_fn(lambda y: tf.random_crop(y, [224, 224, 3]), x), 66 | output_shape=(224, 224, 3))(m_image) 67 | 68 | crop = keras.layers.Lambda(lambda x: preprocess_input_dn(x))(crop) 69 | trans = base_model(crop) 70 | trans = keras.layers.GlobalAveragePooling2D()(trans) 71 | model = Model(inputs=m_image, 72 | outputs=trans) 73 | model.summary() 74 | 75 | batch_size = 100 76 | print('Getting started...') 77 | gen = image_generator(all_paths, batch_size) 78 | 79 | feats = model.predict_generator(gen, 80 | np.ceil(len(all_paths)/batch_size), 81 | use_multiprocessing=False, 82 | workers=1, 83 | verbose=1) 84 | print(feats.shape) 85 | print("Saving to {}".format(fpath)) 86 | with open(fpath, 'w') as f: 87 | for i in range(feats.shape[0]): 88 | f.write(" ".join(["{:.10f}".format(x) for x in feats[i,:]]) + "\n") 89 | -------------------------------------------------------------------------------- /image_feature_extract/make_python_image_info.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import numpy as np 4 | import tqdm 5 | 6 | 7 | if len(sys.argv) != 3: 8 | print('usage: [text file of features] [filenames text file]') 9 | quit() 10 | 11 | 12 | def load_matrix(fname, n_rows): 13 | vecs = [] 14 | with open(fname) as f: 15 | for idx in tqdm.tqdm(range(n_rows)): 16 | vecs.append(np.array([float(x) for x in f.readline().split()])) 17 | return np.vstack(vecs) 18 | 19 | 20 | def load_ordered_ids(fname): 21 | ids = [] 22 | with open(fname) as f: 23 | for line in f: 24 | if line.strip(): 25 | ids.append('.'.join(line.split('/')[-1].split('.')[:-1])) 26 | return ids 27 | 28 | 29 | ordered_ids = load_ordered_ids(sys.argv[2]) 30 | id2row = {str(idx): row for row, idx in enumerate(ordered_ids)} 31 | print('loading {} rows of matrix...'.format(len(ordered_ids))) 32 | m_mat = load_matrix(sys.argv[1], len(ordered_ids)) 33 | 34 | with open('id2row.json', 'w') as f: 35 | f.write(json.dumps(id2row)) 36 | 37 | np.save(sys.argv[1].split('.')[0], m_mat) 38 | 39 | -------------------------------------------------------------------------------- /image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras.preprocessing import image 4 | from tensorflow.keras.applications.nasnet import preprocess_input 5 | import scipy.misc 6 | 7 | data_augmentor = image.ImageDataGenerator(rotation_range=20, 8 | horizontal_flip=True, 9 | fill_mode='reflect') 10 | 11 | def crop(image_in, targ, center=False): 12 | x, y = image_in.shape[:2] 13 | if center: 14 | x_off, y_off = (x-targ) // 2, (y-targ) // 2 15 | else: 16 | x_off, y_off = np.random.randint(0, (x-targ)), np.random.randint(0, (y-targ)) 17 | res = image_in[x_off:x_off+targ, y_off:y_off+targ, ...] 18 | return res 19 | 20 | 21 | def augment_image(image_in): 22 | image_out = data_augmentor.random_transform(image_in) 23 | return np.expand_dims(crop(image_out, 224), axis=0) 24 | 25 | 26 | def images_to_matrix(image_list, image_matrix, image_idx2row): 27 | rows = [] 28 | for img in image_list: 29 | rows.append(image_idx2row[str(img)]) 30 | 31 | if type(image_matrix) is list: 32 | image_matrix_choices = np.random.choice(len(image_matrix), size=len(rows)) 33 | all_features = [image_matrix[image_matrix_choices[idx]][r,:] for idx, r in enumerate(rows)] 34 | return np.vstack(all_features) 35 | else: 36 | return image_matrix[np.array(rows), :] 37 | 38 | 39 | def load_images(image_list, preprocess=True, target_size=(224, 224)): 40 | images = [] 41 | for i in image_list: 42 | c_img = np.expand_dims(image.img_to_array( 43 | image.load_img(i, target_size=target_size)), axis=0) 44 | images.append(c_img) 45 | images = np.vstack(images) 46 | if preprocess: 47 | images = preprocess_input(images) 48 | return images 49 | 50 | 51 | def images_to_images(image_list, augment, args): 52 | if args.full_image_paths: 53 | images = load_images([args.image_dir + '/' + img for img in image_list], target_size=(256, 256)) 54 | else: 55 | images = load_images([args.image_dir + '/' + img + '.jpg' for img in image_list], target_size=(256, 256)) 56 | if augment: 57 | images = np.vstack(list(map(augment_image, images))) 58 | else: 59 | images = np.vstack(list(map(lambda x: np.expand_dims(crop(x, 224, center=True), axis=0), images))) 60 | return images 61 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Model utility functions 3 | ''' 4 | import tensorflow as tf 5 | import collections 6 | import sys 7 | 8 | 9 | def make_get_pos_neg_sims(args, sim_fn): 10 | 11 | def get_pos_neg_sims(inp): 12 | ''' 13 | Applies the similarity function between all text_idx, img_idx pairs. 14 | 15 | inp is a list of three arguments: 16 | - sims: the stacked similarity matrix 17 | - text_n_inp: how many sentences are in each document 18 | - img_n_inp: how many images are in each document 19 | ''' 20 | 21 | sims, text_n_inp, img_n_inp = inp 22 | text_index_borders = tf.dtypes.cast(tf.cumsum(text_n_inp), tf.int32) 23 | img_index_borders = tf.dtypes.cast(tf.cumsum(img_n_inp), tf.int32) 24 | zero = tf.expand_dims(tf.expand_dims(tf.constant(0, dtype=tf.int32), axis=-1), axis=-1) 25 | 26 | # these give the indices of the borders between documents in our big sim matrix... 27 | text_index_borders = tf.concat([zero, text_index_borders], axis=0) 28 | img_index_borders = tf.concat([zero, img_index_borders], axis=0) 29 | 30 | doc2pos_sim = {} 31 | doc2neg_img_sims = collections.defaultdict(list) 32 | doc2neg_text_sims = collections.defaultdict(list) 33 | 34 | # for each pair of text set and image set... 35 | for text_idx in range(args.docs_per_batch): 36 | for img_idx in range(args.docs_per_batch): 37 | text_start = tf.squeeze(text_index_borders[text_idx]) 38 | text_end = tf.squeeze(text_index_borders[text_idx+1]) 39 | img_start = tf.squeeze(img_index_borders[img_idx]) 40 | img_end = tf.squeeze(img_index_borders[img_idx+1]) 41 | cur_sims = sims[text_start:text_end, img_start:img_end] 42 | sim = sim_fn(cur_sims) 43 | if text_idx == img_idx: 44 | doc2pos_sim[text_idx] = sim 45 | else: # negative cases 46 | doc2neg_img_sims[text_idx].append(sim) 47 | doc2neg_text_sims[img_idx].append(sim) 48 | 49 | pos_sims, neg_img_sims, neg_text_sims = [], [], [] 50 | for idx in range(args.docs_per_batch): 51 | pos_sims.append(doc2pos_sim[idx]) 52 | neg_img_sims.append(tf.stack(doc2neg_img_sims[idx])) 53 | neg_text_sims.append(tf.stack(doc2neg_text_sims[idx])) 54 | 55 | pos_sims = tf.expand_dims(tf.stack(pos_sims), -1) 56 | neg_img_sims = tf.stack(neg_img_sims) 57 | neg_text_sims = tf.stack(neg_text_sims) 58 | 59 | return [pos_sims, neg_img_sims, neg_text_sims] 60 | 61 | return get_pos_neg_sims 62 | -------------------------------------------------------------------------------- /paper_commands/generate_training_commands_paper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script generates the training commands for the results reported in the paper 3 | ''' 4 | import collections 5 | import numpy as np 6 | import os 7 | 8 | np.random.seed(1) 9 | if not os.path.exists('paper_res'): 10 | os.makedirs('paper_res') 11 | if not os.path.exists('paper_checkpoints'): 12 | os.makedirs('paper_checkpoints') 13 | 14 | n_parallel = 4 15 | 16 | all_commands = [] 17 | 18 | for batch_size in [11]:#21,31]: 19 | for dset in ['diy', 'mscoco', 'rqa', 'sis', 'dii-r', 'wiki']: 20 | for alg in ['DC','TK','AP','NoStruct']: 21 | for k in [-1] if alg in ['DC', 'NoStruct'] else [-1, 2]: 22 | for neg_mining in ['negative_sample', 'hard_negative']: 23 | 24 | # for wiki, just do limited settings 25 | if dset == 'wiki' and (alg not in ['AP', 'NoStruct'] or k != -1 or neg_mining != 'hard_negative' or batch_size != 11): continue 26 | 27 | identifier_str = '{}+{}+{}+{}+{}'.format(dset, batch_size, alg, k, neg_mining) 28 | output_f = 'paper_res/{}.pkl'.format(identifier_str) 29 | 30 | cmd = 'python3 train_doc.py data/{}/docs.json --image_id2row data/{}/id2row.json --image_features data/{}/features.npy '.format( 31 | dset, dset, dset) 32 | if alg == 'NoStruct': 33 | cmd += '--subsample_image 1 --subsample_text 1 ' 34 | 35 | if dset in ['rqa','diy','wiki']: 36 | cmd += '--seq_len 50 ' 37 | 38 | cmd += '--gpu_memory_frac .45 --sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg if alg != 'NoStruct' else 'DC', neg_mining, k) 39 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset) 40 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f) 41 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str) 42 | 43 | if os.path.exists(output_f): continue 44 | all_commands.append(cmd) 45 | 46 | files = [open('{}_commands.txt'.format(idx+1), 'w') for idx in range(n_parallel)] 47 | idx = 0 48 | for cmd in all_commands: 49 | files[idx].write(cmd + '\n') 50 | idx += 1 51 | idx = idx % n_parallel 52 | 53 | for f in files: 54 | f.close() 55 | -------------------------------------------------------------------------------- /paper_commands/generate_training_commands_paper_finetune.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script generates the training commands for the results reported in the paper 3 | ''' 4 | import collections 5 | import numpy as np 6 | import os 7 | 8 | np.random.seed(1) 9 | if not os.path.exists('paper_res'): 10 | os.makedirs('paper_res') 11 | if not os.path.exists('paper_checkpoints'): 12 | os.makedirs('paper_checkpoints') 13 | 14 | n_parallel = 1 15 | 16 | all_commands = [] 17 | 18 | for batch_size, dset, alg, subsample_img, subsample_txt, neg_mining in [ 19 | (11, 'rqa', 'AP', 10, 10, 'hard_negative'), 20 | (11, 'diy', 'AP', 10, 10, 'hard_negative'), 21 | (11, 'wiki', 'AP', 10, 10, 'hard_negative')]: 22 | identifier_str = '{}+{}+{}+{}+{}+FT'.format(dset, batch_size, alg, -1, neg_mining) 23 | output_f = 'paper_res/{}.pkl'.format(identifier_str) 24 | #if os.path.exists(output_f): continue 25 | if 'wiki' != dset: 26 | cmd = 'python3 train_doc.py data/{}/docs.json --image_dir data/{}/images '.format(dset, dset) 27 | else: 28 | cmd = 'python3 train_doc.py data/{}/docs.json --image_dir data/{} '.format(dset, dset) 29 | cmd += '--sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg, neg_mining, -1) 30 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset) 31 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f) 32 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str) 33 | cmd += '--subsample_text {} '.format(subsample_txt) 34 | cmd += '--subsample_image {} '.format(subsample_img) 35 | cmd += '--end2end 1 ' 36 | cmd += '--seq_len 50 ' 37 | cmd += '--force 1 ' 38 | if dset == 'wiki': 39 | cmd += '--full_image_paths 1' 40 | all_commands.append(cmd) 41 | 42 | files = [open('{}_commands_FT.txt'.format(idx+1), 'w') for idx in range(n_parallel)] 43 | idx = 0 44 | for cmd in all_commands: 45 | files[idx].write(cmd + '\n') 46 | idx += 1 47 | idx = idx % n_parallel 48 | 49 | for f in files: 50 | f.close() 51 | -------------------------------------------------------------------------------- /paper_commands/generate_training_commands_training_dynamics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script generates the training commands for the results reported in the paper 3 | ''' 4 | import collections 5 | import numpy as np 6 | import os 7 | 8 | np.random.seed(1) 9 | if not os.path.exists('paper_res'): 10 | os.makedirs('paper_res') 11 | if not os.path.exists('paper_checkpoints'): 12 | os.makedirs('paper_checkpoints') 13 | 14 | n_parallel = 4 15 | 16 | all_commands = [] 17 | 18 | for batch_size in [11]: 19 | for dset in ['dii', 'diy', 'mscoco', 'rqa', 'sis', 'dii-r']: 20 | for alg in ['AP']: 21 | for neg_mining in ['hard_negative']: 22 | k = -1 23 | identifier_str = '{}+{}+{}+{}+{}+DYNAMICS'.format(dset, batch_size, alg, k, neg_mining) 24 | output_f = 'paper_res/{}.pkl'.format(identifier_str) 25 | cmd = 'python3 train_doc.py data/{}/docs.json --image_id2row data/{}/id2row.json --image_features data/{}/features.npy '.format( 26 | dset, dset, dset) 27 | if dset in ['rqa','diy',]: 28 | cmd += '--seq_len 50 ' 29 | cmd += '--print_metrics 1 ' 30 | cmd += '--gpu_memory_frac .45 --sim_mode {} --neg_mining {} --sim_mode_k {} '.format(alg if alg != 'NoStruct' else 'DC', neg_mining, k) 31 | cmd += '--cached_vocab {}_vocab.json --word2vec_binary GoogleNews-vectors-negative300.bin '.format(dset) 32 | cmd += '--cached_word_embeddings {} --output {} '.format(dset, output_f) 33 | cmd += '--checkpoint_dir {}/{} '.format('paper_checkpoints', identifier_str) 34 | 35 | if os.path.exists(output_f): continue 36 | all_commands.append(cmd) 37 | 38 | files = [open('{}_commands_training_dynamics.txt'.format(idx+1), 'w') for idx in range(n_parallel)] 39 | idx = 0 40 | for cmd in all_commands: 41 | files[idx].write(cmd + '\n') 42 | idx += 1 43 | idx = idx % n_parallel 44 | 45 | for f in files: 46 | f.close() 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/src-d/lapjv 2 | numpy 3 | scipy 4 | scikit-learn 5 | tqdm 6 | spacy 7 | gensim 8 | nltk 9 | sacrebleu 10 | tensorflow>=2.1 11 | git+https://github.com/jmhessel/pycocoevalcap 12 | -------------------------------------------------------------------------------- /summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmhessel/multi-retrieval/20178dfda368ed4117f3aaee7b0bd2946fcdf9eb/summary.png -------------------------------------------------------------------------------- /text_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import collections 4 | import tensorflow as tf 5 | import tqdm 6 | import numpy as np 7 | import pprint 8 | import warnings 9 | 10 | from gensim.models.keyedvectors import KeyedVectors 11 | 12 | from nltk.tokenize import TweetTokenizer 13 | global _TOKENIZER 14 | _TOKENIZER = TweetTokenizer() 15 | 16 | 17 | def preprocess_caption(cap_in): 18 | return ' '.join(_TOKENIZER.tokenize(cap_in)).lower() 19 | 20 | 21 | def get_vocab(data, min_count=5, cached=None): 22 | if cached is None or not os.path.exists(cached): 23 | voc_counter = collections.Counter() 24 | for c in tqdm.tqdm(data): 25 | voc_counter.update(preprocess_caption(c).split()) 26 | word2idx = {'': 0, '': 1} 27 | idx = len(word2idx) 28 | for v, c in sorted(voc_counter.items(), key=lambda x: x[1], reverse=True): 29 | if c < min_count: 30 | break 31 | word2idx[v] = idx 32 | idx += 1 33 | if cached is not None: 34 | with open(cached, 'w') as f: 35 | f.write(json.dumps(word2idx)) 36 | else: 37 | with open(cached) as f: 38 | word2idx = json.loads(f.read()) 39 | return word2idx 40 | 41 | 42 | def get_word2vec_matrix(vocab, cache_file, word2vec_binary): 43 | if cache_file is None and word2vec_binary is None: 44 | return None 45 | if cache_file is None or not os.path.exists(cache_file): 46 | print('Loading word2vec binary...') 47 | word2vec = KeyedVectors.load_word2vec_format(word2vec_binary, binary=True) 48 | word2vec_cachable = {} 49 | for w, idx in vocab.items(): 50 | if w in word2vec: 51 | word2vec_cachable[w] = list([float(x) for x in word2vec[w]]) 52 | if cache_file is not None: 53 | with open(cache_file, 'w') as f: 54 | f.write(json.dumps(word2vec_cachable)) 55 | else: 56 | with open(cache_file) as f: 57 | word2vec_cachable = json.loads(f.read()) 58 | word2vec = {w:np.array(v) for w, v in word2vec_cachable.items()} 59 | m_matrix = np.random.uniform(-.2, .2, size=(len(vocab), 300)) 60 | for w, idx in vocab.items(): 61 | if w in word2vec: 62 | m_matrix[idx, :] = word2vec[w] 63 | return m_matrix 64 | 65 | 66 | def text_to_matrix(captions, vocab, max_len=15, padding='post'): 67 | seqs = [] 68 | for c in captions: 69 | tokens = preprocess_caption(c).split() 70 | 71 | # for reasons I dont understand, the new version of CUDNN 72 | # doesn't play nice with padding, etc. After painstakingly 73 | # narrowing down why this happens, CUDNN errors when: 74 | 75 | # 1) you're using RNN 76 | # 2) your batch consists of fully-padded sequences and non-padded sequences only 77 | 78 | # I filed a tensorflow issue: 79 | # see https://github.com/tensorflow/tensorflow/issues/36139 80 | 81 | # Upon a tensorflower's reply, CuDNN currently doesn't like empty sequences, 82 | # and the CuDNN kernel only gets called when things are right-padded 83 | 84 | # so, for now, until this bug is fixed, so we can still use CuDNN: 85 | # 1) we will right/post-pad 86 | # 2) in the data iterator, for padding sequences, we will prepend with 87 | # an unk. These sentences don't affect the gradient, and we are 88 | # expecting CuDNN to return junk anyway in those cases, so this 89 | # should be fine, but I will experimentally verify 90 | idxs = [vocab[v] if v in vocab else vocab[''] for v in tokens] 91 | 92 | if len(idxs) == 0: 93 | warnings.warn( 94 | 'Wanring: detected at least one zero-length sentence. ' 95 | 'Running will continue, but check your inputs.') 96 | idxs = [vocab['']] 97 | 98 | seqs.append(idxs) 99 | m_mat = tf.keras.preprocessing.sequence.pad_sequences(seqs, maxlen=max_len, 100 | padding=padding, truncating='post', 101 | value=0) 102 | return m_mat 103 | -------------------------------------------------------------------------------- /train_doc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code to accompany 3 | "Unsupervised Discovery of Multimodal Links in Multi-Sentence/Multi-Image Documents." 4 | https://github.com/jmhessel/multi-retrieval 5 | 6 | This is a work-in-progress TF2.0 port. 7 | ''' 8 | import argparse 9 | import collections 10 | import json 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | import sys 15 | import tqdm 16 | import text_utils 17 | import image_utils 18 | import eval_utils 19 | import model_utils 20 | import training_utils 21 | import bipartite_utils 22 | import pickle 23 | import sklearn.preprocessing 24 | from pprint import pprint 25 | 26 | 27 | def load_data(fname): 28 | with open(fname) as f: 29 | return json.loads(f.read()) 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('documents', 35 | help='json of train/val/test documents.') 36 | parser.add_argument('--image_features', 37 | help='path to pre-extracted image-feature numpy array.') 38 | parser.add_argument('--image_id2row', 39 | help='path to mapping from image id --> numpy row for image features.') 40 | parser.add_argument('--joint_emb_dim', 41 | type=int, 42 | help='Embedding dimension of the shared, multimodal space.', 43 | default=1024) 44 | parser.add_argument('--margin', 45 | type=float, 46 | help='Margin for computing hinge loss.', 47 | default=.2) 48 | parser.add_argument('--seq_len', 49 | type=int, 50 | help='Maximum token sequence length for each sentence before truncation.', 51 | default=20) 52 | parser.add_argument('--docs_per_batch', 53 | type=int, 54 | help='How many docs per batch? 11 docs = 10 negative samples per doc.', 55 | default=11) 56 | parser.add_argument('--neg_mining', 57 | help='What type of negative mining?', 58 | default='hard_negative', 59 | choices=['negative_sample', 'hard_negative'], 60 | type=str) 61 | parser.add_argument('--sim_mode', 62 | help='What similarity function should we use?', 63 | default='AP', 64 | choices=['DC','TK','AP'], 65 | type=str) 66 | parser.add_argument('--sim_mode_k', 67 | help='If --sim_mode=TK/AP, what should the k be? ' 68 | 'k=-1 for dynamic = min(n_images, n_sentences))? ' 69 | 'if k > 0, then k=ceil(1./k * min(n_images, n_sentences))', 70 | default=-1, 71 | type=float) 72 | parser.add_argument('--lr', 73 | type=float, 74 | help='Starting learning rate', 75 | default=.0002) 76 | parser.add_argument('--n_epochs', 77 | type=int, 78 | help='How many epochs to run for?', 79 | default=60) 80 | parser.add_argument('--checkpoint_dir', 81 | type=str, 82 | help='What directory to save checkpoints in?', 83 | default='checkpoints') 84 | parser.add_argument('--word2vec_binary', 85 | type=str, 86 | help='If cached word embeddings have not been generated, ' 87 | 'what is the location of the word2vec binary?', 88 | default=None) 89 | parser.add_argument('--cached_word_embeddings', 90 | type=str, 91 | help='Where are/will the cached word embeddings saved?', 92 | default='cached_word2vec.json') 93 | parser.add_argument('--print_metrics', 94 | type=int, 95 | help='Should we print the metrics if there are ground-truth ' 96 | 'labels, or no?', 97 | default=0) 98 | parser.add_argument('--cached_vocab', 99 | type=str, 100 | help='Where should we cache the vocab, if anywhere ' 101 | '(None means no caching)', 102 | default=None) 103 | parser.add_argument('--output', 104 | type=str, 105 | default=None, 106 | help='If output is set, we will save a pkl file' 107 | 'with the validation/test metrics.') 108 | parser.add_argument('--seed', 109 | type=int, 110 | help='Random seed', 111 | default=1) 112 | parser.add_argument('--dropout', 113 | type=float, 114 | default=0.5, 115 | help='How much dropout should we apply?') 116 | parser.add_argument('--subsample_image', 117 | type=int, 118 | default=-1, 119 | help='Should we subsample images to constant lengths? ' 120 | 'This option is useful if the model is being trained end2end ' 121 | 'and there are memory issues.') 122 | parser.add_argument('--subsample_text', 123 | type=int, 124 | default=-1, 125 | help='Should we subsample sentences to constant lengths? ' 126 | 'This option is useful if the model is being trained end2end ' 127 | 'and there are memory issues.') 128 | parser.add_argument('--rnn_type', 129 | type=str, 130 | default='GRU', 131 | help='What RNN should we use') 132 | parser.add_argument('--end2end', 133 | type=int, 134 | default=0, 135 | help='Should we backprop through the whole vision network?') 136 | parser.add_argument('--image_dir', 137 | type=str, 138 | default=None, 139 | help='What image dir should we use, if end2end?') 140 | parser.add_argument('--lr_patience', 141 | type=int, 142 | default=3, 143 | help='What learning rate patience should we use?') 144 | parser.add_argument('--lr_decay', 145 | type=float, 146 | default=.2, 147 | help='What learning rate decay factor should we use?') 148 | parser.add_argument('--min_lr', 149 | type=float, 150 | default=.0000001, 151 | help='What learning rate decay factor should we use?') 152 | parser.add_argument('--full_image_paths', 153 | type=int, 154 | default=0, 155 | help='For end2end training, should we use full image paths ' 156 | '(i.e., is the file extention already on images?)?') 157 | parser.add_argument('--test_eval', 158 | type=int, 159 | help='(DEBUG OPTION) If test_eval >= 1, then training ' 160 | 'only happens over this many batches', 161 | default=-1) 162 | parser.add_argument('--force', 163 | type=int, 164 | default=0, 165 | help='Should we force the run if the output exists?') 166 | parser.add_argument('--save_predictions', 167 | type=str, 168 | default=None, 169 | help='Should we save the train/val/test predictions? ' 170 | 'If so --- they will be saved in this directory.') 171 | parser.add_argument('--image_model_checkpoint', 172 | type=str, 173 | default=None, 174 | help='If set, the image model will be initialized from ' 175 | 'this model checkpoint.') 176 | parser.add_argument('--text_model_checkpoint', 177 | type=str, 178 | default=None, 179 | help='If set, the text model will be initialized from ' 180 | 'this model checkpoint.') 181 | parser.add_argument('--loss_mode', 182 | help='What loss function should we use?', 183 | default='hinge', 184 | choices=['hinge', 'logistic', 'softmax'], 185 | type=str) 186 | parser.add_argument('--compute_mscoco_eval_metrics', 187 | help='Should we compute the mscoco MT metrics?', 188 | default=0, 189 | type=int) 190 | parser.add_argument('--compute_metrics_train', 191 | help='Should we also compute metrics over the training set?', 192 | default=1, 193 | type=int) 194 | parser.add_argument('--lr_warmup_steps', 195 | help='If positive value, we will warmup the learning rate linearly ' 196 | 'over this many steps.', 197 | default=-1, 198 | type=int) 199 | parser.add_argument('--l2_norm', 200 | help='If 1, we will l2 normalize extracted features, else, no normalization.', 201 | default=1, 202 | type=int) 203 | parser.add_argument('--n_layers', 204 | help='How many layers in the encoders?', 205 | default=1, 206 | type=int, 207 | choices=[1,2,3]) 208 | parser.add_argument('--scale_image_features', 209 | help='Should we standard scale image features?', 210 | default=0, 211 | type=int) 212 | args = parser.parse_args() 213 | 214 | # check to make sure that various flags are set correctly 215 | if args.end2end: 216 | assert args.image_dir is not None 217 | if not args.end2end: 218 | assert args.image_features is not None and args.image_id2row is not None 219 | 220 | # print out some info about the run's inputs/outputs 221 | if args.output and '.pkl' not in args.output: 222 | args.output += '.pkl' 223 | 224 | if args.output: 225 | print('Output will be saved to {}'.format(args.output)) 226 | print('Model checkpoints will be saved in {}'.format(args.checkpoint_dir)) 227 | 228 | if args.output and os.path.exists(args.output) and not args.force: 229 | print('{} already done! If you want to force it, set --force 1'.format(args.output)) 230 | quit() 231 | 232 | if not os.path.exists(args.checkpoint_dir): 233 | os.makedirs(args.checkpoint_dir) 234 | 235 | if args.save_predictions: 236 | if not os.path.exists(args.checkpoint_dir): 237 | os.makedirs(args.checkpoint_dir) 238 | os.makedirs(args.checkpoint_dir + '/train') 239 | os.makedirs(args.checkpoint_dir + '/val') 240 | os.makedirs(args.checkpoint_dir + '/test') 241 | 242 | return args 243 | 244 | 245 | def main(): 246 | args = parse_args() 247 | np.random.seed(args.seed) 248 | 249 | data = load_data(args.documents) 250 | train, val, test = data['train'], data['val'], data['test'] 251 | 252 | np.random.shuffle(train); np.random.shuffle(val); np.random.shuffle(test) 253 | max_n_sentence, max_n_image = -1, -1 254 | for d in train + val + test: 255 | imgs, sents, meta = d 256 | max_n_sentence = max(max_n_sentence, len(sents)) 257 | max_n_image = max(max_n_image, len(imgs)) 258 | 259 | # remove zero image/zero sentence cases: 260 | before_lens = list(map(len, [train, val, test])) 261 | 262 | train = [t for t in train if len(t[0]) > 0 and len(t[1]) > 0] 263 | val = [t for t in val if len(t[0]) > 0 and len(t[1]) > 0] 264 | test = [t for t in test if len(t[0]) > 0 and len(t[1]) > 0] 265 | 266 | after_lens = list(map(len, [train, val, test])) 267 | for bl, al, split in zip(before_lens, after_lens, ['train', 'val', 'test']): 268 | if bl == al: continue 269 | print('Removed {} documents from {} split that had zero images and/or sentences'.format( 270 | bl-al, split)) 271 | 272 | print('Max n sentence={}, max n image={}'.format(max_n_sentence, max_n_image)) 273 | if args.cached_vocab: 274 | print('Saving/loading vocab from {}'.format(args.cached_vocab)) 275 | 276 | # create vocab from training documents: 277 | flattened_train_sents = [] 278 | for _, sents, _ in train: 279 | flattened_train_sents.extend([s[0] for s in sents]) 280 | word2idx = text_utils.get_vocab(flattened_train_sents, cached=args.cached_vocab) 281 | print('Vocab size was {}'.format(len(word2idx))) 282 | 283 | if args.word2vec_binary: 284 | we_init = text_utils.get_word2vec_matrix( 285 | word2idx, args.cached_word_embeddings, args.word2vec_binary) 286 | else: 287 | we_init = np.random.uniform(low=-.02, high=.02, size=(len(word2idx), 300)) 288 | 289 | if args.end2end: 290 | image_features = None 291 | image_idx2row = None 292 | else: 293 | image_features = np.load(args.image_features) 294 | image_idx2row = load_data(args.image_id2row) 295 | 296 | if args.scale_image_features: 297 | ss = sklearn.preprocessing.StandardScaler() 298 | all_train_images = [] 299 | for img, txt, cid in train: 300 | all_train_images.extend([x[0] for x in img]) 301 | print('standard scaling with {} images total'.format(len(all_train_images))) 302 | all_train_rows = [image_idx2row[cid] for cid in all_train_images] 303 | ss.fit(image_features[np.array(all_train_rows)]) 304 | image_features = ss.transform(image_features) 305 | 306 | word_emb_dim = 300 307 | 308 | if val[0][0][0][1] is not None: 309 | ground_truth = True 310 | print('The input has ground truth, so AUC will be computed.') 311 | else: 312 | ground_truth = False 313 | 314 | # Step 1: Specify model inputs/outputs: 315 | 316 | # (n docs, n sent, max n words,) 317 | text_inp = tf.keras.layers.Input((None, args.seq_len)) 318 | 319 | # this input tells you how many sentences are really in each doc 320 | text_n_inp = tf.keras.layers.Input((1,), dtype='int32') 321 | if args.end2end: 322 | # (n docs, n image, x, y, color) 323 | img_inp = tf.keras.layers.Input((None, 224, 224, 3)) 324 | else: 325 | # (n docs, n image, feature dim) 326 | img_inp = tf.keras.layers.Input((None, image_features.shape[1])) 327 | # this input tells you how many images are really in each doc 328 | img_n_inp = tf.keras.layers.Input((1,), dtype='int32') 329 | 330 | # Step 2: Define transformations to shared multimodal space. 331 | 332 | # Step 2.1: The text model: 333 | if args.text_model_checkpoint: 334 | print('Loading pretrained text model from {}'.format( 335 | args.text_model_checkpoint)) 336 | single_text_doc_model = tf.keras.models.load_model(args.text_model_checkpoint) 337 | extracted_text_features = single_text_doc_model(text_inp) 338 | else: 339 | word_embedding = tf.keras.layers.Embedding( 340 | len(word2idx), 341 | word_emb_dim, 342 | weights=[we_init] if we_init is not None else None, 343 | mask_zero=True) 344 | element_dropout = tf.keras.layers.SpatialDropout1D(args.dropout) 345 | 346 | if args.rnn_type == 'GRU': 347 | rnn_maker = tf.keras.layers.GRU 348 | else: 349 | rnn_maker = tf.keras.layers.LSTM 350 | 351 | enc_layers = [] 352 | for idx in range(args.n_layers): 353 | if idx == args.n_layers-1: 354 | enc_layers.append(rnn_maker(args.joint_emb_dim)) 355 | else: 356 | enc_layers.append(rnn_maker(args.joint_emb_dim, return_sequences=True)) 357 | 358 | embedded_text_inp = word_embedding(text_inp) 359 | extracted_text_features = tf.keras.layers.TimeDistributed(element_dropout)(embedded_text_inp) 360 | 361 | for l in enc_layers: 362 | extracted_text_features = tf.keras.layers.TimeDistributed(l)(extracted_text_features) 363 | 364 | # extracted_text_features is now (n docs, max n setnences, multimodal dim) 365 | if args.l2_norm: 366 | l2_norm_layer = tf.keras.layers.Lambda(lambda x: tf.nn.l2_normalize(x, axis=-1)) 367 | extracted_text_features = l2_norm_layer(extracted_text_features) 368 | 369 | single_text_doc_model = tf.keras.models.Model( 370 | inputs=text_inp, 371 | outputs=extracted_text_features) 372 | 373 | # Step 2.2: The image model: 374 | if args.image_model_checkpoint: 375 | print('Loading pretrained image model from {}'.format( 376 | args.image_model_checkpoint)) 377 | single_img_doc_model = tf.keras.models.load_model(args.image_model_checkpoint) 378 | extracted_img_features = single_img_doc_model(img_inp) 379 | else: 380 | if args.end2end: 381 | img_projection = tf.keras.layers.Dense(args.joint_emb_dim) 382 | from tf.keras.applications.nasnet import NASNetMobile 383 | cnn = tf.keras.applications.nasnet.NASNetMobile( 384 | include_top=False, input_shape=(224, 224, 3), pooling='avg') 385 | 386 | extracted_img_features = tf.keras.layers.TimeDistributed(cnn)(img_inp) 387 | if args.dropout > 0.0: 388 | extracted_img_features = tf.keras.layers.TimeDistributed( 389 | tf.keras.layers.Dropout(args.dropout))(extracted_img_features) 390 | extracted_img_features = keras.layers.TimeDistributed(img_projection)( 391 | extracted_img_features) 392 | else: 393 | extracted_img_features = tf.keras.layers.Masking()(img_inp) 394 | if args.dropout > 0.0: 395 | extracted_img_features = tf.keras.layers.TimeDistributed( 396 | tf.keras.layers.Dropout(args.dropout))(extracted_img_features) 397 | 398 | enc_layers = [] 399 | for idx in range(args.n_layers): 400 | if idx == args.n_layers-1: 401 | enc_layers.append(tf.keras.layers.Dense(args.joint_emb_dim)) 402 | else: 403 | enc_layers.append(tf.keras.layers.Dense(args.joint_emb_dim, activation='relu')) 404 | enc_layers.append(tf.keras.layers.BatchNormalization()) 405 | 406 | for l in enc_layers: 407 | extracted_img_features = tf.keras.layers.TimeDistributed(l)(extracted_img_features) 408 | 409 | # extracted_img_features is now (n docs, max n images, multimodal dim) 410 | if args.l2_norm: 411 | l2_norm_layer = tf.keras.layers.Lambda(lambda x: tf.nn.l2_normalize(x, axis=-1)) 412 | extracted_img_features = l2_norm_layer(extracted_img_features) 413 | 414 | single_img_doc_model = tf.keras.models.Model( 415 | inputs=img_inp, 416 | outputs=extracted_img_features) 417 | 418 | # Step 3: Extract/stack the non-padding image/sentence representations 419 | def mask_slice_and_stack(inp): 420 | stacker = [] 421 | features, n_inputs = inp 422 | n_inputs = tf.dtypes.cast(n_inputs, tf.int32) 423 | # for each document, we will extract the portion of input features that are not padding 424 | # this means, for features[doc_idx], we will take the first n_inputs[doc_idx] rows. 425 | # we stack them into one big array so we can do a big cosine sim dot product between all 426 | # sentence image pairs in parallel. We'll slice up this array back up later. 427 | for idx in range(args.docs_per_batch): 428 | cur_valid_idxs = tf.range(n_inputs[idx,0]) 429 | cur_valid_features = features[idx] 430 | feats = tf.gather(cur_valid_features, cur_valid_idxs) 431 | stacker.append(feats) 432 | return tf.concat(stacker, axis=0) 433 | 434 | # extracted text/img features are (n_docs, max_in_seq, dim) 435 | # we want to compute cosine sims between all (sent, img) pairs quickly 436 | # so we will stack them into new tensors ... 437 | # text_enc has shape (total number of sent in batch, dim) 438 | # img_enc has shape (total number of image in batch, dim) 439 | text_enc = mask_slice_and_stack([extracted_text_features, text_n_inp]) 440 | img_enc = mask_slice_and_stack([extracted_img_features, img_n_inp]) 441 | 442 | def DC_sim(sim_matrix): 443 | text2im_S = tf.reduce_mean(tf.reduce_max(sim_matrix, 1)) 444 | im2text_S = tf.reduce_mean(tf.reduce_max(sim_matrix, 0)) 445 | return text2im_S + im2text_S 446 | 447 | def get_k(sim_matrix): 448 | k = tf.minimum(tf.shape(sim_matrix)[0], tf.shape(sim_matrix)[1]) 449 | if args.sim_mode_k > 0: 450 | k = tf.dtypes.cast(k, tf.float32) 451 | k = tf.math.ceil(tf.div(k, args.sim_mode_k)) 452 | k = tf.dtypes.cast(k, tf.int32) 453 | return k 454 | 455 | def TK_sim(sim_matrix): 456 | k = get_k(sim_matrix) 457 | im2text_S, text2im_S = tf.reduce_max(sim_matrix, 0), tf.reduce_max(sim_matrix, 1) 458 | text2im_S = tf.reduce_mean(tf.math.top_k(text2im_S, k=k)[0], axis=-1) 459 | im2text_S = tf.reduce_mean(tf.math.top_k(im2text_S, k=k)[0], axis=-1) 460 | return text2im_S + im2text_S 461 | 462 | bipartite_match_fn = bipartite_utils.generate_fast_hungarian_solving_function() 463 | def AP_sim(sim_matrix): 464 | k = get_k(sim_matrix) 465 | sol = tf.numpy_function(bipartite_match_fn, [sim_matrix, k], tf.int32) 466 | return tf.reduce_mean(tf.gather_nd(sim_matrix, sol)) 467 | 468 | if args.sim_mode == 'DC': 469 | sim_fn = DC_sim 470 | elif args.sim_mode == 'TK': 471 | sim_fn = TK_sim 472 | elif args.sim_mode == 'AP': 473 | sim_fn = AP_sim 474 | else: 475 | raise NotImplementedError('{} is not implemented sim function'.format(args.sim_fn)) 476 | 477 | def make_sims(inp): 478 | sims = tf.keras.backend.dot(inp[0], tf.keras.backend.transpose(inp[1])) 479 | return sims 480 | 481 | all_sims = make_sims([text_enc, img_enc]) 482 | get_pos_neg_sims = model_utils.make_get_pos_neg_sims( 483 | args, 484 | sim_fn) 485 | 486 | pos_sims, neg_img_sims, neg_text_sims = tf.keras.layers.Lambda( 487 | get_pos_neg_sims)([all_sims, text_n_inp, img_n_inp]) 488 | 489 | if args.loss_mode == 'hinge': 490 | def per_neg_loss(inp): 491 | pos_s, neg_s = inp 492 | return tf.math.maximum(neg_s - pos_s + args.margin, 0) 493 | elif args.loss_mode == 'logistic': 494 | def per_neg_loss(inp): 495 | pos_s, neg_s = inp 496 | return tf.nn.sigmoid_cross_entropy_with_logits( 497 | labels=tf.ones_like(neg_s), 498 | logits=pos_s - neg_s) 499 | elif args.loss_mode == 'softmax': 500 | def per_neg_loss(inp): 501 | pos_s, neg_s = inp 502 | pos_s -= args.margin 503 | pos_l, neg_l = tf.ones_like(pos_s), tf.zeros_like(neg_s) 504 | return tf.nn.softmax_cross_entropy_with_logits( 505 | tf.concat([pos_l, neg_l], axis=1), 506 | tf.concat([pos_s, neg_s], axis=1)) 507 | 508 | neg_img_losses = per_neg_loss([pos_sims, neg_img_sims]) 509 | neg_text_losses = per_neg_loss([pos_sims, neg_text_sims]) 510 | 511 | if args.loss_mode != 'softmax': 512 | if args.neg_mining == 'negative_sample': 513 | pool_fn = lambda x: tf.reduce_mean(x, axis=1, keepdims=True) 514 | elif args.neg_mining == 'hard_negative': 515 | pool_fn = lambda x: tf.reduce_max(x, axis=1, keepdims=True) 516 | else: 517 | raise NotImplementedError('{} is not a valid for args.neg_mining'.format( 518 | args.neg_mining)) 519 | 520 | neg_img_loss = tf.keras.layers.Lambda(pool_fn, name='neg_img')(neg_img_losses) 521 | neg_text_loss = tf.keras.layers.Lambda(pool_fn, name='neg_text')(neg_text_losses) 522 | else: 523 | neg_img_loss = neg_img_losses 524 | neg_text_loss = neg_text_losses 525 | 526 | inputs = [text_inp, 527 | img_inp, 528 | text_n_inp, 529 | img_n_inp] 530 | 531 | model = tf.keras.models.Model(inputs=inputs, 532 | outputs=[neg_img_loss, neg_text_loss]) 533 | 534 | opt = tf.keras.optimizers.Adam(args.lr) 535 | 536 | def identity(y_true, y_pred): 537 | del y_true 538 | return tf.reduce_mean(y_pred, axis=-1) 539 | 540 | model.compile(opt, loss=identity) 541 | 542 | if args.test_eval > 0: 543 | train = train[:args.test_eval * args.docs_per_batch] 544 | val = val[:args.test_eval * args.docs_per_batch] 545 | test = test[:args.test_eval * args.docs_per_batch] 546 | 547 | train_seq = training_utils.DocumentSequence( 548 | train, 549 | image_features, 550 | image_idx2row, 551 | max_n_sentence, 552 | max_n_image, 553 | word2idx, 554 | args=args, 555 | shuffle_docs=True, 556 | shuffle_sentences=False, 557 | shuffle_images=True) 558 | 559 | val_seq = training_utils.DocumentSequence( 560 | val, 561 | image_features, 562 | image_idx2row, 563 | max_n_sentence, 564 | max_n_image, 565 | word2idx, 566 | args=args, 567 | augment=False, 568 | shuffle_sentences=False, 569 | shuffle_docs=False, 570 | shuffle_images=False) 571 | 572 | sdm = training_utils.SaveDocModels( 573 | args.checkpoint_dir, 574 | single_text_doc_model, 575 | single_img_doc_model) 576 | 577 | if args.loss_mode == 'hinge': 578 | val_loss_thresh = 2 * args.margin # constant prediction performance 579 | else: 580 | val_loss_thresh = np.inf 581 | 582 | reduce_lr = training_utils.ReduceLROnPlateauAfterValLoss( 583 | activation_val_loss=val_loss_thresh, 584 | factor=args.lr_decay, 585 | patience=args.lr_patience, 586 | min_lr=args.min_lr, 587 | verbose=True) 588 | 589 | callbacks = [reduce_lr, sdm] 590 | 591 | if args.print_metrics: 592 | metrics_printer = training_utils.PrintMetrics( 593 | val, 594 | image_features, 595 | image_idx2row, 596 | word2idx, 597 | single_text_doc_model, 598 | single_img_doc_model, 599 | args) 600 | callbacks.append(metrics_printer) 601 | 602 | if args.lr_warmup_steps > 0: 603 | warmup_lr = training_utils.LearningRateLinearIncrease( 604 | args.lr, 605 | args.lr_warmup_steps) 606 | callbacks.append(warmup_lr) 607 | 608 | history = model.fit( 609 | train_seq, 610 | epochs=args.n_epochs, 611 | validation_data=val_seq, 612 | callbacks=callbacks) 613 | 614 | if args.output: 615 | best_image_model_str, best_sentence_model_str, best_logs, best_epoch = sdm.best_checkpoints_and_logs 616 | 617 | single_text_doc_model = tf.keras.models.load_model(best_sentence_model_str) 618 | single_image_doc_model = tf.keras.models.load_model(best_image_model_str) 619 | 620 | if args.scale_image_features: 621 | with open(args.checkpoint_dir + '/image_standardscaler.pkl', 'wb') as f: 622 | pickle.dump(ss, f) 623 | 624 | if ground_truth and args.compute_metrics_train: 625 | train_aucs, train_match_metrics, train_mt_metrics = eval_utils.compute_match_metrics_doc( 626 | train, 627 | image_features, 628 | image_idx2row, 629 | word2idx, 630 | single_text_doc_model, 631 | single_img_doc_model, 632 | args) 633 | else: 634 | train_aucs, train_match_metrics, train_mt_metrics = None, None, None 635 | 636 | 637 | if ground_truth: 638 | val_aucs, val_match_metrics, val_mt_metrics = eval_utils.compute_match_metrics_doc( 639 | val, 640 | image_features, 641 | image_idx2row, 642 | word2idx, 643 | single_text_doc_model, 644 | single_img_doc_model, 645 | args) 646 | 647 | test_aucs, test_match_metrics, test_mt_metrics = eval_utils.compute_match_metrics_doc( 648 | test, 649 | image_features, 650 | image_idx2row, 651 | word2idx, 652 | single_text_doc_model, 653 | single_img_doc_model, 654 | args) 655 | 656 | else: 657 | train_aucs, val_aucs, test_aucs = None, None, None 658 | train_match_metrics, val_match_metrics, test_match_metrics = None, None, None 659 | train_mt_metrics, val_mt_metrics, test_mt_metrics = None, None, None 660 | 661 | 662 | output = {'logs':best_logs, 663 | 664 | 'best_sentence_model_str':best_sentence_model_str, 665 | 'best_image_model_str':best_image_model_str, 666 | 667 | 'train_aucs':train_aucs, 668 | 'train_match_metrics':train_match_metrics, 669 | 'train_mt_metrics':train_mt_metrics, 670 | 671 | 'val_aucs':val_aucs, 672 | 'val_match_metrics':val_match_metrics, 673 | 'val_mt_metrics':val_mt_metrics, 674 | 675 | 'test_aucs':test_aucs, 676 | 'test_match_metrics':test_match_metrics, 677 | 'test_mt_metrics':test_mt_metrics, 678 | 679 | 'args':args, 680 | 'epoch':best_epoch} 681 | 682 | if args.scale_image_features: 683 | output['image_standard_scaler_str'] = args.checkpoint_dir + '/image_standardscaler.pkl' 684 | 685 | for k, v in history.history.items(): 686 | output['history_{}'.format(k)] = v 687 | if args.print_metrics: 688 | for k, v in metrics_printer.history.items(): 689 | output['metrics_history_{}'.format(k)] = v 690 | 691 | with open(args.output, 'wb') as f: 692 | pickle.dump(output, f, protocol=pickle.HIGHEST_PROTOCOL) 693 | print('saved output to {}'.format(args.output)) 694 | 695 | if args.save_predictions: 696 | for d, name in zip([train, val, test], ['train', 'val', 'test']): 697 | out_dir = args.save_predictions + '/' + name 698 | if not os.path.exists(out_dir): 699 | os.makedirs(out_dir) 700 | eval_utils.save_predictions( 701 | d, 702 | image_features, 703 | image_idx2row, 704 | word2idx, 705 | single_text_doc_model, 706 | single_img_doc_model, 707 | out_dir, 708 | args) 709 | 710 | 711 | if __name__ == '__main__': 712 | main() 713 | -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import eval_utils 4 | import text_utils 5 | import image_utils 6 | 7 | 8 | class DocumentSequence(tf.keras.utils.Sequence): 9 | def __init__(self, 10 | data_in, 11 | image_matrix, 12 | image_idx2row, 13 | max_sentences_per_doc, 14 | max_images_per_doc, 15 | vocab, 16 | args=None, 17 | augment=True, 18 | shuffle_sentences=False, 19 | shuffle_images=True, 20 | shuffle_docs=True): 21 | self.data_in = data_in 22 | self.image_matrix = image_matrix 23 | self.image_idx2row = image_idx2row 24 | self.max_sentences_per_doc = max_sentences_per_doc 25 | self.max_images_per_doc = max_images_per_doc 26 | self.vocab = vocab 27 | self.args = args 28 | self.argument = augment 29 | self.shuffle_sentences = shuffle_sentences 30 | self.shuffle_images = shuffle_images 31 | self.shuffle_docs = shuffle_docs 32 | 33 | def __len__(self): 34 | return int(np.ceil(len(self.data_in) / self.args.docs_per_batch)) 35 | 36 | def __getitem__(self, idx): 37 | start = idx * self.args.docs_per_batch 38 | end = (idx + 1) * self.args.docs_per_batch 39 | cur_doc_b = self.data_in[start: end] 40 | 41 | if idx == len(self) - 1: # final batch may have wrong number of docs 42 | docs_to_add = self.args.docs_per_batch - len(cur_doc_b) 43 | cur_doc_b += self.data_in[:docs_to_add] 44 | 45 | images, texts = [], [] 46 | image_n_docs, text_n_docs = [], [] 47 | for idx, vers in enumerate(cur_doc_b): 48 | cur_images = [img[0] for img in vers[0]] 49 | cur_text = [text[0] for text in vers[1]] 50 | 51 | if self.shuffle_sentences and not (self.args and self.args.subsample_text > 0): 52 | np.random.shuffle(cur_text) 53 | 54 | if self.shuffle_images and not (self.args and self.args.subsample_image > 0): 55 | np.random.shuffle(cur_images) 56 | 57 | if self.args and self.args.subsample_image > 0: 58 | np.random.shuffle(cur_images) 59 | cur_images = cur_images[:self.args.subsample_image] 60 | 61 | if self.args and self.args.subsample_text > 0: 62 | np.random.shuffle(cur_text) 63 | cur_text = cur_text[:self.args.subsample_text] 64 | 65 | if self.args.end2end: 66 | cur_images = image_utils.images_to_images(cur_images, augment, args) 67 | if self.args and self.args.subsample_image > 0: 68 | image_padding = np.zeros( 69 | (self.args.subsample_image - cur_images.shape[0], 224, 224, 3)) 70 | else: 71 | image_padding = np.zeros( 72 | (self.max_images_per_doc - cur_images.shape[0], 224, 224, 3)) 73 | else: 74 | cur_images = image_utils.images_to_matrix( 75 | cur_images, self.image_matrix, self.image_idx2row) 76 | if self.args and self.args.subsample_image > 0: 77 | image_padding = np.zeros( 78 | (self.args.subsample_image - cur_images.shape[0], cur_images.shape[-1])) 79 | else: 80 | image_padding = np.zeros( 81 | (self.max_images_per_doc - cur_images.shape[0], cur_images.shape[-1])) 82 | 83 | cur_text = text_utils.text_to_matrix(cur_text, self.vocab, max_len=self.args.seq_len) 84 | image_n_docs.append(cur_images.shape[0]) 85 | text_n_docs.append(cur_text.shape[0]) 86 | 87 | if self.args and self.args.subsample_text > 0: 88 | text_padding = np.zeros( 89 | (self.args.subsample_text - cur_text.shape[0], cur_text.shape[-1])) 90 | else: 91 | text_padding = np.zeros( 92 | (self.max_sentences_per_doc - cur_text.shape[0], cur_text.shape[-1])) 93 | 94 | # cudnn cant do empty sequences, so for now, I will just put an UNK in front of all sequences. 95 | # see comment in text_utils.py for more information. 96 | text_padding[:, 0] = 1 97 | 98 | cur_images = np.vstack([cur_images, image_padding]) 99 | cur_text = np.vstack([cur_text, text_padding]) 100 | 101 | cur_images = np.expand_dims(cur_images, 0) 102 | cur_text = np.expand_dims(cur_text, 0) 103 | 104 | images.append(cur_images) 105 | texts.append(cur_text) 106 | 107 | images = np.vstack(images) 108 | texts = np.vstack(texts) 109 | 110 | image_n_docs = np.expand_dims(np.array(image_n_docs), -1) 111 | text_n_docs = np.expand_dims(np.array(text_n_docs), -1) 112 | 113 | y = [np.zeros(len(text_n_docs)), np.zeros(len(image_n_docs))] 114 | 115 | return ([texts, 116 | images, 117 | text_n_docs, 118 | image_n_docs], y) 119 | 120 | 121 | def on_epoch_end(self): 122 | if self.shuffle_docs: 123 | np.random.shuffle(self.data_in) 124 | 125 | 126 | class SaveDocModels(tf.keras.callbacks.Callback): 127 | 128 | def __init__(self, 129 | checkpoint_dir, 130 | single_text_doc_model, 131 | single_image_doc_model): 132 | super(SaveDocModels, self).__init__() 133 | self.checkpoint_dir = checkpoint_dir 134 | self.single_text_doc_model = single_text_doc_model 135 | self.single_image_doc_model = single_image_doc_model 136 | 137 | 138 | def on_train_begin(self, logs={}): 139 | self.best_val_loss = np.inf 140 | self.best_checkpoints_and_logs = None 141 | 142 | def on_epoch_end(self, epoch, logs): 143 | if logs['val_loss'] < self.best_val_loss: 144 | print('New best val loss: {:.5f}'.format(logs['val_loss'])) 145 | self.best_val_loss = logs['val_loss'] 146 | else: 147 | return 148 | image_model_str = self.checkpoint_dir + '/image_model_epoch_{}_val={:.5f}.model'.format(epoch, logs['val_loss']) 149 | sentence_model_str = self.checkpoint_dir + '/text_model_epoch_{}_val={:.5f}.model'.format(epoch, logs['val_loss']) 150 | self.best_checkpoints_and_logs = (image_model_str, sentence_model_str, logs, epoch) 151 | 152 | self.single_text_doc_model.save(sentence_model_str, overwrite=True, save_format='h5') 153 | self.single_image_doc_model.save(image_model_str, overwrite=True, save_format='h5') 154 | 155 | 156 | 157 | class ReduceLROnPlateauAfterValLoss(tf.keras.callbacks.ReduceLROnPlateau): 158 | ''' 159 | Delays the normal operation of ReduceLROnPlateau until the validation 160 | loss reaches a given value. 161 | ''' 162 | def __init__(self, activation_val_loss=np.inf, *args, **kwargs): 163 | super(ReduceLROnPlateauAfterValLoss, self).__init__(*args, **kwargs) 164 | self.activation_val_loss = activation_val_loss 165 | self.val_threshold_activated = False 166 | 167 | def in_cooldown(self): 168 | if not self.val_threshold_activated: # check to see if we should activate 169 | if self.current_logs['val_loss'] < self.activation_val_loss: 170 | print('Current validation loss ({}) less than activation val loss ({})'. 171 | format(self.current_logs['val_loss'], 172 | self.activation_val_loss)) 173 | print('Normal operation of val LR reduction started.') 174 | self.val_threshold_activated = True 175 | self._reset() 176 | 177 | return self.cooldown_counter > 0 or not self.val_threshold_activated 178 | 179 | def on_epoch_end(self, epoch, logs=None): 180 | self.current_logs = logs 181 | super(ReduceLROnPlateauAfterValLoss, self).on_epoch_end(epoch, logs=logs) 182 | 183 | 184 | class PrintMetrics(tf.keras.callbacks.Callback): 185 | def __init__(self, 186 | val, 187 | image_features, 188 | image_idx2row, 189 | word2idx, 190 | single_text_doc_model, 191 | single_img_doc_model, 192 | args): 193 | super(PrintMetrics, self).__init__() 194 | self.val = val 195 | self.image_features = image_features 196 | self.image_idx2row = image_idx2row 197 | self.word2idx = word2idx 198 | self.single_text_doc_model = single_text_doc_model 199 | self.single_img_doc_model = single_img_doc_model 200 | self.args = args 201 | 202 | def on_train_begin(self, logs=None): 203 | self.epoch = [] 204 | self.history = {} 205 | 206 | def on_epoch_end(self, epoch, logs): 207 | metrics = eval_utils.print_all_metrics( 208 | self.val, 209 | self.image_features, 210 | self.image_idx2row, 211 | self.word2idx, 212 | self.single_text_doc_model, 213 | self.single_img_doc_model, 214 | self.args) 215 | self.epoch.append(epoch) 216 | for k, v in metrics.items(): 217 | self.history.setdefault(k, []).append(v) 218 | 219 | 220 | class LearningRateLinearIncrease(tf.keras.callbacks.Callback): 221 | def __init__(self, max_lr, warmup_steps, verbose=0): 222 | super(LearningRateLinearIncrease, self).__init__() 223 | self.max_lr = max_lr 224 | self.warmup_steps = warmup_steps 225 | self.verbose = verbose 226 | self.cur_step_count = 0 227 | 228 | def on_train_begin(self, logs=None): 229 | tf.keras.backend.set_value(self.model.optimizer.lr, 0.0) 230 | 231 | def on_batch_begin(self, batch, logs=None): 232 | if self.cur_step_count >= self.warmup_steps: 233 | return 234 | lr = float(tf.keras.backend.get_value(self.model.optimizer.lr)) 235 | lr += 1./self.warmup_steps * self.max_lr 236 | if self.verbose and self.cur_step_count % 50 == 0: 237 | print('\n new LR = {}\n'.format(lr)) 238 | tf.keras.backend.set_value(self.model.optimizer.lr, lr) 239 | self.cur_step_count += 1 240 | -------------------------------------------------------------------------------- /visualize_predictions_graph.py: -------------------------------------------------------------------------------- 1 | ''' 2 | for i in predictions/test/*; do python visualize_predictions.py $i\/doc.json $i/pred_weights.npy prediction_dir/$i ; done; 3 | ''' 4 | import argparse 5 | import numpy as np 6 | import bipartite_utils 7 | import json 8 | import os 9 | import subprocess 10 | import matplotlib.pyplot as plt 11 | 12 | from sklearn.metrics import roc_auc_score 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('document') 17 | parser.add_argument('predictions') 18 | parser.add_argument('output') 19 | parser.add_argument('--n_to_show', default=5) 20 | return parser.parse_args() 21 | 22 | 23 | def call(x): 24 | subprocess.call(x, shell=True) 25 | 26 | 27 | def main(): 28 | args = parse_args() 29 | pred_adj = np.load(args.predictions) 30 | with open(args.document) as f: 31 | data = json.loads(f.read()) 32 | 33 | images, text = data[0], data[1] 34 | 35 | solve_fn = bipartite_utils.generate_fast_hungarian_solving_function() 36 | sol = solve_fn(pred_adj, args.n_to_show) 37 | scores = pred_adj[sol[:,0], sol[:,1]] 38 | 39 | true_adj = np.zeros((len(text), len(images))) 40 | for text_idx, t in enumerate(text): 41 | if t[1] == -1: continue 42 | true_adj[text_idx, t[1]] = 1 43 | for image_idx, t in enumerate(images): 44 | if t[1] == -1: continue 45 | true_adj[t[1], image_idx] = 1 46 | 47 | auc = 100 * roc_auc_score(true_adj.flatten(), 48 | pred_adj.flatten()) 49 | print('AUC: {:.2f} {}'.format(auc, 50 | data[-1])) 51 | 52 | ordered_images, ordered_sentences = [], [] 53 | for img_idx, sent_idx, sc in sorted( 54 | zip(sol[:,1], sol[:,0], scores), key=lambda x:-x[-1])[:args.n_to_show]: 55 | ordered_images.append(img_idx) 56 | ordered_sentences.append(sent_idx) 57 | print(sc) 58 | 59 | pred_adj_subgraph = pred_adj[np.array(ordered_sentences),:][:,np.array(ordered_images)] 60 | true_adj_subgraph = true_adj[np.array(ordered_sentences),:][:,np.array(ordered_images)] 61 | selected_images = [images[img_idx][0] for img_idx in ordered_images] 62 | selected_sentences = [text[sent_idx][0] for sent_idx in ordered_sentences] 63 | 64 | # normalize predicted sims to have max 1 and min 0 65 | # first, clip out negative values 66 | pred_adj_subgraph = np.clip(pred_adj_subgraph, 0, 1.0) 67 | pred_adj_subgraph -= np.min(pred_adj_subgraph.flatten()) 68 | pred_adj_subgraph /= np.max(pred_adj_subgraph.flatten()) 69 | assert np.min(pred_adj_subgraph.flatten()) == 0.0 70 | assert np.max(pred_adj_subgraph.flatten()) == 1.0 71 | 72 | print(pred_adj_subgraph.shape) 73 | print(ordered_images) 74 | print(ordered_sentences) 75 | print(selected_images) 76 | print(selected_sentences) 77 | 78 | # each line has ((x1, y1, x2, y2), strength, correctness) 79 | # images go above text 80 | lines_to_plot = [] 81 | image_text_gap = 2 82 | same_mode_gap = 2 83 | offdiag_alpha_mul = .5 84 | 85 | def cosine_to_width(cos, exp=2.0, maxwidth=8.0): 86 | return cos**exp * maxwidth 87 | def cosine_to_alpha(cos, exp=1/2., maxalpha=1.0): 88 | return cos**exp * maxalpha 89 | 90 | correct_color, incorrect_color = '#1b7837', '#762a83' 91 | lines_to_plot = [] 92 | for text_idx in range(args.n_to_show): 93 | for image_idx in range(args.n_to_show): 94 | coords = (text_idx*same_mode_gap, 0, image_idx*same_mode_gap, image_text_gap) 95 | strength = max(pred_adj_subgraph[text_idx, image_idx], 0) 96 | correctness = true_adj_subgraph[text_idx, image_idx] == 1 97 | lines_to_plot.append((coords, strength, correctness)) 98 | 99 | plt.figure(figsize=(args.n_to_show*same_mode_gap, image_text_gap)) 100 | for (x1, y1, x2, y2), strength, correct in sorted(lines_to_plot, 101 | key=lambda x: x[1]): 102 | if x1 == x2: continue 103 | plt.plot([x1, x2], [y1, y2], 104 | linewidth=cosine_to_width(strength), 105 | alpha=cosine_to_alpha(strength) * offdiag_alpha_mul, 106 | color=correct_color if correct else incorrect_color) 107 | for (x1, y1, x2, y2), strength, correct in sorted(lines_to_plot, 108 | key=lambda x: x[1]): 109 | if x1 != x2: continue 110 | plt.plot([x1, x2], [y1, y2], 111 | linewidth=cosine_to_width(strength), 112 | color=correct_color if correct else incorrect_color) 113 | plt.axis('off') 114 | plt.tight_layout() 115 | 116 | if not os.path.exists(args.output): 117 | os.makedirs(args.output) 118 | with open(args.output + '/sentences.txt', 'w') as f: 119 | f.write('\n'.join([' '.join(s.split()) for s in selected_sentences])) 120 | with open(args.output + '/images.txt', 'w') as f: 121 | f.write('\n'.join(selected_images)) 122 | with open(args.output + '/all_sentences.txt', 'w') as f: 123 | f.write('\n'.join([' '.join(s[0].split()) for s in text])) 124 | with open(args.output + '/all_images.txt', 'w') as f: 125 | f.write('\n'.join([x[0] for x in images])) 126 | with open(args.output + '/auc.txt', 'w') as f: 127 | f.write('{:.4f}'.format(auc)) 128 | plt.savefig(args.output + '/graph.png', dpi=300) 129 | call('convert {} -trim {}'.format(args.output + '/graph.png', 130 | args.output + '/graph_cropped.png')) 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /visualize_predictions_html.py: -------------------------------------------------------------------------------- 1 | ''' 2 | for i in predictions/train/*; do python visualize_predictions_html.py $i\/doc.json $i/pred_weights.npy $PWD/data/wiki/ prediction_dir/$i ; done; 3 | ''' 4 | import argparse 5 | import numpy as np 6 | import bipartite_utils 7 | import json 8 | import os 9 | import subprocess 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('document') 15 | parser.add_argument('predictions') 16 | parser.add_argument('image_dir') 17 | parser.add_argument('output') 18 | return parser.parse_args() 19 | 20 | 21 | def call(x): 22 | subprocess.call(x, shell=True) 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | pred_adj = np.load(args.predictions) 28 | with open(args.document) as f: 29 | data = json.loads(f.read()) 30 | 31 | cur_images = [d[0] for d in data[0]] 32 | cur_text = [d[0] for d in data[1]] 33 | 34 | 35 | solve_fn = bipartite_utils.generate_fast_hungarian_solving_function() 36 | sol = solve_fn(pred_adj, max(*pred_adj.shape)) 37 | scores = pred_adj[sol[:,0], sol[:,1]] 38 | lines = [] 39 | images_to_copy = [] 40 | for img_idx, sent_idx, sc in sorted( 41 | zip(sol[:,1], sol[:,0], scores), key=lambda x:-x[-1]): 42 | lines.append('

{} ({:.1f})

'.format(cur_text[sent_idx], 43 | sc*100)) 44 | images_to_copy.append('{}/{}'.format(args.image_dir, cur_images[img_idx])) 45 | lines.append('

'.format( 46 | cur_images[img_idx].split('/')[-1])) 47 | print(cur_text[sent_idx]) 48 | print(cur_images[img_idx]) 49 | print(sc) 50 | 51 | lines.append('

Article Text:

') 52 | for sent in cur_text: 53 | lines.append('

{}

'.format(sent)) 54 | 55 | if not os.path.exists(args.output): 56 | os.makedirs(args.output) 57 | 58 | with open('{}/view.html'.format(args.output), 'w') as f: 59 | f.write('\n'.join(lines)) 60 | 61 | for im in images_to_copy: 62 | call('cp {} {}'.format(im, args.output)) 63 | print() 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | --------------------------------------------------------------------------------