├── .gitignore ├── test.py ├── preprocess_datainfo.py ├── convert_datainfo2cocofmt.py ├── compute_scores.py ├── README.md ├── create_sequencelabel.py ├── compute_ciderdf.py ├── standalize_format.py ├── Makefile ├── dataloader.py ├── utils.py ├── opts.py ├── train.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Files 3 | *.pyc 4 | *.sh 5 | 6 | # Directories 7 | input 8 | output 9 | log 10 | png 11 | cider 12 | coco-caption 13 | old 14 | .ipynb_checkpoints 15 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import os 5 | import sys 6 | import time 7 | import math 8 | import json 9 | 10 | import logging 11 | from datetime import datetime 12 | 13 | from dataloader import DataLoader 14 | from model import CaptionModel, CrossEntropyCriterion 15 | from train import test 16 | 17 | import utils 18 | import opts 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | if __name__ == '__main__': 23 | 24 | opt = opts.parse_opts() 25 | 26 | logging.basicConfig(level=getattr(logging, opt.loglevel.upper()), 27 | format='%(asctime)s:%(levelname)s: %(message)s') 28 | 29 | logger.info( 30 | 'Input arguments: %s', 31 | json.dumps( 32 | vars(opt), 33 | sort_keys=True, 34 | indent=4)) 35 | 36 | start = datetime.now() 37 | 38 | test_opt = {'label_h5': opt.test_label_h5, 39 | 'batch_size': opt.test_batch_size, 40 | 'feat_h5': opt.test_feat_h5, 41 | 'cocofmt_file': opt.test_cocofmt_file, 42 | 'seq_per_img': opt.test_seq_per_img, 43 | 'num_chunks': opt.num_chunks, 44 | 'mode': 'test' 45 | } 46 | 47 | test_loader = DataLoader(test_opt) 48 | 49 | logger.info('Loading model: %s', opt.model_file) 50 | checkpoint = torch.load(opt.model_file) 51 | checkpoint_opt = checkpoint['opt'] 52 | 53 | opt.model_type = checkpoint_opt.model_type 54 | opt.vocab = checkpoint_opt.vocab 55 | opt.vocab_size = checkpoint_opt.vocab_size 56 | opt.seq_length = checkpoint_opt.seq_length 57 | opt.feat_dims = checkpoint_opt.feat_dims 58 | 59 | assert opt.vocab_size == test_loader.get_vocab_size() 60 | assert opt.seq_length == test_loader.get_seq_length() 61 | assert opt.feat_dims == test_loader.get_feat_dims() 62 | 63 | logger.info('Building model...') 64 | model = CaptionModel(opt) 65 | logger.info('Loading state from the checkpoint...') 66 | model.load_state_dict(checkpoint['model']) 67 | 68 | xe_criterion = CrossEntropyCriterion() 69 | 70 | if torch.cuda.is_available(): 71 | model.cuda() 72 | xe_criterion.cuda() 73 | 74 | logger.info('Start testing...') 75 | test(model, xe_criterion, test_loader, opt) 76 | logger.info('Time: %s', datetime.now() - start) 77 | -------------------------------------------------------------------------------- /preprocess_datainfo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing datainfo 3 | tokenize, lowercase, etc. 4 | 5 | """ 6 | 7 | import os 8 | import json 9 | import argparse 10 | import h5py 11 | import numpy as np 12 | import string 13 | 14 | import logging 15 | from datetime import datetime 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def prepro_captions(videos): 21 | 22 | logger.info("Preprocessing %d videos", len(videos)) 23 | for i, v in enumerate(videos): 24 | v['processed_tokens'] = [] 25 | if i % 100 == 0: 26 | logger.info("%d/%d video processed", i, len(videos)) 27 | 28 | for caption in v['captions']: 29 | caption_ascii = ''.join( 30 | [ii if ord(ii) < 128 else '' for ii in caption]) 31 | tokens = str(caption_ascii).lower().translate( 32 | None, string.punctuation).strip().split() 33 | v['processed_tokens'].append(tokens) 34 | 35 | 36 | def main(input_json, output_json): 37 | 38 | infos = json.load(open(input_json, 'r')) 39 | annots = infos['captions'] 40 | 41 | logger.info('group annotations by video') 42 | vtoa = {} 43 | for ann in annots: 44 | vtoa.setdefault(ann['video_id'], []).append(ann) 45 | 46 | logger.info('create the json blob') 47 | videos = [] 48 | for i, v in enumerate(infos['videos']): 49 | 50 | jvid = {} 51 | jvid['category'] = v.get('category', 'unknown') 52 | jvid['video_id'] = v['id'] 53 | 54 | sents = [] 55 | annotsi = vtoa.get(v['id'], []) # at test time, there is no captions 56 | for a in annotsi: 57 | sents.append(a['caption']) 58 | jvid['captions'] = sents 59 | videos.append(jvid) 60 | 61 | logger.info('Tokenizing and preprocessing') 62 | prepro_captions(videos) 63 | 64 | logger.info('Writing to: %s', output_json) 65 | json.dump(videos, open(output_json, 'w')) 66 | 67 | ###################################################################### 68 | 69 | if __name__ == "__main__": 70 | logging.basicConfig(level=logging.DEBUG, 71 | format='%(asctime)s:%(levelname)s: %(message)s') 72 | parser = argparse.ArgumentParser() 73 | 74 | parser.add_argument('input_json', type=str, 75 | help='standalized input json') 76 | parser.add_argument( 77 | 'output_json', type=str, help='output tokenized json file') 78 | 79 | args = parser.parse_args() 80 | logger.info('Input arguments: %s', args) 81 | 82 | start = datetime.now() 83 | main(args.input_json, args.output_json) 84 | 85 | logger.info('Time: %s', datetime.now() - start) 86 | -------------------------------------------------------------------------------- /convert_datainfo2cocofmt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert input json to coco format to use COCO eval toolkit 3 | """ 4 | 5 | from __future__ import print_function 6 | import argparse 7 | from datetime import datetime 8 | import logging 9 | import json 10 | import sys 11 | import os.path 12 | import random 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | # remove non-accii characters 17 | 18 | 19 | def remove_nonaccii(s): 20 | s = ''.join([i if ord(i) < 128 else '' for i in s]) 21 | return s 22 | 23 | if __name__ == '__main__': 24 | start = datetime.now() 25 | 26 | logging.basicConfig( 27 | level=logging.DEBUG, 28 | format='%(asctime)s:%(levelname)s: %(message)s') 29 | 30 | argparser = argparse.ArgumentParser( 31 | description="Prepare image input in Json format for neuraltalk extraction and visualization") 32 | argparser.add_argument( 33 | "input_json", 34 | type=str, 35 | help="Standalized datainfo file") 36 | argparser.add_argument( 37 | "output_json", 38 | type=str, 39 | help="Output json in COCO format") 40 | argparser.add_argument( 41 | "--max_caption", 42 | type=int, 43 | help="Max number of caption per video; default: 0 (all captions)", 44 | default=0) 45 | 46 | args = argparser.parse_args() 47 | 48 | logger.info('Loading input file: %s', args.input_json) 49 | infos = json.load(open(args.input_json)) 50 | 51 | logger.info('Converting json data...') 52 | 53 | imgs = [{'id': v['id']} for v in infos['videos']] 54 | 55 | if args.max_caption <= 0: 56 | anns = [{'caption': remove_nonaccii(s['caption']), 57 | 'image_id': s['video_id'], 58 | 'id': s['id']} 59 | for s in infos['captions']] 60 | else: 61 | logger.info('Create dictionary of video captions') 62 | org_dict = {} 63 | for s in infos['captions']: 64 | org_dict.setdefault(s['video_id'], []).append(s['id']) 65 | 66 | sample_dict = {} 67 | logger.info( 68 | 'Randomly sample maximum %d caption(s) per video', 69 | args.max_caption) 70 | for k, v in org_dict.iteritems(): 71 | sample_dict[k] = random.sample(org_dict[k], args.max_caption) 72 | 73 | anns = [{'caption': remove_nonaccii(s['caption']), 74 | 'image_id': s['video_id'], 75 | 'id': s['id']} 76 | for s in infos['captions'] if s['id'] in sample_dict[s['video_id']]] 77 | 78 | out = { 79 | 'images': imgs, 80 | 'annotations': anns, 81 | 'type': 'captions', 82 | 'info': infos['info'], 83 | 'licenses': 'n/a'} 84 | 85 | logger.info('Saving...') 86 | with open(args.output_json, 'w') as f: 87 | json.dump(out, f) 88 | 89 | logger.info('done') 90 | logger.info('Time: %s', datetime.now() - start) 91 | -------------------------------------------------------------------------------- /compute_scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | Computer all metrics (BMCR) of all splits 3 | It is used to as the inputs of the compute_dataslice function 4 | and as pre-computed cider scores at training time 5 | 6 | """ 7 | 8 | import os 9 | import sys 10 | import json 11 | import argparse 12 | import string 13 | import itertools 14 | import numpy as np 15 | from collections import OrderedDict 16 | 17 | #sys.path.append("cider") 18 | #from pyciderevalcap.cider.cider import Cider 19 | #from pyciderevalcap.ciderD.ciderD import CiderD 20 | 21 | sys.path.append('coco-caption') 22 | #from pycocotools.coco import COCO 23 | #from pycocoevalcap.eval import COCOEvalCap 24 | 25 | from pycocoevalcap.bleu.bleu import Bleu 26 | from pycocoevalcap.meteor.meteor import Meteor 27 | from pycocoevalcap.rouge.rouge import Rouge 28 | from pycocoevalcap.cider.cider import Cider 29 | 30 | import logging 31 | from datetime import datetime 32 | 33 | import utils 34 | logger = logging.getLogger(__name__) 35 | from six.moves import cPickle 36 | 37 | if __name__ == "__main__": 38 | logging.basicConfig(level=logging.DEBUG, 39 | format='%(asctime)s:%(levelname)s: %(message)s') 40 | parser = argparse.ArgumentParser() 41 | 42 | #parser.add_argument('pred_file', type=str, help='') 43 | parser.add_argument('cocofmt_file', type=str, help='') 44 | parser.add_argument('output_pkl', type=str, help='') 45 | parser.add_argument('--seq_per_img', type=int, default=20, help='Number of caption per image/video') 46 | parser.add_argument('--remove_in_ref', default=False, action='store_true', 47 | help='Remove current caption in the ref set') 48 | 49 | args = parser.parse_args() 50 | logger.info('Input arguments: %s', args) 51 | 52 | start = datetime.now() 53 | 54 | #logger.info('Loading prediction: %s', args.pred_file) 55 | #preds = json.load(open(args.pred_file))['predictions'] 56 | #preds = {p['image_id']: [p['caption']] for p in preds} 57 | #scorer = CiderD(df=args.cached_tokens_file) 58 | 59 | logger.info('Setting up scorers...') 60 | scorers = [ 61 | (Bleu(), "Bleu_4"), 62 | (Meteor(), "METEOR"), 63 | (Rouge(), "ROUGE_L"), 64 | (Cider(), "CIDEr") 65 | ] 66 | 67 | logger.info('loading gt refs: %s', args.cocofmt_file) 68 | gt_refs = utils.load_gt_refs(args.cocofmt_file) 69 | 70 | logger.info('Computing score...') 71 | #score, scores = computer_score(preds, gt_refs, scorer) 72 | videos = sorted(gt_refs.keys()) 73 | 74 | gt_scores = {} 75 | for scorer, method in scorers: 76 | gt_scores[method] = np.zeros((len(gt_refs), args.seq_per_img)) 77 | 78 | for i in range(args.seq_per_img): 79 | logger.info('taking caption: %d', i) 80 | preds_i = {v:[gt_refs[v][i]] for v in videos} 81 | 82 | # removing the refs at i 83 | if args.remove_in_ref: 84 | gt_refs_i = {v: gt_refs[v][:i] + gt_refs[v][i+1:] for v in videos} 85 | else: 86 | gt_refs_i = gt_refs 87 | 88 | for scorer, method in scorers: 89 | score_i, scores_i = scorer.compute_score(gt_refs_i, preds_i) 90 | 91 | if method == 'Bleu_4': 92 | score_i = score_i[-1] 93 | scores_i = scores_i[-1] 94 | 95 | # happens for BLeu and METEOR 96 | if type(scores_i) == list: 97 | scores_i = np.array(scores_i) 98 | 99 | gt_scores[method][:,i] = scores_i 100 | logger.info('%s: %f', method, score_i) 101 | 102 | cPickle.dump(gt_scores, open( 103 | args.output_pkl, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) 104 | 105 | logger.info('Time: %s', datetime.now() - start) 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Consensus-based Sequence Training for Video Captioning # 2 | 3 | Code for the video captioning methods from ["Consensus-based Sequence Training for Video Captioning" (Phan, Henter, Miyao, Satoh. 2017)](https://arxiv.org/abs/1712.09532). 4 | 5 | ## Dependencies ### 6 | 7 | * Python 2.7 8 | * Pytorch 0.2 9 | * [Microsoft COCO Caption Evaluation](https://github.com/tylin/coco-caption) 10 | * [CIDEr](https://github.com/plsang/cider) 11 | 12 | (Check out the `coco-caption` and `cider` projects into your working directory) 13 | 14 | ## Data 15 | 16 | Data can be downloaded [here](https://drive.google.com/drive/folders/1t65uYsDck6VV045GIaJXPIqL86vSGtyQ?usp=sharing) (643 MB). This folder contains: 17 | * input/msrvtt: annotatated captions (note that `val_videodatainfo.json` is a symbolic link to `train_videodatainfo.json`) 18 | * output/feature: extracted features 19 | * output/model/cst_best: model file and generated captions on test videos of our best run (CIDEr 54.2) 20 | 21 | ## Getting started ### 22 | 23 | Extract video features 24 | - Extracted features of ResNet, C3D, MFCC and Category embeddings are shared in the above link 25 | 26 | Generate metadata 27 | 28 | ```bash 29 | make pre_process 30 | ``` 31 | 32 | Pre-compute document frequency for CIDEr computation 33 | ```bash 34 | make compute_ciderdf 35 | ``` 36 | 37 | Pre-compute evaluation scores (BLEU_4, CIDEr, METEOR, ROUGE_L) for each caption 38 | ```bash 39 | make compute_evalscores 40 | ``` 41 | 42 | ## Train/Test ### 43 | 44 | ```bash 45 | make train [options] 46 | make test [options] 47 | ``` 48 | 49 | Please refer to the Makefile (and opts.py file) for the set of available train/test options 50 | 51 | ## Examples 52 | 53 | Train XE model 54 | ```bash 55 | make train GID=0 EXP_NAME=xe FEATS="resnet c3d mfcc category" USE_RL=0 USE_CST=0 USE_MIXER=0 SCB_CAPTIONS=0 LOGLEVEL=DEBUG MAX_EPOCHS=50 56 | ``` 57 | 58 | Train CST_GT_None/WXE model 59 | 60 | ```bash 61 | make train GID=0 EXP_NAME=WXE FEATS="resnet c3d mfcc category" USE_RL=1 USE_CST=1 USE_MIXER=0 SCB_CAPTIONS=0 LOGLEVEL=DEBUG MAX_EPOCHS=50 62 | ``` 63 | 64 | Train CST_MS_Greedy model (using greedy baseline) 65 | 66 | ```bash 67 | make train GID=0 EXP_NAME=CST_MS_Greedy FEATS="resnet c3d mfcc category" USE_RL=1 USE_CST=0 SCB_CAPTIONS=0 USE_MIXER=1 MIXER_FROM=1 USE_EOS=1 LOGLEVEL=DEBUG MAX_EPOCHS=200 START_FROM=output/model/WXE 68 | ``` 69 | 70 | Train CST_MS_SCB model (using SCB baseline, where SCB is computed from GT captions) 71 | 72 | ``` 73 | make train GID=0 EXP_NAME=CST_MS_SCB FEATS="resnet c3d mfcc category" USE_RL=1 USE_CST=1 USE_MIXER=1 MIXER_FROM=1 SCB_BASELINE=1 SCB_CAPTIONS=20 USE_EOS=1 LOGLEVEL=DEBUG MAX_EPOCHS=200 START_FROM=output/model/WXE 74 | ``` 75 | 76 | Train CST_MS_SCB(*) model (using SCB baseline, where SCB is computed from model sampled captions) 77 | 78 | ``` 79 | make train GID=0 MODEL_TYPE=concat EXP_NAME=CST_MS_SCBSTAR FEATS="resnet c3d mfcc category" USE_RL=1 USE_CST=1 USE_MIXER=1 MIXER_FROM=1 SCB_BASELINE=2 SCB_CAPTIONS=20 USE_EOS=1 LOGLEVEL=DEBUG MAX_EPOCHS=200 START_FROM=output/model/WXE 80 | ``` 81 | 82 | If you want to change the input features, modify the `FEATS` variable in above commands. 83 | 84 | ## Reference 85 | 86 | @article{cst_phan2017, 87 | author = {Sang Phan and Gustav Eje Henter and Yusuke Miyao and Shin'ichi Satoh}, 88 | title = {Consensus-based Sequence Training for Video Captioning}, 89 | journal = {ArXiv e-prints}, 90 | archivePrefix = "arXiv", 91 | eprint = {1712.09532}, 92 | year = {2017}, 93 | } 94 | 95 | ## Todo 96 | 97 | * Test on Youtube2Text dataset (different number of captions per video) 98 | 99 | ### Acknowledgements ### 100 | 101 | * Torch implementation of [NeuralTalk2](https://github.com/karpathy/neuraltalk2) 102 | * PyTorch implementation of Self-critical Sequence Training for Image Captioning [(SCST)](https://github.com/ruotianluo/self-critical.pytorch) 103 | * PyTorch Team 104 | -------------------------------------------------------------------------------- /create_sequencelabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Read caption from a json file 3 | Save as h5 format 4 | """ 5 | 6 | import os 7 | import json 8 | import argparse 9 | import h5py 10 | import numpy as np 11 | import string 12 | from random import shuffle, seed 13 | 14 | import logging 15 | from datetime import datetime 16 | from build_vocab import __PAD_TOKEN, __UNK_TOKEN, __BOS_TOKEN, __EOS_TOKEN 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def encode_captions(videos, max_length, wtoi): 22 | """ 23 | encode all captions into one large array 24 | """ 25 | 26 | N = len(videos) 27 | M = sum(len(v['final_captions']) 28 | for v in videos) # total number of captions 29 | 30 | label_arrays = [] 31 | # note: these will be one-indexed 32 | label_start_ix = np.zeros(N, dtype=int) 33 | label_end_ix = np.zeros(N, dtype=int) 34 | label_length = np.zeros(M, dtype=int) 35 | label_to_video = np.zeros(M, dtype=int) 36 | counter = 0 37 | for i, v in enumerate(videos): 38 | n = len(v['final_captions']) 39 | assert n > 0, 'error: some image has no captions' 40 | 41 | # 0 is __PAD_TOKEN, implicitly 42 | Li = np.zeros((n, max_length), dtype=int) 43 | for j, s in enumerate(v['final_captions']): 44 | 45 | label_length[counter + j] = min(max_length, len(s)) 46 | label_to_video[counter + j] = i 47 | 48 | # truncated at max_length, thus the last token might be not the . 49 | # any problem with this? 50 | for k, w in enumerate(s): 51 | if k < max_length: 52 | Li[j, k] = wtoi[w] 53 | 54 | label_arrays.append(Li) 55 | label_start_ix[i] = counter 56 | label_end_ix[i] = counter + n 57 | 58 | counter += n 59 | 60 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 61 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 62 | # assert np.all(label_length > 0), 'error: some caption had no words?' 63 | 64 | logger.info('encoded captions to array of size %s', repr(L.shape)) 65 | return L, label_start_ix, label_end_ix, label_length, label_to_video 66 | 67 | 68 | def main(vocab_json, captions_json, output_h5, max_length): 69 | 70 | # create the vocab 71 | vocab = json.load(open(vocab_json)) 72 | 73 | # inverse table 74 | wtoi = {w: i for i, w in enumerate(vocab)} 75 | 76 | videos = json.load(open(captions_json)) 77 | 78 | logger.info('Select tokens in the vocab only') 79 | for v in videos: 80 | v['final_captions'] = [] 81 | for txt in v['processed_tokens']: 82 | caption = [__BOS_TOKEN] 83 | caption += [w if w in wtoi else __UNK_TOKEN for w in txt] 84 | caption += [__EOS_TOKEN] 85 | v['final_captions'].append(caption) 86 | 87 | with h5py.File(output_h5, 'w') as of: 88 | if len(videos[0]['captions']) > 0: 89 | logger.info('Encoding captions...') 90 | L, label_start_ix, label_end_ix, label_length, label_to_video = encode_captions( 91 | videos, max_length, wtoi) 92 | 93 | of.create_dataset('labels', dtype=int, data=L) 94 | of.create_dataset('label_start_ix', dtype=int, data=label_start_ix) 95 | of.create_dataset('label_end_ix', dtype=int, data=label_end_ix) 96 | of.create_dataset('label_length', dtype=int, data=label_length) 97 | of.create_dataset('label_to_video', dtype=int, data=label_to_video) 98 | else: 99 | logger.info('Caption not found! Skipped encoding captions.') 100 | 101 | video_ids = [v['video_id'] for v in videos] 102 | of['videos'] = np.array(video_ids, dtype=np.string_) 103 | of['vocab'] = np.array(vocab, dtype=np.string_) 104 | 105 | logger.info('Wrote to %s', output_h5) 106 | 107 | ###################################################################### 108 | 109 | if __name__ == "__main__": 110 | logging.basicConfig(level=logging.DEBUG, 111 | format='%(asctime)s:%(levelname)s: %(message)s') 112 | parser = argparse.ArgumentParser() 113 | 114 | parser.add_argument('vocab_json', default='_vocab.json', 115 | help='vocab json file') 116 | parser.add_argument('captions_json', default='_proprocessedtokens', 117 | help='_proprocessedtokens json file') 118 | parser.add_argument( 119 | 'output_h5', 120 | default='_sequencelabel.h5', 121 | help='output h5 file') 122 | 123 | parser.add_argument( 124 | '--max_length', 125 | default=30, 126 | type=int, 127 | help='max length of a caption, in number of words. captions longer than this get clipped.') 128 | 129 | args = parser.parse_args() 130 | logger.info('Input parameters: %s', args) 131 | 132 | start = datetime.now() 133 | 134 | main(args.vocab_json, args.captions_json, args.output_h5, args.max_length) 135 | 136 | logger.info('Time: %s', datetime.now() - start) 137 | -------------------------------------------------------------------------------- /compute_ciderdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | https://github.com/ruotianluo/self-critical.pytorch/blob/master/scripts/prepro_ngrams.py 4 | """ 5 | 6 | import os 7 | import json 8 | import argparse 9 | from six.moves import cPickle 10 | from collections import defaultdict 11 | 12 | import logging 13 | from datetime import datetime 14 | from build_vocab import __PAD_TOKEN, __UNK_TOKEN, __BOS_TOKEN, __EOS_TOKEN, build_vocab 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def precook(s, n=4, out=False): 20 | """ 21 | Takes a string as input and returns an object that can be given to 22 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 23 | can take string arguments as well. 24 | :param s: string : sentence to be converted into ngrams 25 | :param n: int : number of ngrams for which representation is calculated 26 | :return: term frequency vector for occuring ngrams 27 | """ 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in xrange(1, n + 1): 31 | for i in xrange(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return counts 35 | 36 | 37 | def cook_refs(refs, n=4): # lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them. 41 | :param refs: list of string : reference sentences for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (list of dict) 44 | ''' 45 | return [precook(ref, n) for ref in refs] 46 | 47 | 48 | def create_crefs(refs): 49 | crefs = [] 50 | for ref in refs: 51 | # ref is a list of 5 captions 52 | crefs.append(cook_refs(ref)) 53 | return crefs 54 | 55 | 56 | def compute_doc_freq(crefs): 57 | ''' 58 | Compute term frequency for reference data. 59 | This will be used to compute idf (inverse document frequency later) 60 | The term frequency is stored in the object 61 | :return: None 62 | ''' 63 | document_frequency = defaultdict(float) 64 | for refs in crefs: 65 | # refs, k ref captions of one image 66 | for ngram in set([ngram for ref in refs for ( 67 | ngram, count) in ref.iteritems()]): 68 | document_frequency[ngram] += 1 69 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 70 | return document_frequency 71 | 72 | 73 | def build_dict(videos, wtoi): 74 | 75 | count_videos = 0 76 | 77 | refs_words = [] 78 | refs_idxs = [] 79 | for v in videos: 80 | ref_words = [] 81 | ref_idxs = [] 82 | for sent in v['final_captions']: 83 | ref_words.append(' '.join(sent)) 84 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in sent])) 85 | refs_words.append(ref_words) 86 | refs_idxs.append(ref_idxs) 87 | count_videos += 1 88 | 89 | logger.info('total videos: %d', count_videos) 90 | 91 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 92 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 93 | return ngram_words, ngram_idxs, count_videos 94 | 95 | 96 | def main(vocab_json, captions_json, output_pkl, save_words=False): 97 | 98 | logger.info('Loading: %s', captions_json) 99 | videos = json.load(open(captions_json)) 100 | 101 | if vocab_json and os.path.isfile(vocab_json): 102 | logger.info('Loading vocab: %s', vocab_json) 103 | vocab = json.load(open(vocab_json)) 104 | else: 105 | logger.info('Selecting all word to form the vocab') 106 | vocab = build_vocab(videos, 0) 107 | 108 | # inverse table 109 | wtoi = {w: i for i, w in enumerate(vocab)} 110 | 111 | logger.info('Select tokens in the vocab only') 112 | for v in videos: 113 | v['final_captions'] = [] 114 | for txt in v['processed_tokens']: 115 | caption = [__BOS_TOKEN] 116 | caption = [w if w in wtoi else __UNK_TOKEN for w in txt] 117 | caption += [__EOS_TOKEN] 118 | v['final_captions'].append(caption) 119 | 120 | ngram_words, ngram_idxs, ref_len = build_dict(videos, wtoi) 121 | 122 | logger.info('Saving index to: %s', output_pkl) 123 | cPickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open( 124 | output_pkl, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) 125 | 126 | if save_words: 127 | output_file = output_pkl.replace('.pkl', '_words.pkl', 1) 128 | logger.info('Saving word to: %s', output_file) 129 | cPickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open( 130 | output_file, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) 131 | 132 | if __name__ == "__main__": 133 | logging.basicConfig(level=logging.DEBUG, 134 | format='%(asctime)s:%(levelname)s: %(message)s') 135 | parser = argparse.ArgumentParser() 136 | 137 | # input json 138 | parser.add_argument('captions_json', default='_proprocessedtokens', 139 | help='_proprocessedtokens json file') 140 | parser.add_argument( 141 | 'output_pkl', 142 | default='_pkl', 143 | help='save idx frequencies') 144 | 145 | parser.add_argument( 146 | '--output_words', 147 | action='store_true', 148 | help='optionally saving word frequencies') 149 | 150 | parser.add_argument('--vocab_json', default=None, 151 | help='vocab json file') 152 | 153 | args = parser.parse_args() 154 | 155 | start = datetime.now() 156 | 157 | main( 158 | args.vocab_json, 159 | args.captions_json, 160 | args.output_pkl, 161 | save_words=args.output_words) 162 | 163 | logger.info('Time: %s', datetime.now() - start) 164 | -------------------------------------------------------------------------------- /standalize_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert MSRVTT format to standard JSON 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | import string 9 | import itertools 10 | 11 | import logging 12 | from datetime import datetime 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def standalize_yt2t(input_file): 18 | """ 19 | Use data splits provided by the NAACL2015 paper 20 | Ref: 21 | """ 22 | logger.info('Reading file: %s', input_file) 23 | lines = [line.rstrip('\n') for line in open(input_file)] 24 | lines = [line.split('\t') for line in lines] 25 | 26 | logger.info('Building caption dictionary for each video key') 27 | video_ids = [] 28 | capdict = {} 29 | for line in lines: 30 | video_id = line[0] 31 | if video_id in capdict: 32 | capdict[video_id].append(line[1]) 33 | else: 34 | capdict[video_id] = [line[1]] 35 | video_ids.append(video_id) 36 | 37 | # create the json blob 38 | videos = [] 39 | captions = [] 40 | counter = itertools.count() 41 | for video_id in video_ids: 42 | 43 | vid = int(video_id[3:]) 44 | jvid = {} 45 | jvid['category'] = 'unknown' 46 | jvid['video_id'] = video_id 47 | jvid['id'] = vid 48 | jvid['start_time'] = -1 49 | jvid['end_time'] = -1 50 | jvid['url'] = '' 51 | videos.append(jvid) 52 | 53 | for caption in capdict[video_id]: 54 | jcap = {} 55 | jcap['id'] = next(counter) 56 | jcap['video_id'] = vid 57 | jcap['caption'] = unicode(caption, errors='ignore') 58 | captions.append(jcap) 59 | 60 | out = {} 61 | out['info'] = {} 62 | out['videos'] = videos 63 | out['captions'] = captions 64 | 65 | return out 66 | 67 | 68 | def standalize_msrvtt( 69 | input_file, 70 | dataset='msrvtt2016', 71 | split='train', 72 | val2016_json=None): 73 | """ 74 | Supports both msrvtt2016 and msrvtt2017 75 | There is no official train/val set in MSRVTT2017: 76 | -> train2017 = train2016 + test2016 77 | -> val2017 = val2016 78 | """ 79 | info = json.load(open(args.input_file)) 80 | 81 | if split == 'val': 82 | split = 'validate' 83 | 84 | out = {} 85 | out['info'] = info['info'] 86 | 87 | if args.dataset == 'msrvtt2017' and split == 'train': 88 | # loading all training videos and removing those that are in the 89 | # val2016 set 90 | logger.info('Loading val2016 info: %s', val2016_json) 91 | info2016 = json.load(open(val2016_json)) 92 | val2016_videos = [v for v in info2016[ 93 | 'videos'] if v['split'] == 'validate'] 94 | 95 | val2016_video_dict = {v['video_id']: v['id'] for v in val2016_videos} 96 | out['videos'] = [v for v in info['videos'] 97 | if v['video_id'] not in val2016_video_dict] 98 | 99 | else: 100 | out['videos'] = [v for v in info['videos'] if v['split'] == split] 101 | 102 | tmp_dict = {v['video_id']: v['id'] for v in out['videos']} 103 | out['captions'] = [{'id': c['sen_id'], 'video_id': tmp_dict[c['video_id']], 'caption': c[ 104 | 'caption']} for c in info['sentences'] if c['video_id'] in tmp_dict] 105 | 106 | return out 107 | 108 | 109 | def standalize_tvvtt(input_file, split='train2016'): 110 | """ 111 | Standalize TRECVID V2T task 112 | Read from a metadata file generated in the v2t2017 project 113 | Basically there is no split in the v2t dataset, 114 | Just consider each provided set as an independent dataset 115 | """ 116 | split_mapping = { 117 | 'train': 'train2016', 118 | 'val': 'test2016', 119 | 'test': 'test2017' 120 | } 121 | 122 | split = split_mapping[split] 123 | logger.info('Loading file: %s, split: %s', input_file, split) 124 | info = json.load(open(input_file))[split] 125 | 126 | out = {} 127 | out['info'] = {} 128 | videos = [] 129 | for v in info['videos']: 130 | jvid = {} 131 | jvid['category'] = 'unknown' 132 | jvid['video_id'] = str(v) 133 | jvid['id'] = v 134 | jvid['start_time'] = -1 135 | jvid['end_time'] = -1 136 | jvid['url'] = '' 137 | videos.append(jvid) 138 | out['videos'] = videos 139 | out['captions'] = info['captions'] 140 | 141 | return out 142 | 143 | if __name__ == "__main__": 144 | logging.basicConfig(level=logging.DEBUG, 145 | format='%(asctime)s:%(levelname)s: %(message)s') 146 | parser = argparse.ArgumentParser() 147 | 148 | parser.add_argument('input_file', type=str, help='') 149 | parser.add_argument('output_json', type=str, help='') 150 | parser.add_argument('--split', type=str, help='') 151 | parser.add_argument( 152 | '--dataset', 153 | type=str, 154 | default='yt2t', 155 | choices=[ 156 | 'yt2t', 157 | 'msrvtt2016', 158 | 'msrvtt2017', 159 | 'tvvtt'], 160 | help='Choose dataset') 161 | parser.add_argument( 162 | '--val2016_json', 163 | type=str, 164 | help='use valset from msrvtt2016 contest') 165 | 166 | args = parser.parse_args() 167 | logger.info('Input arguments: %s', args) 168 | 169 | start = datetime.now() 170 | 171 | if args.dataset == 'msrvtt2016': 172 | out = standalize_msrvtt( 173 | args.input_file, 174 | dataset=args.dataset, 175 | split=args.split) 176 | elif args.dataset == 'msrvtt2017': 177 | out = standalize_msrvtt( 178 | args.input_file, 179 | dataset=args.dataset, 180 | split=args.split, 181 | val2016_json=args.val2016_json) 182 | elif args.dataset == 'yt2t': 183 | out = standalize_yt2t(args.input_file) 184 | elif args.dataset == 'tvvtt': 185 | out = standalize_tvvtt(args.input_file, split=args.split) 186 | else: 187 | raise ValueError('Unknow dataset: %s', args.dataset) 188 | 189 | if not os.path.exists(os.path.dirname(args.output_json)): 190 | os.makedirs(os.path.dirname(args.output_json)) 191 | 192 | with open(args.output_json, 'w') as of: 193 | json.dump(out, of) 194 | 195 | logger.info('Time: %s', datetime.now() - start) 196 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | ### Directory Setting 3 | IN_DIR=input 4 | OUT_DIR=output 5 | META_DIR=$(OUT_DIR)/metadata 6 | FEAT_DIR=$(OUT_DIR)/feature 7 | MODEL_DIR=$(OUT_DIR)/model 8 | 9 | MSRVTT2016_DIR=$(IN_DIR)/msrvtt 10 | MSRVTT2017_DIR=$(IN_DIR)/msrvtt2017 11 | YT2T_DIR=$(IN_DIR)/yt2t 12 | 13 | SPLITS=train val test 14 | DATASETS=msrvtt# yt2t msrvtt2017 tvvtt 15 | 16 | WORD_COUNT_THRESHOLD?=3 # in output/metadata this threshold was 0; was 3 in output/metadata2017 17 | MAX_SEQ_LEN?=30 # in output/metadata seqlen was 20; was 30 in output/metadata2017 18 | 19 | GID?=5 20 | 21 | DATASET?=msrvtt 22 | TRAIN_DATASET?=$(DATASET) 23 | VAL_DATASET?=$(DATASET) 24 | TEST_DATASET?=$(DATASET) 25 | TRAIN_SPLIT?=train 26 | VAL_SPLIT?=val 27 | TEST_SPLIT?=test 28 | 29 | LEARNING_RATE?=0.0001 30 | LR_UPDATE?=200 31 | BATCH_SIZE?=64 32 | TRAIN_SEQ_PER_IMG?=20 33 | TEST_SEQ_PER_IMG?=20 34 | RNN_SIZE?=512 35 | 36 | PRINT_INTERVAL?=20 37 | MAX_PATIENCE?=50 # FOR EARLY STOPPING 38 | SAVE_CHECKPOINT_FROM?=1 39 | 40 | MAX_EPOCHS?=200 41 | NUM_CHUNKS?=1 42 | PRINT_ATT_COEF?=0 43 | BEAM_SIZE?=5 44 | 45 | TODAY=20170831 46 | EXP_NAME?=exp_$(DATASET)_$(TODAY) 47 | VAL_LANG_EVAL?=1 48 | TEST_LANG_EVAL?=1 49 | EVAL_METRIC?=CIDEr 50 | START_FROM?=No 51 | MODEL_TYPE?=concat 52 | POOLING?=mp 53 | CAT_TYPE=glove 54 | LOGLEVEL?=INFO 55 | 56 | SS_MAX_PROB?=0.25 57 | USE_CST?=0 58 | SCB_CAPTIONS?=20 59 | SCB_BASELINE?=1 60 | USE_RL?=0 61 | USE_RL_AFTER?=0 62 | USE_EOS?=0 63 | USE_MIXER?=0 64 | MIXER_FROM?=-1 65 | SS_K?=100 66 | 67 | 68 | FEAT1?=resnet 69 | FEAT2?=c3d 70 | FEAT3?=mfcc 71 | FEAT4?=category 72 | 73 | FEATS?=$(FEAT1) $(FEAT2) $(FEAT3) $(FEAT4) 74 | 75 | TRAIN_ID=$(TRAIN_DATASET)_$(MODEL_TYPE)_$(EVAL_METRIC)_$(BATCH_SIZE)_$(LEARNING_RATE) 76 | 77 | ################################################################################################################### 78 | ### 79 | pre_process: standalize_datainfo preprocess_datainfo build_vocab create_sequencelabel convert_datainfo2cocofmt 80 | 81 | ### Standalize data 82 | standalize_datainfo: $(foreach d,$(DATASETS),$(patsubst %,$(META_DIR)/$(d)_%_datainfo.json,$(SPLITS))) 83 | $(META_DIR)/msrvtt_%_datainfo.json: $(MSRVTT2016_DIR)/%_videodatainfo.json 84 | python standalize_format.py $^ $@ --dataset msrvtt2016 --split $* 85 | $(META_DIR)/msrvtt2017_%_datainfo.json: $(MSRVTT2017_DIR)/msrvtt2017_%_videodatainfo.json 86 | python standalize_format.py $^ $@ --dataset msrvtt2017 --split $* \ 87 | --val2016_json $(MSRVTT2016_DIR)/val_videodatainfo.json 88 | $(META_DIR)/yt2t_%_datainfo.json: $(YT2T_DIR)/naacl15/sents_%_lc_nopunc.txt 89 | python standalize_format.py $^ $@ --dataset yt2t 90 | $(META_DIR)/tvvtt_%_datainfo.json: $(META_DIR)/v2t2017_infos.json 91 | python standalize_format.py $^ $@ --dataset tvvtt --split $* 92 | ### 93 | preprocess_datainfo: $(foreach s,$(SPLITS),$(patsubst %,$(META_DIR)/%_$(s)_proprocessedtokens.json,$(DATASETS))) 94 | %_proprocessedtokens.json: %_datainfo.json 95 | python preprocess_datainfo.py $^ $@ 96 | 97 | ### 98 | build_vocab: $(patsubst %,$(META_DIR)/%_train_vocab.json,$(DATASETS)) 99 | %_train_vocab.json: %_train_proprocessedtokens.json 100 | python build_vocab.py $< $@ --word_count_threshold $(WORD_COUNT_THRESHOLD) 101 | ### 102 | create_sequencelabel: $(foreach s,$(SPLITS),$(patsubst %,$(META_DIR)/%_$(s)_sequencelabel.h5,$(DATASETS))) 103 | .SECONDEXPANSION: 104 | %_sequencelabel.h5: $$(firstword $$(subst _, ,$$@))_train_vocab.json %_proprocessedtokens.json 105 | python create_sequencelabel.py $^ $@ --max_length $(MAX_SEQ_LEN) 106 | 107 | ### Convert standalized datainfo to coco format for language evaluation 108 | convert_datainfo2cocofmt: $(foreach s,$(SPLITS),$(patsubst %,$(META_DIR)/%_$(s)_cocofmt.json,$(DATASETS))) 109 | %_cocofmt.json: %_datainfo.json 110 | python convert_datainfo2cocofmt.py $< $@ 111 | 112 | ### pre-compute document frequency for computing CIDEr of on model samples 113 | compute_ciderdf: $(foreach s,$(SPLITS),$(patsubst %,$(META_DIR)/%_$(s)_ciderdf.pkl,$(DATASETS))) 114 | %_ciderdf.pkl: %_proprocessedtokens.json 115 | python compute_ciderdf.py $^ $@ --output_words --vocab_json $(firstword $(subst _, ,$@))_train_vocab.json 116 | 117 | ### pre-compute evaluation scores (BLEU_4, CIDEr, METEOR, ROUGE_L) 118 | compute_evalscores: $(patsubst %,$(META_DIR)/$(TRAIN_DATASET)_%_evalscores.pkl,$(SPLITS)) 119 | %_evalscores.pkl: %_cocofmt.json 120 | python compute_scores.py $^ $@ --remove_in_ref 121 | 122 | ##################################################################################################################### 123 | 124 | noop= 125 | space=$(noop) $(noop) 126 | 127 | TRAIN_OPT=--beam_size $(BEAM_SIZE) --max_patience $(MAX_PATIENCE) --eval_metric $(EVAL_METRIC) --print_log_interval $(PRINT_INTERVAL)\ 128 | --language_eval $(VAL_LANG_EVAL) --max_epochs $(MAX_EPOCHS) --rnn_size $(RNN_SIZE) \ 129 | --train_seq_per_img $(TRAIN_SEQ_PER_IMG) --test_seq_per_img $(TEST_SEQ_PER_IMG) \ 130 | --batch_size $(BATCH_SIZE) --test_batch_size $(BATCH_SIZE) --learning_rate $(LEARNING_RATE) --lr_update $(LR_UPDATE) \ 131 | --save_checkpoint_from $(SAVE_CHECKPOINT_FROM) --num_chunks $(NUM_CHUNKS) \ 132 | --train_cached_tokens $(META_DIR)/$(TRAIN_DATASET)_train_ciderdf.pkl \ 133 | --ss_k $(SS_K) --use_rl_after $(USE_RL_AFTER) --ss_max_prob $(SS_MAX_PROB) \ 134 | --use_rl $(USE_RL) --use_mixer $(USE_MIXER) --mixer_from $(MIXER_FROM) \ 135 | --use_cst $(USE_CST) --scb_captions $(SCB_CAPTIONS) --scb_baseline $(SCB_BASELINE) \ 136 | --loglevel $(LOGLEVEL) --model_type $(MODEL_TYPE) --use_eos $(USE_EOS) \ 137 | --model_file $@ --start_from $(START_FROM) --result_file $(basename $@)_test.json \ 138 | 2>&1 | tee $(basename $@).log 139 | 140 | TEST_OPT=--beam_size $(BEAM_SIZE) \ 141 | --language_eval $(VAL_LANG_EVAL) \ 142 | --test_seq_per_img $(TEST_SEQ_PER_IMG) \ 143 | --test_batch_size $(BATCH_SIZE) \ 144 | --loglevel $(LOGLEVEL) \ 145 | --result_file $@ 146 | 147 | train: $(MODEL_DIR)/$(EXP_NAME)/$(subst $(space),$(noop),$(FEATS))_$(TRAIN_ID).pth 148 | $(MODEL_DIR)/$(EXP_NAME)/$(subst $(space),$(noop),$(FEATS))_$(TRAIN_ID).pth: \ 149 | $(META_DIR)/$(TRAIN_DATASET)_$(TRAIN_SPLIT)_sequencelabel.h5 \ 150 | $(META_DIR)/$(VAL_DATASET)_$(VAL_SPLIT)_sequencelabel.h5 \ 151 | $(META_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_sequencelabel.h5 \ 152 | $(META_DIR)/$(TRAIN_DATASET)_$(TRAIN_SPLIT)_cocofmt.json \ 153 | $(META_DIR)/$(VAL_DATASET)_$(VAL_SPLIT)_cocofmt.json \ 154 | $(META_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_cocofmt.json \ 155 | $(META_DIR)/$(TRAIN_DATASET)_$(TRAIN_SPLIT)_evalscores.pkl \ 156 | $(patsubst %,$(FEAT_DIR)/$(TRAIN_DATASET)_$(TRAIN_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS)) \ 157 | $(patsubst %,$(FEAT_DIR)/$(VAL_DATASET)_$(VAL_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS)) \ 158 | $(patsubst %,$(FEAT_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS)) 159 | mkdir -p $(MODEL_DIR)/$(EXP_NAME) 160 | CUDA_VISIBLE_DEVICES=$(GID) python train.py \ 161 | --train_label_h5 $(word 1,$^) \ 162 | --val_label_h5 $(word 2,$^) \ 163 | --test_label_h5 $(word 3,$^) \ 164 | --train_cocofmt_file $(word 4,$^) \ 165 | --val_cocofmt_file $(word 5,$^) \ 166 | --test_cocofmt_file $(word 6,$^) \ 167 | --train_bcmrscores_pkl $(word 7,$^) \ 168 | --train_feat_h5 $(patsubst %,$(FEAT_DIR)/$(TRAIN_DATASET)_$(TRAIN_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS))\ 169 | --val_feat_h5 $(patsubst %,$(FEAT_DIR)/$(VAL_DATASET)_$(VAL_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS))\ 170 | --test_feat_h5 $(patsubst %,$(FEAT_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS))\ 171 | $(TRAIN_OPT) 172 | 173 | test: $(MODEL_DIR)/$(EXP_NAME)/$(subst $(space),$(noop),$(FEATS))_$(TRAIN_ID)_test.json 174 | $(MODEL_DIR)/$(EXP_NAME)/$(subst $(space),$(noop),$(FEATS))_$(TRAIN_ID)_test.json: \ 175 | $(MODEL_DIR)/$(EXP_NAME)/$(subst $(space),$(noop),$(FEATS))_$(TRAIN_ID).pth \ 176 | $(META_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_sequencelabel.h5 \ 177 | $(META_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_cocofmt.json \ 178 | $(patsubst %,$(FEAT_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS)) 179 | CUDA_VISIBLE_DEVICES=$(GID) python test.py \ 180 | --model_file $(word 1,$^) \ 181 | --test_label_h5 $(word 2,$^) \ 182 | --test_cocofmt_file $(word 3,$^) \ 183 | --test_feat_h5 $(patsubst %,$(FEAT_DIR)/$(TEST_DATASET)_$(TEST_SPLIT)_%_mp$(NUM_CHUNKS).h5,$(FEATS))\ 184 | $(TEST_OPT) 185 | 186 | 187 | # You can use the wildcard with .PRECIOUS. 188 | .PRECIOUS: %.pth 189 | 190 | # If you want all intermediates to remain 191 | .SECONDARY: 192 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import json 5 | import h5py 6 | import os 7 | import numpy as np 8 | import random 9 | import time 10 | import cPickle 11 | 12 | import logging 13 | from datetime import datetime 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DataLoader(): 18 | 19 | """Class to load video features and captions""" 20 | 21 | def __init__(self, opt): 22 | 23 | self.iterator = 0 24 | self.epoch = 0 25 | 26 | self.batch_size = opt.get('batch_size', 128) 27 | self.seq_per_img = opt.get('seq_per_img', 1) 28 | self.word_embedding_size = opt.get('word_embedding_size', 512) 29 | self.num_chunks = opt.get('num_chunks', 1) 30 | self.mode = opt.get('mode', 'train') 31 | self.cocofmt_file = opt.get('cocofmt_file', None) 32 | self.bcmrscores_pkl = opt.get('bcmrscores_pkl', None) 33 | 34 | # open the hdf5 info file 35 | logger.info('DataLoader loading h5 file: %s', opt['label_h5']) 36 | self.label_h5 = h5py.File(opt['label_h5'], 'r') 37 | 38 | self.vocab = [i for i in self.label_h5['vocab']] 39 | self.videos = [i for i in self.label_h5['videos']] 40 | 41 | self.ix_to_word = {i: w for i, w in enumerate(self.vocab)} 42 | self.num_videos = len(self.videos) 43 | self.index = range(self.num_videos) 44 | 45 | # load the json file which contains additional information about the 46 | # dataset 47 | feat_h5_files = opt['feat_h5'] 48 | logger.info('DataLoader loading h5 files: %s', feat_h5_files) 49 | self.feat_h5 = [] 50 | self.feat_dims = [] 51 | for ii, feat_h5_file in enumerate(feat_h5_files): 52 | self.feat_h5.append(h5py.File(feat_h5_files[ii], 'r')) 53 | self.feat_dims.append(self.feat_h5[ii][self.videos[0]].shape[0]) 54 | 55 | self.num_feats = len(feat_h5_files) 56 | 57 | # load in the sequence data 58 | if 'labels' in self.label_h5.keys(): 59 | self.seq_length = self.label_h5['labels'].shape[1] 60 | logger.info('max sequence length in data is: %d', self.seq_length) 61 | 62 | # load the pointers in full to RAM (should be small enough) 63 | self.label_start_ix = self.label_h5['label_start_ix'] 64 | self.label_end_ix = self.label_h5['label_end_ix'] 65 | assert(self.label_start_ix.shape[0] == self.label_end_ix.shape[0]) 66 | self.has_label = True 67 | else: 68 | self.has_label = False 69 | 70 | if self.bcmrscores_pkl is not None: 71 | eval_metric = opt.get('eval_metric', 'CIDEr') 72 | logger.info('Loading: %s, with metric: %s', self.bcmrscores_pkl, eval_metric) 73 | self.bcmrscores = cPickle.load(open(self.bcmrscores_pkl)) 74 | if eval_metric == 'CIDEr' and eval_metric not in self.bcmrscores: 75 | eval_metric = 'cider' 76 | self.bcmrscores = self.bcmrscores[eval_metric] 77 | 78 | if self.mode == 'train': 79 | self.shuffle_videos() 80 | 81 | def __del__(self): 82 | for f in self.feat_h5: 83 | f.close() 84 | self.label_h5.close() 85 | 86 | def get_batch(self): 87 | 88 | video_batch = [] 89 | for dim in self.feat_dims: 90 | feat = torch.FloatTensor( 91 | self.batch_size, self.num_chunks, dim).zero_() 92 | video_batch.append(feat) 93 | 94 | if self.has_label: 95 | label_batch = torch.LongTensor( 96 | self.batch_size * self.seq_per_img, 97 | self.seq_length).zero_() 98 | mask_batch = torch.FloatTensor( 99 | self.batch_size * self.seq_per_img, 100 | self.seq_length).zero_() 101 | 102 | videoids_batch = [] 103 | gts = [] 104 | bcmrscores = np.zeros((self.batch_size, self.seq_per_img)) if self.bcmrscores_pkl is not None else None 105 | 106 | for ii in range(self.batch_size): 107 | idx = self.index[self.iterator] 108 | video_id = int(self.videos[idx]) 109 | videoids_batch.append(video_id) 110 | 111 | for jj in range(self.num_feats): 112 | video_batch[jj][ii] = torch.from_numpy( 113 | np.array(self.feat_h5[jj][str(video_id)])) 114 | 115 | if self.has_label: 116 | # fetch the sequence labels 117 | ix1 = self.label_start_ix[idx] 118 | ix2 = self.label_end_ix[idx] 119 | ncap = ix2 - ix1 # number of captions available for this image 120 | assert ncap > 0, 'No captions!!' 121 | 122 | seq = torch.LongTensor( 123 | self.seq_per_img, self.seq_length).zero_() 124 | seq_all = torch.from_numpy( 125 | np.array(self.label_h5['labels'][ix1:ix2])) 126 | 127 | if ncap <= self.seq_per_img: 128 | seq[:ncap] = seq_all[:ncap] 129 | for q in range(ncap, self.seq_per_img): 130 | ixl = np.random.randint(ncap) 131 | seq[q] = seq_all[ixl] 132 | else: 133 | randpos = torch.randperm(ncap) 134 | for q in range(self.seq_per_img): 135 | ixl = randpos[q] 136 | seq[q] = seq_all[ixl] 137 | 138 | il = ii * self.seq_per_img 139 | label_batch[il:il + self.seq_per_img] = seq 140 | 141 | # Used for reward evaluation 142 | gts.append( 143 | self.label_h5['labels'][ 144 | self.label_start_ix[idx]: self.label_end_ix[idx]]) 145 | 146 | # pre-computed cider scores, 147 | # assuming now that videos order are same (which is the sorted videos order) 148 | if self.bcmrscores_pkl is not None: 149 | bcmrscores[ii] = self.bcmrscores[idx] 150 | 151 | self.iterator += 1 152 | if self.iterator >= self.num_videos: 153 | logger.info('===> Finished loading epoch %d', self.epoch) 154 | self.iterator = 0 155 | self.epoch += 1 156 | if self.mode == 'train': 157 | self.shuffle_videos() 158 | 159 | data = {} 160 | data['feats'] = video_batch 161 | data['ids'] = videoids_batch 162 | 163 | if self.has_label: 164 | # + 1 here to count the token, because the token is set to 0 165 | nonzeros = np.array( 166 | list(map(lambda x: (x != 0).sum() + 1, label_batch))) 167 | for ix, row in enumerate(mask_batch): 168 | row[:nonzeros[ix]] = 1 169 | 170 | data['labels'] = label_batch 171 | data['masks'] = mask_batch 172 | data['gts'] = gts 173 | data['bcmrscores'] = bcmrscores 174 | 175 | return data 176 | 177 | def reset(self): 178 | self.iterator = 0 179 | 180 | def get_current_index(self): 181 | return self.iterator 182 | 183 | def set_current_index(self, index): 184 | self.iterator = index 185 | 186 | def get_vocab(self): 187 | return self.ix_to_word 188 | 189 | def get_vocab_size(self): 190 | return len(self.vocab) 191 | 192 | def get_feat_dims(self): 193 | return self.feat_dims 194 | 195 | def get_feat_size(self): 196 | return sum(self.feat_dims) 197 | 198 | def get_num_feats(self): 199 | return self.num_feats 200 | 201 | def get_seq_length(self): 202 | return self.seq_length 203 | 204 | def get_seq_per_img(self): 205 | return self.seq_per_img 206 | 207 | def get_num_videos(self): 208 | return self.num_videos 209 | 210 | def get_batch_size(self): 211 | return self.batch_size 212 | 213 | def get_current_epoch(self): 214 | return self.epoch 215 | 216 | def set_current_epoch(self, epoch): 217 | self.epoch = epoch 218 | 219 | def shuffle_videos(self): 220 | np.random.shuffle(self.index) 221 | 222 | def get_cocofmt_file(self): 223 | return self.cocofmt_file 224 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | sys.path.append("cider") 9 | from pyciderevalcap.ciderD.ciderD import CiderD 10 | 11 | sys.path.append('coco-caption') 12 | from pycocotools.coco import COCO 13 | from pycocoevalcap.eval import COCOEvalCap 14 | 15 | from pycocoevalcap.bleu.bleu import Bleu 16 | from pycocoevalcap.rouge.rouge import Rouge 17 | from pycocoevalcap.meteor.meteor import Meteor 18 | 19 | import cPickle 20 | 21 | 22 | def adjust_learning_rate(opt, optimizer, epoch): 23 | """Sets the learning rate to the initial LR 24 | decayed by 10 every [lr_update] epochs""" 25 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = lr 28 | return lr 29 | 30 | 31 | def score(ref, hypo): 32 | """ 33 | ref, dictionary of reference sentences (id, sentence) 34 | hypo, dictionary of hypothesis sentences (id, sentence) 35 | score, dictionary of scores 36 | """ 37 | scorers = [ 38 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 39 | (Meteor(), "METEOR"), 40 | (Rouge(), "ROUGE_L"), 41 | (Cider(), "CIDEr") 42 | ] 43 | final_scores = {} 44 | for scorer, method in scorers: 45 | score, scores = scorer.compute_score(ref, hypo) 46 | if isinstance(score, list): 47 | for m, s in zip(method, score): 48 | final_scores[m] = s 49 | else: 50 | final_scores[method] = score 51 | return final_scores 52 | 53 | 54 | def load_gt_refs(cocofmt_file): 55 | d = json.load(open(cocofmt_file)) 56 | out = {} 57 | for i in d['annotations']: 58 | out.setdefault(i['image_id'], []).append(i['caption']) 59 | return out 60 | 61 | 62 | def compute_score(gt_refs, predictions, scorer): 63 | # use with standard package https://github.com/tylin/coco-caption 64 | # hypo = {p['image_id']: [p['caption']] for p in predictions} 65 | 66 | # use with Cider provided by https://github.com/ruotianluo/cider 67 | hypo = [{'image_id': p['image_id'], 'caption':[p['caption']]} 68 | for p in predictions] 69 | 70 | # standard package requires ref and hypo have same keys, i.e., ref.keys() 71 | # == hypo.keys() 72 | ref = {p['image_id']: gt_refs[p['image_id']] for p in predictions} 73 | 74 | score, scores = scorer.compute_score(ref, hypo) 75 | 76 | return score, scores 77 | 78 | 79 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 80 | def decode_sequence(ix_to_word, seq): 81 | N, D = seq.size() 82 | out = [] 83 | for i in range(N): 84 | txt = '' 85 | for j in range(D): 86 | ix = seq[i, j] 87 | if ix > 0: 88 | if j >= 1: 89 | txt = txt + ' ' 90 | txt = txt + ix_to_word[ix] 91 | else: 92 | break 93 | out.append(txt) 94 | return out 95 | 96 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 97 | def compute_avglogp(seq, logseq, eos_token=0): 98 | seq = seq.cpu().numpy() 99 | logseq = logseq.cpu().numpy() 100 | 101 | N, D = seq.shape 102 | out_avglogp = [] 103 | for i in range(N): 104 | avglogp = [] 105 | for j in range(D): 106 | ix = seq[i, j] 107 | avglogp.append(logseq[i, j]) 108 | if ix == eos_token: 109 | break 110 | avg = 0 if len(avglogp) == 0 else sum(avglogp)/float(len(avglogp)) 111 | out_avglogp.append(avg) 112 | return out_avglogp 113 | 114 | def language_eval(gold_file, pred_file): 115 | 116 | # save the current stdout 117 | temp = sys.stdout 118 | sys.stdout = open(os.devnull, 'w') 119 | 120 | coco = COCO(gold_file) 121 | cocoRes = coco.loadRes(pred_file) 122 | cocoEval = COCOEvalCap(coco, cocoRes) 123 | cocoEval.params['image_id'] = cocoRes.getImgIds() 124 | cocoEval.evaluate() 125 | 126 | out = {} 127 | for metric, score in cocoEval.eval.items(): 128 | out[metric] = round(score, 5) 129 | 130 | # restore the previous stdout 131 | sys.stdout = temp 132 | return out 133 | 134 | 135 | def array_to_str(arr, use_eos=0): 136 | out = '' 137 | for i in range(len(arr)): 138 | if use_eos == 0 and arr[i] == 0: 139 | break 140 | 141 | # skip the token 142 | if arr[i] == 1: 143 | continue 144 | 145 | out += str(arr[i]) + ' ' 146 | 147 | # return if encouters the token 148 | # this will also guarantees that the first will be rewarded 149 | if arr[i] == 0: 150 | break 151 | 152 | return out.strip() 153 | 154 | 155 | def get_self_critical_reward2(model_res, greedy_res, gt_refs, scorer): 156 | 157 | model_score, model_scores = compute_score(model_res, gt_refs, scorer) 158 | greedy_score, greedy_scores = compute_score(greedy_res, gt_refs, scorer) 159 | scores = model_scores - greedy_scores 160 | 161 | m_score = np.mean(model_scores) 162 | g_score = np.mean(greedy_scores) 163 | 164 | #rewards = np.repeat(scores[:, np.newaxis], model_res.shape[1], 1) 165 | 166 | return m_score, g_score 167 | 168 | 169 | def get_self_critical_reward( 170 | model_res, 171 | greedy_res, 172 | data_gts, 173 | bcmr_scorer, 174 | expand_feat=0, 175 | seq_per_img=20, 176 | use_eos=0): 177 | 178 | batch_size = model_res.size(0) 179 | 180 | model_res = model_res.cpu().numpy() 181 | greedy_res = greedy_res.cpu().numpy() 182 | 183 | res = OrderedDict() 184 | for i in range(batch_size): 185 | res[i] = [array_to_str(model_res[i], use_eos)] 186 | for i in range(batch_size): 187 | res[batch_size + i] = [array_to_str(greedy_res[i], use_eos)] 188 | 189 | gts = OrderedDict() 190 | for i in range(len(data_gts)): 191 | gts[i] = [array_to_str(data_gts[i][j], use_eos) 192 | for j in range(len(data_gts[i]))] 193 | 194 | #_, scores = Bleu(4).compute_score(gts, res) 195 | #scores = np.array(scores[3]) 196 | if isinstance(bcmr_scorer, CiderD): 197 | res = [{'image_id': i, 'caption': res[i]} for i in range(2 * batch_size)] 198 | 199 | if expand_feat == 1: 200 | gts = {i: gts[(i % batch_size) // seq_per_img] 201 | for i in range(2 * batch_size)} 202 | else: 203 | gts = {i: gts[i % batch_size] for i in range(2 * batch_size)} 204 | 205 | score, scores = bcmr_scorer.compute_score(gts, res) 206 | 207 | # if bleu, only use bleu_4 208 | if isinstance(bcmr_scorer, Bleu): 209 | score = score[-1] 210 | scores = scores[-1] 211 | 212 | # happens for BLeu and METEOR 213 | if type(scores) == list: 214 | scores = np.array(scores) 215 | 216 | m_score = np.mean(scores[:batch_size]) 217 | g_score = np.mean(scores[batch_size:]) 218 | 219 | scores = scores[:batch_size] - scores[batch_size:] 220 | 221 | rewards = np.repeat(scores[:, np.newaxis], model_res.shape[1], 1) 222 | 223 | return rewards, m_score, g_score 224 | 225 | 226 | def get_cst_reward( 227 | model_res, 228 | data_gts, 229 | bcmr_scorer, 230 | bcmrscores=None, 231 | expand_feat=0, 232 | seq_per_img=20, 233 | scb_captions=20, 234 | scb_baseline=1, 235 | use_eos=0, 236 | use_mixer=0): 237 | 238 | """ 239 | Arguments: 240 | bcmrscores: precomputed scores of GT sequences 241 | scb_baseline: 1 - use GT to compute baseline, 242 | 2 - use MS to compute baseline 243 | """ 244 | 245 | if bcmrscores is None or use_mixer == 1: 246 | batch_size = model_res.size(0) 247 | 248 | model_res = model_res.cpu().numpy() 249 | 250 | res = OrderedDict() 251 | for i in range(batch_size): 252 | res[i] = [array_to_str(model_res[i], use_eos)] 253 | 254 | gts = OrderedDict() 255 | for i in range(len(data_gts)): 256 | gts[i] = [array_to_str(data_gts[i][j], use_eos) 257 | for j in range(len(data_gts[i]))] 258 | 259 | if isinstance(bcmr_scorer, CiderD): 260 | res = [{'image_id': i, 'caption': res[i]} for i in range(batch_size)] 261 | 262 | if expand_feat == 1: 263 | gts = {i: gts[(i % batch_size) // seq_per_img] 264 | for i in range(batch_size)} 265 | else: 266 | gts = {i: gts[i % batch_size] for i in range(batch_size)} 267 | 268 | _, scores = bcmr_scorer.compute_score(gts, res) 269 | 270 | # if bleu, only use bleu_4 271 | if isinstance(bcmr_scorer, Bleu): 272 | scores = scores[-1] 273 | 274 | # happens for BLeu and METEOR 275 | if type(scores) == list: 276 | scores = np.array(scores) 277 | 278 | scores = scores.reshape(-1, seq_per_img) 279 | 280 | elif bcmrscores is not None and use_mixer == 0: 281 | # use pre-computed scores only when mixer is not used 282 | scores = bcmrscores.copy() 283 | else: 284 | raise ValueError('bcmrscores is not set!') 285 | 286 | if scb_captions > 0: 287 | 288 | sorted_scores = np.sort(scores, axis=1) 289 | 290 | if scb_baseline == 1: 291 | # compute baseline from GT scores 292 | sorted_bcmrscores = np.sort(bcmrscores, axis=1) 293 | m_score = np.mean(scores) 294 | b_score = np.mean(bcmrscores) 295 | elif scb_baseline == 2: 296 | # compute baseline from sampled scores 297 | m_score = np.mean(sorted_scores) 298 | b_score = np.mean(sorted_scores[:,:scb_captions]) 299 | else: 300 | raise ValueError('unknown scb_baseline!') 301 | 302 | for ii in range(scores.shape[0]): 303 | if scb_baseline == 1: 304 | b = np.mean(sorted_bcmrscores[ii,:scb_captions]) 305 | elif scb_baseline == 2: 306 | b = np.mean(sorted_scores[ii,:scb_captions]) 307 | else: 308 | b = 0 309 | scores[ii] = scores[ii] - b 310 | 311 | else: 312 | m_score = np.mean(scores) 313 | b_score = 0 314 | 315 | scores = scores.reshape(-1) 316 | rewards = np.repeat(scores[:, np.newaxis], model_res.shape[1], 1) 317 | 318 | return rewards, m_score, b_score 319 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | 5 | def parse_opts(): 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | '--train_label_h5', 10 | type=str, 11 | help='path to the h5file containing the preprocessed dataset') 12 | parser.add_argument( 13 | '--val_label_h5', 14 | type=str, 15 | help='path to the h5file containing the preprocessed dataset') 16 | parser.add_argument( 17 | '--test_label_h5', 18 | type=str, 19 | help='path to the h5file containing the preprocessed dataset') 20 | 21 | parser.add_argument( 22 | '--train_feat_h5', 23 | type=str, 24 | nargs='+', 25 | help='path to the h5 file containing extracted features') 26 | parser.add_argument( 27 | '--val_feat_h5', 28 | type=str, 29 | nargs='+', 30 | help='path to the h5 file containing extracted features') 31 | parser.add_argument( 32 | '--test_feat_h5', 33 | type=str, 34 | nargs='+', 35 | help='path to the h5 file containing extracted features') 36 | 37 | parser.add_argument( 38 | '--train_cocofmt_file', 39 | type=str, 40 | help='Gold captions in MSCOCO format to cal language metrics') 41 | parser.add_argument( 42 | '--val_cocofmt_file', 43 | type=str, 44 | help='Gold captions in MSCOCO format to cal language metrics') 45 | parser.add_argument( 46 | '--test_cocofmt_file', 47 | type=str, 48 | help='Gold captions in MSCOCO format to cal language metrics') 49 | 50 | parser.add_argument( 51 | '--train_bcmrscores_pkl', 52 | type=str, 53 | help='Pre-computed Cider-D metric for all captions') 54 | 55 | # Optimization: General 56 | parser.add_argument( 57 | '--max_patience', 58 | type=int, 59 | default=5, 60 | help='max number of epoch to run since the minima is detected -- early stopping') 61 | parser.add_argument( 62 | '--batch_size', 63 | type=int, 64 | default=128, 65 | help='Video batch size (there will be x seq_per_img sentences)') 66 | parser.add_argument( 67 | '--test_batch_size', 68 | type=int, 69 | default=32, 70 | help='what is the batch size in number of images per batch? (there will be x seq_per_img sentences)') 71 | parser.add_argument( 72 | '--train_seq_per_img', 73 | type=int, 74 | default=20, 75 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive.') 76 | parser.add_argument( 77 | '--test_seq_per_img', 78 | type=int, 79 | default=20, 80 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive.') 81 | parser.add_argument( 82 | '--learning_rate', 83 | type=float, 84 | default=1e-4, 85 | help='learning rate') 86 | parser.add_argument('--lr_update', default=50, type=int, 87 | help='Number of epochs to update the learning rate.') 88 | 89 | # Model settings 90 | parser.add_argument( 91 | '--rnn_type', 92 | type=str, 93 | default='lstm', 94 | choices=[ 95 | 'lstm', 96 | 'gru', 97 | 'rnn'], 98 | help='type of RNN') 99 | parser.add_argument( 100 | '--rnn_size', 101 | type=int, 102 | default=512, 103 | help='size of the rnn in number of hidden nodes in each layer') 104 | parser.add_argument( 105 | '--num_lm_layer', 106 | type=int, 107 | default=1, 108 | help='size of the rnn in number of hidden nodes in each layer') 109 | parser.add_argument( 110 | '--input_encoding_size', 111 | type=int, 112 | default=512, 113 | help='the encoding size of each frame in the video.') 114 | parser.add_argument( 115 | '--max_epochs', 116 | type=int, 117 | default=sys.maxsize, 118 | help='max number of epochs to run for (-1 = run forever)') 119 | parser.add_argument( 120 | '--grad_clip', 121 | type=float, 122 | default=0.25, 123 | help='clip gradients at this value (note should be lower than usual 5 because we normalize grads by both batch and seq_length)') 124 | parser.add_argument( 125 | '--drop_prob_lm', 126 | type=float, 127 | default=0.5, 128 | help='strength of dropout in the Language Model RNN') 129 | 130 | # Optimization: for the Language Model 131 | parser.add_argument( 132 | '--optim', 133 | type=str, 134 | default='adam', 135 | help='what update to use? sgd|sgdmom|adagrad|adam') 136 | parser.add_argument( 137 | '--optim_alpha', 138 | type=float, 139 | default=0.8, 140 | help='alpha for adagrad/rmsprop/momentum/adam') 141 | parser.add_argument( 142 | '--optim_beta', 143 | type=float, 144 | default=0.999, 145 | help='beta used for adam') 146 | parser.add_argument( 147 | '--optim_epsilon', 148 | type=float, 149 | default=1e-8, 150 | help='epsilon that goes into denominator for smoothing') 151 | 152 | # Evaluation/Checkpointing 153 | parser.add_argument( 154 | '--save_checkpoint_from', 155 | type=int, 156 | default=20, 157 | help='Start saving checkpoint from this epoch') 158 | parser.add_argument( 159 | '--save_checkpoint_every', 160 | type=int, 161 | default=1, 162 | help='how often to save a model checkpoint in epochs?') 163 | 164 | parser.add_argument( 165 | '--use_rl', 166 | type=int, 167 | default=0, 168 | help='Use RL training or not') 169 | parser.add_argument( 170 | '--use_rl_after', 171 | type=int, 172 | default=30, 173 | help='Start RL training after this epoch') 174 | parser.add_argument( 175 | '--train_cached_tokens', 176 | type=str, 177 | default=30, 178 | help='Path to idx document frequencies to cal Cider on training data') 179 | parser.add_argument( 180 | '--expand_feat', 181 | type=int, 182 | default=1, 183 | help='To expand features when sampling (to multiple captions)') 184 | 185 | parser.add_argument('--model_file', type=str, help='output model file') 186 | parser.add_argument('--result_file', type=str, help='output result file') 187 | parser.add_argument( 188 | '--start_from', 189 | type=str, 190 | default='', 191 | help='Load state from this file to continue training') 192 | parser.add_argument( 193 | '--language_eval', 194 | type=int, 195 | default=1, 196 | help='Evaluate language evaluation') 197 | parser.add_argument( 198 | '--eval_metric', 199 | default='CIDEr', 200 | choices=[ 201 | 'Loss', 202 | 'Bleu_4', 203 | 'METEOR', 204 | 'ROUGE_L', 205 | 'CIDEr', 206 | 'MSRVTT'], 207 | help='Evaluation metrics') 208 | parser.add_argument( 209 | '--test_language_eval', 210 | type=int, 211 | default=1, 212 | help='Evaluate language evaluation') 213 | 214 | parser.add_argument( 215 | '--print_log_interval', 216 | type=int, 217 | default=20, 218 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') 219 | parser.add_argument( 220 | '--loglevel', 221 | type=str, 222 | default='DEBUG', 223 | choices=[ 224 | 'DEBUG', 225 | 'INFO', 226 | 'WARNING', 227 | 'ERROR', 228 | 'CRITICAL']) 229 | 230 | # misc 231 | parser.add_argument( 232 | '--seed', 233 | type=int, 234 | default=123, 235 | help='random number generator seed to use') 236 | parser.add_argument( 237 | '--gpuid', 238 | type=int, 239 | default=7, 240 | help='which gpu to use. -1 = use CPU') 241 | parser.add_argument( 242 | '--num_chunks', 243 | type=int, 244 | default=1, 245 | help='1: no attention, > 1: attention with num_chunks') 246 | parser.add_argument( 247 | '--num_layers', 248 | type=int, 249 | default=1, 250 | help='number of layers in the lstm ') 251 | 252 | parser.add_argument( 253 | '--model_type', 254 | type=str, 255 | default='concat', 256 | choices=[ 257 | 'standard', 258 | 'concat', 259 | 'manet', 260 | ], 261 | help='Type of models') 262 | 263 | parser.add_argument( 264 | '--beam_size', 265 | type=int, 266 | default=5, 267 | help='Beam search size') 268 | 269 | parser.add_argument( 270 | '--use_ss', 271 | type=int, 272 | default=0, 273 | help='Use schedule sampling') 274 | parser.add_argument( 275 | '--use_ss_after', 276 | type=int, 277 | default=0, 278 | help='Use schedule sampling after this epoch') 279 | parser.add_argument( 280 | '--ss_max_prob', 281 | type=float, 282 | default=0.25, 283 | help='Use schedule sampling') 284 | parser.add_argument( 285 | '--ss_k', 286 | type=float, 287 | default=30.0, 288 | help='plot k/(k+exp(x/k)) from x=0 to 400, k=30') 289 | 290 | parser.add_argument( 291 | '--use_mixer', 292 | type=int, 293 | default=1, 294 | help='Use schedule sampling') 295 | parser.add_argument( 296 | '--mixer_from', 297 | type=int, 298 | default=-1, 299 | help='If -1, then an annealing scheme will be used, based on mixer_descrease_every.\ 300 | Initially it will set to the max_seq_length (30), and will be gradually descreased to 1.\ 301 | If this value is set to 1 from the begininig, then the MIXER approach is not applied') 302 | parser.add_argument( 303 | '--mixer_descrease_every', 304 | type=int, 305 | default=2, 306 | help='Epoch interval to descrease mixing value') 307 | parser.add_argument( 308 | '--use_cst', 309 | type=int, 310 | default=0, 311 | help='Use cst training') 312 | parser.add_argument( 313 | '--use_cst_after', 314 | type=int, 315 | default=0, 316 | help='Start cst training after this epoch') 317 | parser.add_argument( 318 | '--cst_increase_every', 319 | type=int, 320 | default=5, 321 | help='Epoch interval to increase cst baseline') 322 | parser.add_argument( 323 | '--scb_baseline', 324 | type=int, 325 | default=1, 326 | help='which Self-consensus baseline (SCB) to use? 1: GT SCB, 2: Model Sample SCB') 327 | parser.add_argument( 328 | '--scb_captions', 329 | type=int, 330 | default=20, 331 | help='-1: annealing, otherwise using this fixed number to be the number of captions to compute SCB') 332 | parser.add_argument( 333 | '--use_eos', 334 | type=int, 335 | default=0, 336 | help='If 1, keep in captions of the reference set') 337 | parser.add_argument( 338 | '--output_logp', 339 | type=int, 340 | default=0, 341 | help='Output average log likehood of the test and GT captions. Used for robustness analysis at test time.') 342 | 343 | 344 | args = parser.parse_args() 345 | return args 346 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | from torch.nn.utils import clip_grad_norm 7 | import numpy as np 8 | import os 9 | import sys 10 | import time 11 | import math 12 | import json 13 | import uuid 14 | import logging 15 | from datetime import datetime 16 | from six.moves import cPickle 17 | 18 | from dataloader import DataLoader 19 | from model import CaptionModel, CrossEntropyCriterion, RewardCriterion 20 | 21 | import utils 22 | import opts 23 | 24 | import sys 25 | sys.path.append("cider") 26 | from pyciderevalcap.cider.cider import Cider 27 | from pyciderevalcap.ciderD.ciderD import CiderD 28 | 29 | sys.path.append('coco-caption') 30 | from pycocoevalcap.bleu.bleu import Bleu 31 | from pycocoevalcap.meteor.meteor import Meteor 32 | from pycocoevalcap.rouge.rouge import Rouge 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def language_eval(predictions, cocofmt_file, opt): 38 | logger.info('>>> Language evaluating ...') 39 | tmp_checkpoint_json = os.path.join( 40 | opt.model_file + str(uuid.uuid4()) + '.json') 41 | json.dump(predictions, open(tmp_checkpoint_json, 'w')) 42 | lang_stats = utils.language_eval(cocofmt_file, tmp_checkpoint_json) 43 | os.remove(tmp_checkpoint_json) 44 | return lang_stats 45 | 46 | 47 | def train( 48 | model, 49 | criterion, 50 | optimizer, 51 | train_loader, 52 | val_loader, 53 | opt, 54 | rl_criterion=None): 55 | 56 | infos = {'iter': 0, 57 | 'epoch': 0, 58 | 'start_epoch': 0, 59 | 'best_score': float('-inf'), 60 | 'best_iter': 0, 61 | 'best_epoch': opt.max_epochs 62 | } 63 | 64 | checkpoint_checked = False 65 | rl_training = False 66 | seq_per_img = train_loader.get_seq_per_img() 67 | infos_history = {} 68 | 69 | if os.path.exists(opt.start_from): 70 | if os.path.isdir(opt.start_from): 71 | # loading the same model file at a different experiment dir 72 | start_from_file = os.path.join( 73 | opt.start_from, os.path.basename( 74 | opt.model_file)) 75 | else: 76 | start_from_file = opt.start_from 77 | logger.info('Loading state from: %s', start_from_file) 78 | checkpoint = torch.load(start_from_file) 79 | model.load_state_dict(checkpoint['model']) 80 | infos = checkpoint['infos'] 81 | infos['start_epoch'] = infos['epoch'] 82 | checkpoint_checked = True # this epoch is already checked 83 | else: 84 | logger.info('No checkpoint found! Training from the scratch') 85 | 86 | if opt.use_rl == 1 and opt.use_rl_after == 0: 87 | opt.use_rl_after = infos['epoch'] 88 | opt.use_cst_after = infos['epoch'] 89 | train_loader.set_current_epoch(infos['epoch']) 90 | 91 | while True: 92 | t_start = time.time() 93 | model.train() 94 | data = train_loader.get_batch() 95 | feats = [Variable(feat, volatile=False) for feat in data['feats']] 96 | labels = Variable(data['labels'], volatile=False) 97 | masks = Variable(data['masks'], volatile=False) 98 | 99 | if torch.cuda.is_available(): 100 | feats = [feat.cuda() for feat in feats] 101 | labels = labels.cuda() 102 | masks = masks.cuda() 103 | 104 | # implement scheduled sampling 105 | opt.ss_prob = 0 106 | if opt.use_ss == 1 and infos['epoch'] >= opt.use_ss_after: 107 | annealing_prob = opt.ss_k / \ 108 | (opt.ss_k + np.exp((infos['epoch'] - opt.use_ss_after) / opt.ss_k)) 109 | opt.ss_prob = min(1 - annealing_prob, opt.ss_max_prob) 110 | model.set_ss_prob(opt.ss_prob) 111 | 112 | if opt.use_rl == 1 and infos[ 113 | 'epoch'] >= opt.use_rl_after and not rl_training: 114 | logger.info('Using RL objective...') 115 | rl_training = True 116 | bcmr_scorer = { 117 | 'Bleu_4': Bleu(), 118 | 'CIDEr': CiderD(df=opt.train_cached_tokens), 119 | 'METEOR': Meteor(), 120 | 'ROUGE_L': Rouge() 121 | }[opt.eval_metric] 122 | 123 | #logger.info('loading gt refs: %s', train_loader.cocofmt_file) 124 | #gt_refs = utils.load_gt_refs(train_loader.cocofmt_file) 125 | 126 | mixer_from = opt.mixer_from 127 | if opt.use_mixer == 1 and rl_training: 128 | #annealing_mixer = opt.ss_k / \ 129 | # (opt.ss_k + np.exp((infos['epoch'] - opt.use_rl_after) / opt.ss_k)) 130 | #annealing_mixer = int(round(annealing_mixer * opt.seq_length)) 131 | 132 | # -1 for annealing 133 | if opt.mixer_from == -1: 134 | annealing_mixer = opt.seq_length - int(np.ceil((infos['epoch']-opt.use_rl_after+1)/float(opt.mixer_descrease_every))) 135 | mixer_from = max(1, annealing_mixer) 136 | 137 | model.set_mixer_from(mixer_from) 138 | 139 | scb_captions = opt.scb_captions 140 | if opt.use_cst == 1 and rl_training: 141 | # if opt.use_cst == 1 and opt.ss_k == 0, 142 | # then do not using annealing, but the fixed scb_captions provided 143 | #annealing_robust = opt.ss_k / \ 144 | # (opt.ss_k + np.exp((infos['epoch'] - opt.use_rl_after) / opt.ss_k)) 145 | #annealing_robust = int(round((1 - annealing_robust) * seq_per_img)) 146 | 147 | # do not use robust before fully mixed 148 | # if opt.use_mixer == 1 and mixer_from > 1: 149 | # opt.use_cst_after = infos['epoch'] 150 | 151 | # if opt.scb_captions is -1, then use the annealing value, 152 | # otherwise, use the set value 153 | if opt.scb_captions == -1: 154 | annealing_robust = int(np.ceil((infos['epoch']-opt.use_cst_after+1)/float(opt.cst_increase_every))) 155 | scb_captions = min(annealing_robust, seq_per_img-1) 156 | 157 | optimizer.zero_grad() 158 | model.set_seq_per_img(seq_per_img) 159 | 160 | if rl_training: 161 | # sampling from model distribution 162 | # model_res, logprobs = model.sample( 163 | # feats, {'sample_max': 0, 'expand_feat': opt.expand_feat, 'temperature': 1}) 164 | 165 | # using mixer 166 | pred, model_res, logprobs = model(feats, labels) 167 | 168 | if opt.use_cst == 0: 169 | # greedy decoding baseline in SCST paper 170 | greedy_baseline, _ = model.sample([Variable(f.data, volatile=True) for f in feats], 171 | {'sample_max': 1, 'expand_feat': opt.expand_feat}) 172 | 173 | """ 174 | if opt.loglevel.upper() == 'DEBUG' and opt.use_cst == 0: 175 | model_sents = utils.decode_sequence(opt.vocab, model_res) 176 | baseline_sents = utils.decode_sequence(opt.vocab, greedy_baseline) 177 | for jj, sent in enumerate(zip(model_sents, baseline_sents)): 178 | if opt.expand_feat == 1: 179 | video_id = data['ids'][ 180 | jj // train_loader.get_seq_per_img()] 181 | else: 182 | video_id = data['ids'][jj] 183 | logger.debug( 184 | '[%d] video %s\n\t Model: %s \n\t Greedy: %s' % 185 | (jj, video_id, sent[0], sent[1])) 186 | """ 187 | 188 | if opt.use_cst == 1: 189 | bcmrscores = data['bcmrscores'] 190 | reward, m_score, g_score = utils.get_cst_reward(model_res, data['gts'], bcmr_scorer, 191 | bcmrscores=bcmrscores, 192 | expand_feat=opt.expand_feat, 193 | seq_per_img=train_loader.get_seq_per_img(), 194 | scb_captions=scb_captions, 195 | scb_baseline=opt.scb_baseline, 196 | use_eos=opt.use_eos, 197 | use_mixer=opt.use_mixer 198 | ) 199 | else: 200 | # use greedy baseline by default, compute self-critical reward 201 | reward, m_score, g_score = utils.get_self_critical_reward(model_res, greedy_baseline, data['gts'], bcmr_scorer, 202 | expand_feat=opt.expand_feat, 203 | seq_per_img=train_loader.get_seq_per_img(), 204 | use_eos=opt.use_eos) 205 | 206 | """[[ 207 | #import pdb; pdb.set_trace() 208 | rl_loss = 0 209 | xe_loss = 0 210 | # -1 because we don't count here 211 | 212 | if mixer_from < model_res.size(1)-1: 213 | rl_loss = rl_criterion( 214 | model_res[:,mixer_from:], 215 | logprobs[:,mixer_from:], 216 | Variable( 217 | torch.from_numpy(reward[:,mixer_from:]).float().cuda(), 218 | requires_grad=False)) 219 | 220 | if mixer_from > 0: 221 | xe_loss = criterion(pred[:, :mixer_from], labels[:, 1:mixer_from+1], masks[:, 1:mixer_from+1]) 222 | 223 | loss = rl_loss + xe_loss 224 | """ 225 | 226 | loss = rl_criterion( 227 | model_res, 228 | logprobs, 229 | Variable( 230 | torch.from_numpy(reward).float().cuda(), 231 | requires_grad=False)) 232 | 233 | else: 234 | pred = model(feats, labels)[0] 235 | loss = criterion(pred, labels[:, 1:], masks[:, 1:]) 236 | 237 | loss.backward() 238 | clip_grad_norm(model.parameters(), opt.grad_clip) 239 | optimizer.step() 240 | infos['TrainLoss'] = loss.data[0] 241 | infos['mixer_from'] = mixer_from 242 | infos['scb_captions'] = scb_captions 243 | 244 | if infos['iter'] % opt.print_log_interval == 0: 245 | elapsed_time = time.time() - t_start 246 | 247 | log_info = [('Epoch', infos['epoch']), 248 | ('Iter', infos['iter']), 249 | ('Loss', infos['TrainLoss'])] 250 | 251 | if rl_training: 252 | log_info += [('Reward', np.mean(reward[:, 0])), 253 | ('{} (m)'.format(opt.eval_metric), m_score), 254 | ('{} (b)'.format(opt.eval_metric), g_score)] 255 | 256 | if opt.use_ss == 1: 257 | log_info += [('ss_prob', opt.ss_prob)] 258 | 259 | if opt.use_mixer == 1: 260 | log_info += [('mixer_from', mixer_from)] 261 | 262 | if opt.use_cst == 1: 263 | log_info += [('scb_captions', scb_captions)] 264 | 265 | log_info += [('Time', elapsed_time)] 266 | logger.info('%s', '\t'.join( 267 | ['{}: {}'.format(k, v) for (k, v) in log_info])) 268 | 269 | infos['iter'] += 1 270 | 271 | if infos['epoch'] < train_loader.get_current_epoch(): 272 | infos['epoch'] = train_loader.get_current_epoch() 273 | checkpoint_checked = False 274 | learning_rate = utils.adjust_learning_rate( 275 | opt, optimizer, infos['epoch'] - infos['start_epoch']) 276 | logger.info('===> Learning rate: %f: ', learning_rate) 277 | 278 | if (infos['epoch'] >= opt.save_checkpoint_from and 279 | infos['epoch'] % opt.save_checkpoint_every == 0 and 280 | not checkpoint_checked): 281 | # evaluate the validation performance 282 | results = validate(model, criterion, val_loader, opt) 283 | logger.info( 284 | 'Validation output: %s', 285 | json.dumps( 286 | results['scores'], 287 | indent=4, 288 | sort_keys=True)) 289 | infos.update(results['scores']) 290 | 291 | check_model(model, opt, infos, infos_history) 292 | checkpoint_checked = True 293 | 294 | if (infos['epoch'] >= opt.max_epochs or 295 | infos['epoch'] - infos['best_epoch'] > opt.max_patience): 296 | logger.info('>>> Terminating...') 297 | break 298 | 299 | return infos 300 | 301 | 302 | def validate(model, criterion, loader, opt): 303 | 304 | model.eval() 305 | loader.reset() 306 | 307 | num_videos = loader.get_num_videos() 308 | batch_size = loader.get_batch_size() 309 | num_iters = int(math.ceil(num_videos * 1.0 / batch_size)) 310 | last_batch_size = num_videos % batch_size 311 | seq_per_img = loader.get_seq_per_img() 312 | model.set_seq_per_img(seq_per_img) 313 | 314 | loss_sum = 0 315 | logger.info( 316 | '#num_iters: %d, batch_size: %d, seg_per_image: %d', 317 | num_iters, 318 | batch_size, 319 | seq_per_img) 320 | predictions = [] 321 | gt_avglogps = [] 322 | test_avglogps = [] 323 | for ii in range(num_iters): 324 | data = loader.get_batch() 325 | feats = [Variable(feat, volatile=True) for feat in data['feats']] 326 | if loader.has_label: 327 | labels = Variable(data['labels'], volatile=True) 328 | masks = Variable(data['masks'], volatile=True) 329 | 330 | if ii == (num_iters - 1) and last_batch_size > 0: 331 | feats = [f[:last_batch_size] for f in feats] 332 | if loader.has_label: 333 | labels = labels[ 334 | :last_batch_size * 335 | seq_per_img] # labels shape is DxN 336 | masks = masks[:last_batch_size * seq_per_img] 337 | 338 | if torch.cuda.is_available(): 339 | feats = [feat.cuda() for feat in feats] 340 | if loader.has_label: 341 | labels = labels.cuda() 342 | masks = masks.cuda() 343 | 344 | if loader.has_label: 345 | pred, gt_seq, gt_logseq = model(feats, labels) 346 | if opt.output_logp == 1: 347 | gt_avglogp = utils.compute_avglogp(gt_seq, gt_logseq.data) 348 | gt_avglogps.extend(gt_avglogp) 349 | 350 | loss = criterion(pred, labels[:, 1:], masks[:, 1:]) 351 | loss_sum += loss.data[0] 352 | 353 | seq, logseq = model.sample(feats, {'beam_size': opt.beam_size}) 354 | sents = utils.decode_sequence(opt.vocab, seq) 355 | if opt.output_logp == 1: 356 | test_avglogp = utils.compute_avglogp(seq, logseq) 357 | test_avglogps.extend(test_avglogp) 358 | 359 | for jj, sent in enumerate(sents): 360 | if opt.output_logp == 1: 361 | entry = {'image_id': data['ids'][jj], 'caption': sent, 'avglogp': test_avglogp[jj]} 362 | else: 363 | entry = {'image_id': data['ids'][jj], 'caption': sent} 364 | predictions.append(entry) 365 | logger.debug('[%d] video %s: %s' % 366 | (jj, entry['image_id'], entry['caption'])) 367 | 368 | loss = round(loss_sum / num_iters, 3) 369 | results = {} 370 | lang_stats = {} 371 | 372 | if opt.language_eval == 1 and loader.has_label: 373 | logger.info('>>> Language evaluating ...') 374 | tmp_checkpoint_json = os.path.join( 375 | opt.model_file + str(uuid.uuid4()) + '.json') 376 | json.dump(predictions, open(tmp_checkpoint_json, 'w')) 377 | lang_stats = utils.language_eval( 378 | loader.cocofmt_file, tmp_checkpoint_json) 379 | os.remove(tmp_checkpoint_json) 380 | 381 | results['predictions'] = predictions 382 | results['scores'] = {'Loss': -loss} 383 | results['scores'].update(lang_stats) 384 | 385 | if opt.output_logp == 1: 386 | avglogp = sum(test_avglogps)/float(len(test_avglogps)) 387 | results['scores'].update({'avglogp': avglogp}) 388 | 389 | gt_avglogps = np.array(gt_avglogps).reshape(-1, seq_per_img) 390 | assert num_videos == gt_avglogps.shape[0] 391 | 392 | gt_avglogps_file = opt.model_file.replace('.pth', '_gt_avglogps.pkl', 1) 393 | cPickle.dump(gt_avglogps, open( 394 | gt_avglogps_file, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) 395 | 396 | logger.info('Wrote GT logp to: %s', gt_avglogps_file) 397 | 398 | return results 399 | 400 | 401 | def test(model, criterion, loader, opt): 402 | 403 | results = validate(model, criterion, loader, opt) 404 | logger.info('Test output: %s', json.dumps(results['scores'], indent=4)) 405 | 406 | json.dump(results, open(opt.result_file, 'w')) 407 | logger.info('Wrote output caption to: %s ', opt.result_file) 408 | 409 | 410 | def check_model(model, opt, infos, infos_history): 411 | 412 | if opt.eval_metric == 'MSRVTT': 413 | current_score = infos['Bleu_4'] + \ 414 | infos['METEOR'] + infos['ROUGE_L'] + infos['CIDEr'] 415 | else: 416 | current_score = infos[opt.eval_metric] 417 | 418 | # write the full model checkpoint as well if we did better than ever 419 | if current_score >= infos['best_score']: 420 | infos['best_score'] = current_score 421 | infos['best_iter'] = infos['iter'] 422 | infos['best_epoch'] = infos['epoch'] 423 | 424 | logger.info( 425 | '>>> Found new best [%s] score: %f, at iter: %d, epoch %d', 426 | opt.eval_metric, 427 | current_score, 428 | infos['iter'], 429 | infos['epoch']) 430 | 431 | torch.save({'model': model.state_dict(), 432 | 'infos': infos, 433 | 'opt': opt 434 | }, opt.model_file) 435 | logger.info('Wrote checkpoint to: %s', opt.model_file) 436 | 437 | else: 438 | logger.info('>>> Current best [%s] score: %f, at iter %d, epoch %d', 439 | opt.eval_metric, infos['best_score'], 440 | infos['best_iter'], 441 | infos['best_epoch']) 442 | 443 | infos_history[infos['epoch']] = infos.copy() 444 | with open(opt.history_file, 'w') as of: 445 | json.dump(infos_history, of) 446 | logger.info('Updated history to: %s', opt.history_file) 447 | 448 | if __name__ == '__main__': 449 | 450 | opt = opts.parse_opts() 451 | 452 | logging.basicConfig(level=getattr(logging, opt.loglevel.upper()), 453 | format='%(asctime)s:%(levelname)s: %(message)s') 454 | 455 | logger.info( 456 | 'Input arguments: %s', 457 | json.dumps( 458 | vars(opt), 459 | sort_keys=True, 460 | indent=4)) 461 | 462 | # Set the random seed manually for reproducibility. 463 | np.random.seed(opt.seed) 464 | torch.manual_seed(opt.seed) 465 | if torch.cuda.is_available(): 466 | torch.cuda.manual_seed(opt.seed) 467 | 468 | train_opt = {'label_h5': opt.train_label_h5, 469 | 'batch_size': opt.batch_size, 470 | 'feat_h5': opt.train_feat_h5, 471 | 'cocofmt_file': opt.train_cocofmt_file, 472 | 'bcmrscores_pkl': opt.train_bcmrscores_pkl, 473 | 'eval_metric': opt.eval_metric, 474 | 'seq_per_img': opt.train_seq_per_img, 475 | 'num_chunks': opt.num_chunks, 476 | 'mode': 'train' 477 | } 478 | 479 | val_opt = {'label_h5': opt.val_label_h5, 480 | 'batch_size': opt.test_batch_size, 481 | 'feat_h5': opt.val_feat_h5, 482 | 'cocofmt_file': opt.val_cocofmt_file, 483 | 'seq_per_img': opt.test_seq_per_img, 484 | 'num_chunks': opt.num_chunks, 485 | 'mode': 'test' 486 | } 487 | 488 | test_opt = {'label_h5': opt.test_label_h5, 489 | 'batch_size': opt.test_batch_size, 490 | 'feat_h5': opt.test_feat_h5, 491 | 'cocofmt_file': opt.test_cocofmt_file, 492 | 'seq_per_img': opt.test_seq_per_img, 493 | 'num_chunks': opt.num_chunks, 494 | 'mode': 'test' 495 | } 496 | 497 | train_loader = DataLoader(train_opt) 498 | val_loader = DataLoader(val_opt) 499 | test_loader = DataLoader(test_opt) 500 | 501 | opt.vocab = train_loader.get_vocab() 502 | opt.vocab_size = train_loader.get_vocab_size() 503 | opt.seq_length = train_loader.get_seq_length() 504 | opt.feat_dims = train_loader.get_feat_dims() 505 | opt.history_file = opt.model_file.replace('.pth', '_history.json', 1) 506 | 507 | logger.info('Building model...') 508 | model = CaptionModel(opt) 509 | 510 | xe_criterion = CrossEntropyCriterion() 511 | rl_criterion = RewardCriterion() 512 | 513 | if torch.cuda.is_available(): 514 | model.cuda() 515 | xe_criterion.cuda() 516 | rl_criterion.cuda() 517 | 518 | logger.info('Start training...') 519 | start = datetime.now() 520 | 521 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) 522 | infos = train( 523 | model, 524 | xe_criterion, 525 | optimizer, 526 | train_loader, 527 | val_loader, 528 | opt, 529 | rl_criterion=rl_criterion) 530 | logger.info( 531 | 'Best val %s score: %f. Best iter: %d. Best epoch: %d', 532 | opt.eval_metric, 533 | infos['best_score'], 534 | infos['best_iter'], 535 | infos['best_epoch']) 536 | 537 | logger.info('Training time: %s', datetime.now() - start) 538 | 539 | if opt.result_file: 540 | logger.info('Start testing...') 541 | start = datetime.now() 542 | 543 | logger.info('Loading model: %s', opt.model_file) 544 | checkpoint = torch.load(opt.model_file) 545 | model.load_state_dict(checkpoint['model']) 546 | 547 | test(model, xe_criterion, test_loader, opt) 548 | logger.info('Testing time: %s', datetime.now() - start) 549 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | def to_contiguous(tensor): 9 | if tensor.is_contiguous(): 10 | return tensor 11 | else: 12 | return tensor.contiguous() 13 | 14 | 15 | class RewardCriterion(nn.Module): 16 | 17 | def __init__(self): 18 | super(RewardCriterion, self).__init__() 19 | 20 | def forward(self, seq, logprobs, reward): 21 | 22 | # import pdb; pdb.set_trace() 23 | logprobs = to_contiguous(logprobs).view(-1) 24 | reward = to_contiguous(reward).view(-1) 25 | mask = (seq > 0).float() 26 | # add one to the right to count for the token 27 | mask = to_contiguous(torch.cat( 28 | [mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 29 | #import pdb; pdb.set_trace() 30 | output = - logprobs * reward * Variable(mask) 31 | output = torch.sum(output) / torch.sum(mask) 32 | 33 | return output 34 | 35 | 36 | class CrossEntropyCriterion(nn.Module): 37 | 38 | def __init__(self): 39 | super(CrossEntropyCriterion, self).__init__() 40 | 41 | def forward(self, pred, target, mask): 42 | # truncate to the same size 43 | target = target[:, :pred.size(1)] 44 | mask = mask[:, :pred.size(1)] 45 | 46 | pred = to_contiguous(pred).view(-1, pred.size(2)) 47 | target = to_contiguous(target).view(-1, 1) 48 | mask = to_contiguous(mask).view(-1, 1) 49 | 50 | output = -pred.gather(1, target) * mask 51 | output = torch.sum(output) / torch.sum(mask) 52 | 53 | return output 54 | 55 | 56 | class FeatPool(nn.Module): 57 | 58 | def __init__(self, feat_dims, out_size, dropout): 59 | super(FeatPool, self).__init__() 60 | 61 | module_list = [] 62 | for dim in feat_dims: 63 | module = nn.Sequential( 64 | nn.Linear( 65 | dim, 66 | out_size), 67 | nn.ReLU(), 68 | nn.Dropout(dropout)) 69 | module_list += [module] 70 | self.feat_list = nn.ModuleList(module_list) 71 | 72 | # self.embed = nn.Sequential(nn.Linear(sum(feat_dims), out_size), nn.ReLU(), nn.Dropout(dropout)) 73 | 74 | def forward(self, feats): 75 | """ 76 | feats is a list, each element is a tensor that have size (N x C x F) 77 | at the moment assuming that C == 1 78 | """ 79 | out = torch.cat([m(feats[i].squeeze(1)) 80 | for i, m in enumerate(self.feat_list)], 1) 81 | # pdb.set_trace() 82 | # out = self.embed(torch.cat(feats, 2).squeeze(1)) 83 | return out 84 | 85 | 86 | class FeatExpander(nn.Module): 87 | 88 | def __init__(self, n=1): 89 | super(FeatExpander, self).__init__() 90 | self.n = n 91 | 92 | def forward(self, x): 93 | if self.n == 1: 94 | out = x 95 | else: 96 | out = Variable( 97 | x.data.new( 98 | self.n * x.size(0), 99 | x.size(1)), 100 | volatile=x.volatile) 101 | for i in range(x.size(0)): 102 | out[i * self.n:(i + 1) * 103 | self.n] = x[i].expand(self.n, x.size(1)) 104 | return out 105 | 106 | def set_n(self, x): 107 | self.n = x 108 | 109 | 110 | class RNNUnit(nn.Module): 111 | 112 | def __init__(self, opt): 113 | super(RNNUnit, self).__init__() 114 | self.rnn_type = opt.rnn_type 115 | self.rnn_size = opt.rnn_size 116 | self.num_layers = opt.num_layers 117 | self.drop_prob_lm = opt.drop_prob_lm 118 | 119 | if opt.model_type == 'standard': 120 | self.input_size = opt.input_encoding_size 121 | elif opt.model_type in ['concat', 'manet']: 122 | self.input_size = opt.input_encoding_size + opt.video_encoding_size 123 | 124 | self.rnn = getattr( 125 | nn, 126 | self.rnn_type.upper())( 127 | self.input_size, 128 | self.rnn_size, 129 | self.num_layers, 130 | bias=False, 131 | dropout=self.drop_prob_lm) 132 | 133 | def forward(self, xt, state): 134 | output, state = self.rnn(xt.unsqueeze(0), state) 135 | return output.squeeze(0), state 136 | 137 | 138 | class MANet(nn.Module): 139 | """ 140 | MANet: Modal Attention 141 | """ 142 | 143 | def __init__(self, video_encoding_size, rnn_size, num_feats): 144 | super(MANet, self).__init__() 145 | self.video_encoding_size = video_encoding_size 146 | self.rnn_size = rnn_size 147 | self.num_feats = num_feats 148 | 149 | self.f_feat_m = nn.Linear(self.video_encoding_size, self.num_feats) 150 | self.f_h_m = nn.Linear(self.rnn_size, self.num_feats) 151 | self.align_m = nn.Linear(self.num_feats, self.num_feats) 152 | 153 | def forward(self, x, h): 154 | f_feat = self.f_feat_m(x) 155 | f_h = self.f_h_m(h.squeeze(0)) # assuming now num_layers is 1 156 | att_weight = nn.Softmax()(self.align_m(nn.Tanh()(f_feat + f_h))) 157 | att_weight = att_weight.unsqueeze(2).expand( 158 | x.size(0), self.num_feats, self.video_encoding_size / self.num_feats) 159 | att_weight = att_weight.contiguous().view(x.size(0), x.size(1)) 160 | return x * att_weight 161 | 162 | 163 | class CaptionModel(nn.Module): 164 | """ 165 | A baseline captioning model 166 | """ 167 | 168 | def __init__(self, opt): 169 | super(CaptionModel, self).__init__() 170 | self.vocab_size = opt.vocab_size 171 | self.input_encoding_size = opt.input_encoding_size 172 | self.rnn_type = opt.rnn_type 173 | self.rnn_size = opt.rnn_size 174 | self.num_layers = opt.num_layers 175 | self.drop_prob_lm = opt.drop_prob_lm 176 | self.seq_length = opt.seq_length 177 | self.feat_dims = opt.feat_dims 178 | self.num_feats = len(self.feat_dims) 179 | self.seq_per_img = opt.train_seq_per_img 180 | self.model_type = opt.model_type 181 | self.bos_index = 1 # index of the token 182 | self.ss_prob = 0 183 | self.mixer_from = 0 184 | 185 | self.embed = nn.Embedding(self.vocab_size, self.input_encoding_size) 186 | self.logit = nn.Linear(self.rnn_size, self.vocab_size) 187 | self.dropout = nn.Dropout(self.drop_prob_lm) 188 | 189 | self.init_weights() 190 | self.feat_pool = FeatPool( 191 | self.feat_dims, 192 | self.num_layers * 193 | self.rnn_size, 194 | self.drop_prob_lm) 195 | self.feat_expander = FeatExpander(self.seq_per_img) 196 | 197 | self.video_encoding_size = self.num_feats * self.num_layers * self.rnn_size 198 | opt.video_encoding_size = self.video_encoding_size 199 | self.core = RNNUnit(opt) 200 | 201 | if self.model_type == 'manet': 202 | self.manet = MANet( 203 | self.video_encoding_size, 204 | self.rnn_size, 205 | self.num_feats) 206 | 207 | def set_ss_prob(self, p): 208 | self.ss_prob = p 209 | 210 | def set_mixer_from(self, t): 211 | """Set values of mixer_from 212 | if mixer_from > 0 then start MIXER training 213 | i.e: 214 | from t = 0 -> t = mixer_from -1: use XE training 215 | from t = mixer_from -> end: use RL training 216 | """ 217 | self.mixer_from = t 218 | 219 | def set_seq_per_img(self, x): 220 | self.seq_per_img = x 221 | self.feat_expander.set_n(x) 222 | 223 | def init_weights(self): 224 | initrange = 0.1 225 | self.embed.weight.data.uniform_(-initrange, initrange) 226 | self.logit.bias.data.fill_(0) 227 | self.logit.weight.data.uniform_(-initrange, initrange) 228 | 229 | def init_hidden(self, batch_size): 230 | weight = next(self.parameters()).data 231 | 232 | if self.rnn_type == 'lstm': 233 | return ( 234 | Variable( 235 | weight.new( 236 | self.num_layers, 237 | batch_size, 238 | self.rnn_size).zero_()), 239 | Variable( 240 | weight.new( 241 | self.num_layers, 242 | batch_size, 243 | self.rnn_size).zero_())) 244 | else: 245 | return Variable( 246 | weight.new( 247 | self.num_layers, 248 | batch_size, 249 | self.rnn_size).zero_()) 250 | 251 | def forward(self, feats, seq): 252 | 253 | fc_feats = self.feat_pool(feats) 254 | fc_feats = self.feat_expander(fc_feats) 255 | 256 | batch_size = fc_feats.size(0) 257 | state = self.init_hidden(batch_size) 258 | outputs = [] 259 | sample_seq = [] 260 | sample_logprobs = [] 261 | 262 | # -- if is input at the first step, use index -1 263 | # -- the token is not used for training 264 | start_i = -1 if self.model_type == 'standard' else 0 265 | end_i = seq.size(1) - 1 266 | 267 | for token_idx in range(start_i, end_i): 268 | if token_idx == -1: 269 | xt = fc_feats 270 | else: 271 | # token_idx = 0 corresponding to the token 272 | # (already encoded in seq) 273 | 274 | if self.training and token_idx >= 1 and self.ss_prob > 0.0: 275 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 276 | sample_mask = sample_prob < self.ss_prob 277 | if sample_mask.sum() == 0: 278 | it = seq[:, token_idx].clone() 279 | else: 280 | sample_ind = sample_mask.nonzero().view(-1) 281 | it = seq[:, token_idx].data.clone() 282 | # fetch prev distribution: shape Nx(M+1) 283 | prob_prev = torch.exp(outputs[-1].data) 284 | sample_ind_tokens = torch.multinomial( 285 | prob_prev, 1).view(-1).index_select(0, sample_ind) 286 | it.index_copy_(0, sample_ind, sample_ind_tokens) 287 | it = Variable(it, requires_grad=False) 288 | elif self.training and self.mixer_from > 0 and token_idx >= self.mixer_from: 289 | prob_prev = torch.exp(outputs[-1].data) 290 | it = torch.multinomial(prob_prev, 1).view(-1) 291 | it = Variable(it, requires_grad=False) 292 | else: 293 | it = seq[:, token_idx].clone() 294 | 295 | if token_idx >= 1: 296 | # store the seq and its logprobs 297 | sample_seq.append(it.data) 298 | logprobs = outputs[-1].gather(1, it.unsqueeze(1)) 299 | sample_logprobs.append(logprobs.view(-1)) 300 | 301 | # break if all the sequences end, which requires EOS token = 0 302 | if it.data.sum() == 0: 303 | break 304 | xt = self.embed(it) 305 | 306 | if self.model_type == 'standard': 307 | output, state = self.core(xt, state) 308 | else: 309 | if self.model_type == 'manet': 310 | fc_feats = self.manet(fc_feats, state[0]) 311 | output, state = self.core(torch.cat([xt, fc_feats], 1), state) 312 | 313 | if token_idx >= 0: 314 | output = F.log_softmax(self.logit(self.dropout(output))) 315 | outputs.append(output) 316 | 317 | # only returns outputs of seq input 318 | # output size is: B x L x V (where L is truncated lengths 319 | # which are different for different batch) 320 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1), \ 321 | torch.cat([_.unsqueeze(1) for _ in sample_seq], 1), \ 322 | torch.cat([_.unsqueeze(1) for _ in sample_logprobs], 1) \ 323 | 324 | def sample(self, feats, opt={}): 325 | sample_max = opt.get('sample_max', 1) 326 | beam_size = opt.get('beam_size', 1) 327 | temperature = opt.get('temperature', 1.0) 328 | expand_feat = opt.get('expand_feat', 0) 329 | 330 | if beam_size > 1: 331 | return self.sample_beam(feats, opt) 332 | 333 | fc_feats = self.feat_pool(feats) 334 | if expand_feat == 1: 335 | fc_feats = self.feat_expander(fc_feats) 336 | batch_size = fc_feats.size(0) 337 | state = self.init_hidden(batch_size) 338 | 339 | seq = [] 340 | seqLogprobs = [] 341 | 342 | unfinished = fc_feats.data.new(batch_size).fill_(1).byte() 343 | 344 | # -- if is input at the first step, use index -1 345 | start_i = -1 if self.model_type == 'standard' else 0 346 | end_i = self.seq_length - 1 347 | 348 | for token_idx in range(start_i, end_i): 349 | if token_idx == -1: 350 | xt = fc_feats 351 | else: 352 | if token_idx == 0: # input 353 | it = fc_feats.data.new( 354 | batch_size).long().fill_(self.bos_index) 355 | elif sample_max == 1: 356 | # output here is a Tensor, because we don't use backprop 357 | sampleLogprobs, it = torch.max(logprobs.data, 1) 358 | it = it.view(-1).long() 359 | else: 360 | if temperature == 1.0: 361 | # fetch prev distribution: shape Nx(M+1) 362 | prob_prev = torch.exp(logprobs.data).cpu() 363 | else: 364 | # scale logprobs by temperature 365 | prob_prev = torch.exp( 366 | torch.div( 367 | logprobs.data, 368 | temperature)).cpu() 369 | #import pdb; pdb.set_trace() 370 | it = torch.multinomial(prob_prev, 1).cuda() 371 | # gather the logprobs at sampled positions 372 | sampleLogprobs = logprobs.gather( 373 | 1, Variable(it, requires_grad=False)) 374 | # and flatten indices for downstream processing 375 | it = it.view(-1).long() 376 | 377 | xt = self.embed(Variable(it, requires_grad=False)) 378 | 379 | if token_idx >= 1: 380 | unfinished = unfinished * (it > 0) 381 | 382 | # 383 | it = it * unfinished.type_as(it) 384 | seq.append(it) 385 | seqLogprobs.append(sampleLogprobs.view(-1)) 386 | 387 | # requires EOS token = 0 388 | if unfinished.sum() == 0: 389 | break 390 | 391 | if self.model_type == 'standard': 392 | output, state = self.core(xt, state) 393 | else: 394 | if self.model_type == 'manet': 395 | fc_feats = self.manet(fc_feats, state[0]) 396 | output, state = self.core(torch.cat([xt, fc_feats], 1), state) 397 | 398 | logprobs = F.log_softmax(self.logit(output)) 399 | 400 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat( 401 | [_.unsqueeze(1) for _ in seqLogprobs], 1) 402 | 403 | def sample_beam(self, feats, opt={}): 404 | """ 405 | modified from https://github.com/ruotianluo/self-critical.pytorch 406 | """ 407 | beam_size = opt.get('beam_size', 5) 408 | fc_feats = self.feat_pool(feats) 409 | batch_size = fc_feats.size(0) 410 | 411 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 412 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 413 | # lets process every image independently for now, for simplicity 414 | 415 | self.done_beams = [[] for _ in range(batch_size)] 416 | for k in range(batch_size): 417 | state = self.init_hidden(beam_size) 418 | fc_feats_k = fc_feats[k].expand( 419 | beam_size, self.video_encoding_size) 420 | 421 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 422 | beam_seq_logprobs = torch.FloatTensor( 423 | self.seq_length, beam_size).zero_() 424 | # running sum of logprobs for each beam 425 | beam_logprobs_sum = torch.zeros(beam_size) 426 | 427 | # -- if is input at the first step, use index -1 428 | start_i = -1 if self.model_type == 'standard' else 0 429 | end_i = self.seq_length - 1 430 | 431 | for token_idx in range(start_i, end_i): 432 | if token_idx == -1: 433 | xt = fc_feats_k 434 | elif token_idx == 0: # input 435 | it = fc_feats.data.new( 436 | beam_size).long().fill_(self.bos_index) 437 | xt = self.embed(Variable(it, requires_grad=False)) 438 | else: 439 | """perform a beam merge. that is, 440 | for every previous beam we now many new possibilities to branch out 441 | we need to resort our beams to maintain the loop invariant of keeping 442 | the top beam_size most likely sequences.""" 443 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 444 | # sorted array of logprobs along each previous beam (last 445 | # true = descending) 446 | ys, ix = torch.sort(logprobsf, 1, True) 447 | candidates = [] 448 | cols = min(beam_size, ys.size(1)) 449 | rows = beam_size 450 | if token_idx == 1: # at first time step only the first beam is active 451 | rows = 1 452 | for c in range(cols): 453 | for q in range(rows): 454 | # compute logprob of expanding beam q with word in 455 | # (sorted) position c 456 | local_logprob = ys[q, c] 457 | candidate_logprob = beam_logprobs_sum[ 458 | q] + local_logprob 459 | candidates.append({'c': ix.data[q, c], 'q': q, 'p': candidate_logprob.data[ 460 | 0], 'r': local_logprob.data[0]}) 461 | candidates = sorted(candidates, key=lambda x: -x['p']) 462 | 463 | # construct new beams 464 | new_state = [_.clone() for _ in state] 465 | if token_idx > 1: 466 | # well need these as reference when we fork beams 467 | # around 468 | beam_seq_prev = beam_seq[:token_idx - 1].clone() 469 | beam_seq_logprobs_prev = beam_seq_logprobs[ 470 | :token_idx - 1].clone() 471 | 472 | for vix in range(beam_size): 473 | v = candidates[vix] 474 | # fork beam index q into index vix 475 | if token_idx > 1: 476 | beam_seq[ 477 | :token_idx - 1, 478 | vix] = beam_seq_prev[ 479 | :, 480 | v['q']] 481 | beam_seq_logprobs[ 482 | :token_idx - 1, 483 | vix] = beam_seq_logprobs_prev[ 484 | :, 485 | v['q']] 486 | 487 | # rearrange recurrent states 488 | for state_ix in range(len(new_state)): 489 | # copy over state in previous beam q to new beam at 490 | # vix 491 | new_state[state_ix][ 492 | 0, vix] = state[state_ix][ 493 | 0, v['q']] # dimension one is time step 494 | 495 | # append new end terminal at the end of this beam 496 | # c'th word is the continuation 497 | beam_seq[token_idx - 1, vix] = v['c'] 498 | beam_seq_logprobs[ 499 | token_idx - 1, vix] = v['r'] # the raw logprob here 500 | # the new (sum) logprob along this beam 501 | beam_logprobs_sum[vix] = v['p'] 502 | 503 | if v['c'] == 0 or token_idx == self.seq_length - 2: 504 | # END token special case here, or we reached the end. 505 | # add the beam to a set of done beams 506 | if token_idx > 1: 507 | ppl = np.exp(-beam_logprobs_sum[vix] / (token_idx - 1)) 508 | else: 509 | ppl = 10000 510 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 511 | 'logps': beam_seq_logprobs[:, vix].clone(), 512 | 'p': beam_logprobs_sum[vix], 513 | 'ppl': ppl 514 | }) 515 | 516 | # encode as vectors 517 | it = beam_seq[token_idx - 1] 518 | xt = self.embed(Variable(it.cuda())) 519 | 520 | if token_idx >= 1: 521 | state = new_state 522 | 523 | if self.model_type == 'standard': 524 | output, state = self.core(xt, state) 525 | else: 526 | if self.model_type == 'manet': 527 | fc_feats_k = self.manet(fc_feats_k, state[0]) 528 | output, state = self.core( 529 | torch.cat([xt, fc_feats_k], 1), state) 530 | 531 | logprobs = F.log_softmax(self.logit(output)) 532 | 533 | 534 | #self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 535 | self.done_beams[k] = sorted( 536 | self.done_beams[k], key=lambda x: x['ppl']) 537 | 538 | # the first beam has highest cumulative score 539 | seq[:, k] = self.done_beams[k][0]['seq'] 540 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 541 | 542 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 543 | --------------------------------------------------------------------------------