├── .gitignore ├── emscore ├── __init__.py ├── images │ └── EMScore.png ├── example │ └── Bfyp0C02sko_000248_000258.mp4 ├── scorer.py └── utils.py ├── LICENSE ├── demo.py ├── README.md ├── extract_video_embeddings.py ├── VATEX-EVAL-demo.py └── ActivityNet-FOIL_demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ -------------------------------------------------------------------------------- /emscore/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | from .scorer import * 3 | -------------------------------------------------------------------------------- /emscore/images/EMScore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiYaya/emscore/HEAD/emscore/images/EMScore.png -------------------------------------------------------------------------------- /emscore/example/Bfyp0C02sko_000248_000258.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiYaya/emscore/HEAD/emscore/example/Bfyp0C02sko_000248_000258.mp4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yaya Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from scorer import EMScorer 2 | 3 | 4 | if __name__ == '__main__': 5 | # you video's path list 6 | vids = ['emscore/example/Bfyp0C02sko_000248_000258.mp4'] 7 | metric = EMScorer() 8 | # the candidate caption 9 | cands = ['A person is frying a pan of eggs that are colored pink and green.'] 10 | refs = ['Two egg yolks are placed in food flavored white yolk on a pan until they are cooked.'] 11 | # refs = ['Two egg yolks are placed in food flavored white yolk on a pan until they are cooked.'] 12 | results = metric.score(cands=cands, refs=refs, vids=vids, idf=False) 13 | 14 | """ 15 | only use video as groud truth 16 | """ 17 | print('------------------------------------------------------------------------') 18 | print('fine-grained EMScore(X,V), P: {}, R:{}, F:{}'.format( 19 | results['EMScore(X,V)']['full_P'], results['EMScore(X,V)']['full_R'], results['EMScore(X,V)']['full_F'])) 20 | print( 21 | 'coarse-grained EMScore(X,V), {}'.format(results['EMScore(X,V)']['cogr'])) 22 | print('fine- and coarse-grained EMScore(X,V), P: {}, R:{}, F:{}'.format( 23 | results['EMScore(X,V)']['full_P'], results['EMScore(X,V)']['full_R'], results['EMScore(X,V)']['full_F'])) 24 | 25 | """ 26 | only use references as groud truth 27 | """ 28 | print('------------------------------------------------------------------------') 29 | print('fine- and coarse-grained EMScore(X,X*), P: {}, R:{}, F:{}'.format(results['EMScore(X,X*)']['full_P'], 30 | results['EMScore(X,X*)']['full_R'], results['EMScore(X,X*)']['full_F'])) 31 | 32 | """ 33 | use both video and references as groud truth 34 | """ 35 | print('------------------------------------------------------------------------') 36 | print('fine- and coarse-grained EMScore(X,V,X*), P: {}, R:{}, F:{}'.format(results['EMScore(X,V,X*)']['full_P'], 37 | results['EMScore(X,V,X*)']['full_R'], results['EMScore(X,V,X*)']['full_F'])) 38 | print() 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Automatic Video Captioning Evaluation Metric --- EMScore 2 | 3 | 4 | ## Overview 5 | 6 | For an illustration, EMScore can be computed as: 7 | 8 | ![EMScore](./emscore/images/EMScore.png) 9 | 10 | 11 | 12 | 13 | ## Installation 14 | 15 | - modify the `encode_text()` function in `CLIP/clip/model.py` as follows: 16 | 17 | ``` 18 | def encode_text(self, text, local=False): 19 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 20 | 21 | x = x + self.positional_embedding.type(self.dtype) 22 | x = x.permute(1, 0, 2) # NLD -> LND 23 | x = self.transformer(x) 24 | x = x.permute(1, 0, 2) # LND -> NLD 25 | x = self.ln_final(x).type(self.dtype) 26 | 27 | if local: 28 | x = x @ self.text_projection 29 | else: 30 | # x.shape = [batch_size, n_ctx, transformer.width] 31 | # take features from the eot embedding (eot_token is the highest number in each sequence) 32 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 33 | 34 | return x 35 | ``` 36 | 37 | - Push your modified CLIP to your GitHub. 38 | 39 | - Install 40 | 41 | ``` 42 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 43 | $ pip install ftfy regex tqdm 44 | $ pip install git+https://github.com/$Yours_GitHub_name/CLIP 45 | ``` 46 | 47 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 48 | 49 | 50 | 51 | ## Usage: 52 | 53 | ### A general demo 54 | ``` 55 | python demo.py 56 | ``` 57 | 58 | 59 | ### VATEX-EVAL 60 | - download the files in the following link, and save at a storage directory 61 | ``` 62 | https://drive.google.com/drive/folders/1jAfZZKEgkMEYFF2x1mhYo39nH-TNeGm6?usp=sharing 63 | ``` 64 | 65 | - run code 66 | ``` 67 | python VATEX-EVAL-demo.py --storage_path $storage_path --use_n_refs 1 --use_feat_cache --use_idf 68 | ``` 69 | 70 | 71 | ### ActivityNet-FOIL 72 | - download the files in the following link, and save at a storage directory 73 | ``` 74 | https://drive.google.com/drive/folders/1oY9EJiEi_db_1GH-R33JDqfE8txffKR3?usp=sharing 75 | ``` 76 | 77 | - run code 78 | ``` 79 | python ActivityNet-FOIL_demo.py --storage_path $storage_path --use_references --use_idf 80 | ``` 81 | 82 | ## Others 83 | if you want extract embeddings by yourself: 84 | ``` 85 | python extract_video_embeddings.py --videos_path $your_video_path --save_path $your_storage_path --backbone 'ViT-B/32' 86 | ``` 87 | 88 | 89 | ## Citation 90 | If you find this code useful for your research, please consider citing: 91 | 92 | ``` 93 | @inproceedings{DBLP:conf/cvpr/ShiYXYLHZ22, 94 | author = {Yaya Shi and 95 | Xu Yang and 96 | Haiyang Xu and 97 | Chunfeng Yuan and 98 | Bing Li and 99 | Weiming Hu and 100 | Zheng{-}Jun Zha}, 101 | title = {EMScore: Evaluating Video Captioning via Coarse-Grained and Fine-Grained 102 | Embedding Matching}, 103 | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, 104 | {CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022}, 105 | year = {2022}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /extract_video_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = "3" 3 | 4 | 5 | import argparse 6 | import torch 7 | import clip 8 | from PIL import Image 9 | import json 10 | import cv2 11 | import glob 12 | import numpy as np 13 | import os 14 | import torch 15 | from tqdm import tqdm 16 | import math 17 | 18 | 19 | def encode_video(video_file, preprocess, model): 20 | cap = cv2.VideoCapture(video_file) 21 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 22 | images = [] 23 | count = 0 24 | ret = True 25 | 26 | while (count < frameCount and ret): 27 | ret, frame = cap.read() 28 | if not ret: # if file is empty break loop 29 | break 30 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 31 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 32 | count += 1 33 | # print('{}/{}'.format(count, frameCount)) 34 | 35 | if len(images) == 0: 36 | return None 37 | 38 | image_input = torch.tensor(np.stack(images)).cuda() 39 | image_features_list = [] 40 | bs = 256 41 | with torch.no_grad(): 42 | n_inter = math.ceil(len(image_input)/bs) 43 | for i in tqdm(range(n_inter), desc='encoding vid: {}'.format(video_file)): 44 | image_features = model.encode_image(image_input[i*bs: (i+1)*bs]).float() 45 | image_features_list.append(image_features) 46 | image_features = torch.cat(image_features_list, dim=0) 47 | image_features /= image_features.norm(dim=-1, keepdim=True) 48 | cap.release() 49 | 50 | return image_features 51 | 52 | 53 | def extract_dataset_videos_embeddings(preprocess, model, opt): 54 | save_dir_path = os.path.join(opt.save_path, 'clip_vid_feats') 55 | if not os.path.exists(save_dir_path): 56 | os.makedirs(save_dir_path) 57 | 58 | all_videos_path = glob.glob(opt.videos_path + '/*.mp4') 59 | for vid_path in tqdm(all_videos_path): 60 | vid = vid_path.split('/')[-1][:-4] 61 | save_vid_path = os.path.join(save_dir_path, vid+'.pt') 62 | if os.path.exists(save_vid_path): 63 | print('vid:{} done'.format(vid)) 64 | continue 65 | frames_feature = encode_video(vid_path, preprocess, model) 66 | 67 | if frames_feature == None: 68 | continue 69 | else: 70 | frames_feature = frames_feature.cpu().data 71 | 72 | torch.save(frames_feature, save_vid_path) 73 | # print(vid) 74 | 75 | 76 | 77 | if __name__ == "__main__": 78 | parse = argparse.ArgumentParser() 79 | parse.add_argument('--videos_path', type=str, default='') 80 | parse.add_argument('--save_path', type=str, default='', 81 | help='the path to save reformat files') 82 | parse.add_argument('--backbone', type=str, default='RN50') 83 | opt = parse.parse_args() 84 | 85 | device = "cuda" if torch.cuda.is_available() else "cpu" 86 | # backbone = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/16'] 87 | if 'ViT-B/16' == opt.backbone: 88 | opt.save_path = os.path.join(opt.save_path, 'ViT-B-16') 89 | elif 'ViT-B/32' == opt.backbone: 90 | opt.save_path = os.path.join(opt.save_path, 'ViT-B-32') 91 | else: 92 | opt.save_path = os.path.join(opt.save_path, opt.backbone) 93 | if not os.path.exists(opt.save_path): 94 | os.mkdir(opt.save_path) 95 | model, preprocess = clip.load(opt.backbone, device=device) 96 | 97 | extract_dataset_videos_embeddings(preprocess, model, opt) 98 | -------------------------------------------------------------------------------- /emscore/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | import json 5 | import cv2 6 | import numpy as np 7 | from tqdm import tqdm 8 | import math 9 | import time 10 | from collections import defaultdict 11 | 12 | 13 | from .utils import em_cos_score, get_idf_dict 14 | 15 | class EMScorer: 16 | """ 17 | EMScore Scorer Object. 18 | """ 19 | 20 | def __init__(self, vid_feat_cache=None, device=None,): 21 | 22 | self.vid_feat_cache = vid_feat_cache 23 | 24 | if device is None: 25 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 26 | else: 27 | self.device = device 28 | 29 | def score(self, cands, refs, vids=None, verbose=True, batch_size=64, nthreads=4, idf=True, return_matched_idx=False): 30 | """ 31 | Args: 32 | - :param: `cands` (list of str): candidate sentences 33 | - :param: `refs` (list of list of str): reference sentences 34 | 35 | Return: 36 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 37 | candidate reference pairs. if returning hashcode, the 38 | output will be ((P, R, F), hashcode). If a candidate have 39 | multiple references, the returned score of this candidate is 40 | the *best* score among all references. 41 | """ 42 | 43 | model, preprocess = clip.load("ViT-B/32", device=self.device) 44 | self._model = model 45 | self._tokenizer = clip.tokenize 46 | self._image_preprocess = preprocess 47 | 48 | ref_group_boundaries = None 49 | ori_cands, ori_refs = cands, refs 50 | # if reference are avaliable, and there are multiple references for each candidata caption 51 | if refs and not isinstance(refs[0], str): 52 | ref_group_boundaries = [] 53 | cands, refs = [], [] 54 | count = 0 55 | for cand, ref_group in zip(ori_cands, ori_refs): 56 | cands += [cand] * len(ref_group) 57 | refs += ref_group 58 | ref_group_boundaries.append((count, count + len(ref_group))) 59 | count += len(ref_group) 60 | 61 | if not idf: 62 | idf_dict = defaultdict(lambda: 1.0) 63 | elif isinstance(idf, dict): 64 | if verbose: 65 | print("using predefined IDF dict...") 66 | idf_dict = idf 67 | else: 68 | if verbose: 69 | print("preparing IDF dict...") 70 | start = time.perf_counter() 71 | idf_corpus = refs if refs else cands 72 | idf_dict = get_idf_dict(idf_corpus, self._tokenizer, nthreads=nthreads) 73 | # max token_id are eos token id 74 | # set idf of eos token are mean idf value 75 | idf_dict[max(list(idf_dict.keys()))] = sum(list(idf_dict.values()))/len(list(idf_dict.values())) 76 | if verbose: 77 | print("done in {:.2f} seconds".format(time.perf_counter() - start)) 78 | 79 | 80 | if verbose: 81 | print("calculating EMScore scores...") 82 | time_start = time.perf_counter() 83 | 84 | results = em_cos_score( 85 | self._model, 86 | refs, 87 | cands, 88 | ori_cands, 89 | ori_refs, 90 | vids, 91 | self.vid_feat_cache, 92 | self._tokenizer, 93 | idf_dict, 94 | self._image_preprocess, 95 | verbose=verbose, 96 | device=self.device, 97 | batch_size=batch_size, 98 | return_matched_idx=return_matched_idx 99 | ) 100 | 101 | final_results = {} 102 | if refs: 103 | refs_all_local_preds = results['refs_result']['figr'] 104 | refs_all_global_preds = results['refs_result']['cogr'] 105 | if ref_group_boundaries is not None: 106 | max_preds_local = [] 107 | for start, end in ref_group_boundaries: 108 | max_preds_local.append(refs_all_local_preds[start:end].max(dim=0)[0]) 109 | refs_all_local_preds = torch.stack(max_preds_local, dim=0) 110 | 111 | max_preds_global = [] 112 | for start, end in ref_group_boundaries: 113 | max_preds_global.append(refs_all_global_preds[start:end].max()) 114 | refs_all_global_preds = torch.stack(max_preds_global, dim=0) 115 | 116 | refs_P, refs_R, refs_F = refs_all_local_preds[..., 0], refs_all_local_preds[..., 1], refs_all_local_preds[..., 2] # P, R, F 117 | 118 | refs_results = {} 119 | refs_results['figr_P'] = refs_P 120 | refs_results['figr_R'] = refs_R 121 | refs_results['figr_F'] = refs_F 122 | refs_results['cogr'] = refs_all_global_preds 123 | refs_results['full_P'] = (refs_results['figr_P'] + refs_results['cogr'])/2 124 | refs_results['full_R'] = (refs_results['figr_R'] + refs_results['cogr'])/2 125 | refs_results['full_F'] = (refs_results['figr_F'] + refs_results['cogr'])/2 126 | # refs_results['refs_matched_indices'] = results['refs_result']['matched_indices'] 127 | final_results['EMScore(X,X*)'] = refs_results 128 | 129 | if vids: 130 | vid_all_local_preds = results['vid_result']['figr'] 131 | vid_all_global_preds = results['vid_result']['cogr'] 132 | vid_P, vid_R, vid_F = vid_all_local_preds[..., 0], vid_all_local_preds[..., 1], vid_all_local_preds[..., 2] # P, R, F 133 | 134 | vid_results = {} 135 | vid_results['figr_P'] = vid_P 136 | vid_results['figr_R'] = vid_R 137 | vid_results['figr_F'] = vid_F 138 | vid_results['cogr'] = vid_all_global_preds 139 | vid_results['full_P'] = (vid_results['figr_P'] + vid_results['cogr'])/2 140 | vid_results['full_R'] = (vid_results['figr_R'] + vid_results['cogr'])/2 141 | vid_results['full_F'] = (vid_results['figr_F'] + vid_results['cogr'])/2 142 | # vid_results['vid_matched_indices'] = results['vid_result']['matched_indices'] 143 | final_results['EMScore(X,V)'] = vid_results 144 | 145 | if refs and vids: 146 | vid_refs_result = {} 147 | vid_refs_result['figr_P'] = (final_results['EMScore(X,V)']['figr_P'] + final_results['EMScore(X,X*)']['figr_P'])/2 148 | vid_refs_result['figr_R'] = (final_results['EMScore(X,V)']['figr_R'] + final_results['EMScore(X,X*)']['figr_R'])/2 149 | vid_refs_result['figr_F'] = (final_results['EMScore(X,V)']['figr_F'] + final_results['EMScore(X,X*)']['figr_F'])/2 150 | vid_refs_result['cogr'] = (final_results['EMScore(X,V)']['cogr'] + final_results['EMScore(X,X*)']['cogr'])/2 151 | vid_refs_result['full_P'] = (vid_refs_result['figr_P'] + vid_refs_result['cogr'])/2 152 | vid_refs_result['full_R'] = (vid_refs_result['figr_R'] + vid_refs_result['cogr'])/2 153 | vid_refs_result['full_F'] = (vid_refs_result['figr_F'] + vid_refs_result['cogr'])/2 154 | final_results['EMScore(X,V,X*)'] = vid_refs_result 155 | 156 | 157 | if verbose: 158 | time_diff = time.perf_counter() - time_start 159 | print(f"done in {time_diff:.2f} seconds, {len(cands) / time_diff:.2f} sentences/sec") 160 | 161 | return final_results 162 | 163 | 164 | -------------------------------------------------------------------------------- /VATEX-EVAL-demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | import json 6 | import glob 7 | import torch 8 | import math 9 | from tqdm import tqdm 10 | from emscore import EMScorer 11 | from emscore.utils import get_idf_dict, compute_correlation_uniquehuman 12 | import clip 13 | 14 | def get_feats_dict(feat_dir_path): 15 | print('loding cache feats ........') 16 | file_path_list = glob.glob(feat_dir_path+'/*.pt') 17 | feats_dict = {} 18 | for file_path in tqdm(file_path_list): 19 | vid = file_path.split('/')[-1][:-3] 20 | data = torch.load(file_path) 21 | feats_dict[vid] = data 22 | return feats_dict 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--storage_path', default='', type=str, help='The path you storage VATEX-EVAL dataset') 28 | parser.add_argument('--vid_base_path', default='', type=str, help='The path you storage VATEX-EVAL videos (optinal, if you use prepared video feats, You do not need to consider this)') 29 | parser.add_argument('--use_n_refs', default=1, type=int, help='How many references do you want to use for evaluation (1~9)') 30 | parser.add_argument('--use_feat_cache', default=True, action='store_true', help='Whether to use pre-prepared video features') 31 | parser.add_argument('--use_idf', action='store_true', default=True) 32 | 33 | opt = parser.parse_args() 34 | 35 | """ 36 | Dataset prepare 37 | """ 38 | samples_list = pickle.load(open(os.path.join(opt.storage_path, 'candidates_list.pkl'), 'rb')) 39 | gts_list = pickle.load(open(os.path.join(opt.storage_path, 'gts_list.pkl'), 'rb')) 40 | all_human_scores = pickle.load(open(os.path.join(opt.storage_path, 'human_scores.pkl'), 'rb')) 41 | all_human_scores = np.transpose(all_human_scores.reshape(3, -1), (1, 0)) 42 | video_ids = pickle.load(open(os.path.join(opt.storage_path, 'video_ids.pkl'), 'rb')) 43 | vid_base_path = 'your path to save vatex val videos' # optional 44 | cands = samples_list.tolist() 45 | refs = gts_list.tolist() 46 | 47 | 48 | """ 49 | Video feats prepare 50 | """ 51 | use_uniform_sample = 10 52 | 53 | if not opt.use_feat_cache: 54 | vids = [vid_base_path+vid+'.mp4' for vid in video_ids] 55 | metric = EMScorer(vid_feat_cache=[]) 56 | else: 57 | vid_clip_feats_dir = os.path.join(opt.storage_path, 'VATEX-EVAL_video_feats') 58 | video_clip_feats_dict = get_feats_dict(vid_clip_feats_dir) 59 | if use_uniform_sample: 60 | for vid in video_clip_feats_dict: 61 | data = video_clip_feats_dict[vid] 62 | select_index = np.linspace(0, len(data)-1, use_uniform_sample) 63 | select_index = [int(index) for index in select_index] 64 | video_clip_feats_dict[vid] = data[select_index] 65 | 66 | vids = video_ids.tolist() 67 | metric = EMScorer(vid_feat_cache=video_clip_feats_dict) 68 | 69 | 70 | """ 71 | Prepare IDF 72 | """ 73 | if opt.use_idf: 74 | vatex_train_corpus_path = os.path.join(opt.storage_path, 'vatex_train_en_annotations.json') 75 | vatex_train_corpus = json.load(open(vatex_train_corpus_path)) 76 | vatex_train_corpus_list = [] 77 | for vid in vatex_train_corpus: 78 | vatex_train_corpus_list.extend(vatex_train_corpus[vid]) 79 | 80 | emscore_idf_dict = get_idf_dict(vatex_train_corpus_list, clip.tokenize, nthreads=4) 81 | # max token_id are eos token id 82 | # set idf of eos token are mean idf value 83 | emscore_idf_dict[max(list(emscore_idf_dict.keys()))] = sum(list(emscore_idf_dict.values()))/len(list(emscore_idf_dict.values())) 84 | else: 85 | emscore_idf_dict = False 86 | 87 | 88 | """ 89 | Metric calculate 90 | """ 91 | refs = np.array(refs)[:, :opt.use_n_refs].tolist() 92 | # results = metric.score(cands, refs, vids=vids) 93 | results = metric.score(cands, refs=refs, vids=vids, idf=emscore_idf_dict) 94 | 95 | 96 | if 'EMScore(X,V)' in results: 97 | print('EMScore(X,V) correlation --------------------------------------') 98 | # vid_figr_res_P = results['EMScore(X,V)']['figr_P'] 99 | # vid_figr_res_R = results['EMScore(X,V)']['figr_R'] 100 | # vid_figr_res_F = results['EMScore(X,V)']['figr_F'] 101 | # vid_cogr_res = results['EMScore(X,V)']['cogr'] 102 | # vid_full_res_P = results['EMScore(X,V)']['full_P'] 103 | # vid_full_res_R = results['EMScore(X,V)']['full_R'] 104 | vid_full_res_F = results['EMScore(X,V)']['full_F'] 105 | # compute_correlation_uniquehuman(vid_figr_res_P.numpy(), all_human_scores) 106 | # compute_correlation_uniquehuman(vid_figr_res_R.numpy(), all_human_scores) 107 | # compute_correlation_uniquehuman(vid_figr_res_F.numpy(), all_human_scores) 108 | # compute_correlation_uniquehuman(vid_cogr_res.numpy(), all_human_scores) 109 | # compute_correlation_uniquehuman(vid_full_res_P.numpy(), all_human_scores) 110 | # compute_correlation_uniquehuman(vid_full_res_R.numpy(), all_human_scores) 111 | compute_correlation_uniquehuman(vid_full_res_F.numpy(), all_human_scores) 112 | 113 | 114 | if 'EMScore(X,X*)' in results: 115 | print('EMScore(X,X*) correlation --------------------------------------') 116 | 117 | # refs_figr_res_P = results['EMScore(X,X*)']['figr_P'] 118 | # refs_figr_res_R = results['EMScore(X,X*)']['figr_R'] 119 | # refs_figr_res_F = results['EMScore(X,X*)']['figr_F'] 120 | # refs_cogr_res = results['EMScore(X,X*)']['cogr'] 121 | # refs_full_res_P = results['EMScore(X,X*)']['full_P'] 122 | # refs_full_res_R = results['EMScore(X,X*)']['full_R'] 123 | refs_full_res_F = results['EMScore(X,X*)']['full_F'] 124 | # compute_correlation_uniquehuman(refs_figr_res_P.numpy(), all_human_scores) 125 | # compute_correlation_uniquehuman(refs_figr_res_R.numpy(), all_human_scores) 126 | # compute_correlation_uniquehuman(refs_figr_res_F.numpy(), all_human_scores) 127 | # compute_correlation_uniquehuman(refs_cogr_res.numpy(), all_human_scores) 128 | # compute_correlation_uniquehuman(refs_full_res_P.numpy(), all_human_scores) 129 | # compute_correlation_uniquehuman(refs_full_res_R.numpy(), all_human_scores) 130 | compute_correlation_uniquehuman(refs_full_res_F.numpy(), all_human_scores) 131 | 132 | 133 | if 'EMScore(X,V,X*)' in results: 134 | print('EMScore(X,V,X*) correlation --------------------------------------') 135 | # vid_refs_figr_res_P = results['EMScore(X,V,X*)']['figr_P'] 136 | # vid_refs_figr_res_R = results['EMScore(X,V,X*)']['figr_R'] 137 | # vid_refs_figr_res_F = results['EMScore(X,V,X*)']['figr_F'] 138 | # vid_refs_cogr_res = results['EMScore(X,V,X*)']['cogr'] 139 | # vid_refs_full_res_P = results['EMScore(X,V,X*)']['full_P'] 140 | # vid_refs_full_res_R = results['EMScore(X,V,X*)']['full_R'] 141 | vid_refs_full_res_F = results['EMScore(X,V,X*)']['full_F'] 142 | # compute_correlation_uniquehuman(vid_refs_figr_res_P.numpy(), all_human_scores) 143 | # compute_correlation_uniquehuman(vid_refs_figr_res_R.numpy(), all_human_scores) 144 | # compute_correlation_uniquehuman(vid_refs_figr_res_F.numpy(), all_human_scores) 145 | # compute_correlation_uniquehuman(vid_refs_cogr_res.numpy(), all_human_scores) 146 | # compute_correlation_uniquehuman(vid_refs_full_res_P.numpy(), all_human_scores) 147 | # compute_correlation_uniquehuman(vid_refs_full_res_R.numpy(), all_human_scores) 148 | compute_correlation_uniquehuman(vid_refs_full_res_F.numpy(), all_human_scores) 149 | -------------------------------------------------------------------------------- /ActivityNet-FOIL_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | import clip 5 | import torch 6 | import glob 7 | from tqdm import tqdm 8 | import numpy as np 9 | import json 10 | from collections import defaultdict, Counter 11 | from multiprocessing import Pool 12 | from functools import partial 13 | from itertools import chain 14 | from math import log 15 | import argparse 16 | from emscore import EMScorer 17 | 18 | 19 | def process(a, tokenizer=None): 20 | if tokenizer is not None: 21 | a = tokenizer(a)[0].tolist() 22 | return set(a) 23 | 24 | 25 | def get_idf_dict(arr, tokenizer, nthreads=4): 26 | """ 27 | Returns mapping from word piece index to its inverse document frequency. 28 | 29 | 30 | Args: 31 | - :param: `arr` (list of str) : sentences to process. 32 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 33 | - :param: `nthreads` (int) : number of CPU threads to use 34 | """ 35 | idf_count = Counter() 36 | num_docs = len(arr) 37 | 38 | process_partial = partial(process, tokenizer=tokenizer) 39 | 40 | with Pool(nthreads) as p: 41 | idf_count.update(chain.from_iterable(p.map(process_partial, arr))) 42 | 43 | idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) 44 | idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) 45 | return idf_dict 46 | 47 | 48 | class EMScore_ANET_FOIL(object): 49 | 50 | def __init__(self, args, prediction_filename=None, 51 | idf=True, verbose=False,): 52 | # For clip 53 | self.storage_path = args.storage_path 54 | self.verbose = verbose 55 | self.args = args 56 | 57 | self.vid_duration_dict = self.import_vid_duration() 58 | self.clip_prediction = self.import_emscore_prediction(prediction_filename) 59 | if idf: 60 | self.emscore_idf_dict = self.compute_emscore_idf() 61 | else: 62 | self.emscore_idf_dict = False 63 | self.vid_clip_feats = self.import_clip_vid_feats() 64 | self.cands_timestamp = self.get_cands_timestamp() 65 | self.refs_for_eval = self.import_refs_for_eval() 66 | 67 | 68 | def get_gt_vid_ids(self): 69 | pred_vid_ids = set(list(self.clip_prediction.keys())) 70 | exist_videos = [item.split('.')[0] for item in os.listdir(self.args.anet_vid_clip_feats_path)] 71 | gt_exist_videos = set(pred_vid_ids).intersection(set(exist_videos)) 72 | return list(gt_exist_videos) 73 | 74 | def import_clip_vid_feats(self): 75 | def get_feats_dict(feat_dir_path, gt_vid_ids): 76 | file_path_list = glob.glob(feat_dir_path+'/*.pt') 77 | feats_dict = {} 78 | for file_path in file_path_list: 79 | vid = file_path.split('/')[-1][:-3] 80 | if vid not in gt_vid_ids: 81 | continue 82 | data = torch.load(file_path) 83 | feats_dict[vid] = data 84 | return feats_dict 85 | gt_vid_ids = self.get_gt_vid_ids() 86 | vid_feat_dict = get_feats_dict(self.args.anet_vid_clip_feats_path, gt_vid_ids) 87 | assert len(vid_feat_dict.keys()) == len(gt_vid_ids) 88 | return vid_feat_dict 89 | 90 | def import_vid_duration(self): 91 | # the duration for each video 92 | filenames = (os.path.join(self.storage_path, 'anet_entities_test_1.json'), 93 | os.path.join(self.storage_path, 'anet_entities_test_2.json')) 94 | vid_duration_dict = {} 95 | for filename in filenames: 96 | gt = json.load(open(filename)) 97 | for vid in gt: 98 | vid_duration_dict[vid] = gt[vid]['duration'] 99 | 100 | return vid_duration_dict 101 | 102 | def get_cands_timestamp(self): 103 | ref_filename = os.path.join(self.storage_path, 'anet_entities_test_1.json') 104 | clip_refs = json.load(open(ref_filename)) 105 | return clip_refs 106 | 107 | def import_refs_for_eval(self): 108 | ref_filename = os.path.join(self.storage_path, 'anet_entities_test_2.json') 109 | refs_for_eval = json.load(open(ref_filename)) 110 | return refs_for_eval 111 | 112 | def compute_emscore_idf(self): 113 | print('compute emscore idf ..................') 114 | data = json.load(open(self.args.idf_corpus)) 115 | train_corpus = [] 116 | for vid in data: 117 | sents = data[vid]['sentences'] 118 | new_sents = [] 119 | for sent in sents: 120 | if len(sent.split(' ')) > 66: # Filter out too long sentences 121 | continue 122 | else: 123 | new_sents.append(sent) 124 | train_corpus.extend(new_sents) 125 | 126 | idf_dict = get_idf_dict(train_corpus, clip.tokenize, nthreads=4) 127 | idf_dict[max(list(idf_dict.keys()))] = sum(list(idf_dict.values()))/len(list(idf_dict.values())) 128 | return idf_dict 129 | 130 | def import_emscore_prediction(self, prediction_filename): 131 | pred = json.load(open(prediction_filename)) 132 | return pred['results'] 133 | 134 | def use_ref_timestamps(self, timestamp, ref_timestamps): 135 | use_ref_idxs = [] 136 | # 如果timestamp 在 ref_timestamp 中的交集占 ref_timestamp 的50%以上,则使用该 ref_timestamp 137 | a = int(timestamp[0]) # a 138 | b = int(timestamp[1]) # b 139 | for ref_i, ref_time in enumerate(ref_timestamps): 140 | c = int(ref_time[0]) # c 141 | d = int(ref_time[1]) # d 142 | result = list(set(range(a, b+1)) & 143 | set(range(c, d+1))) 144 | if len(result)/len(range(c, d+1)) > 0.4: 145 | use_ref_idxs.append(ref_i) 146 | return use_ref_idxs 147 | 148 | 149 | def evaluate(self): 150 | gt_vid_ids = self.get_gt_vid_ids() 151 | self.filter_clip_prediction = { 152 | vid: self.clip_prediction[vid] for vid in gt_vid_ids} 153 | 154 | Use_Ref = self.args.use_references 155 | device = "cuda" if torch.cuda.is_available() else "cpu" 156 | model, preprocess = clip.load("ViT-B/32", device=device) 157 | 158 | vid_emscore_with_idf = {} 159 | count = 0 160 | vid_group_boundaries = [] 161 | cands_list = [] 162 | refs_list = [] 163 | seg_feat_dict = {} 164 | seg_list = [] 165 | for vid in tqdm(gt_vid_ids, desc='Computing EMScore'): 166 | vid_pred = self.filter_clip_prediction[vid] 167 | 168 | for sent_i, sent_segment in enumerate(vid_pred): 169 | cand_sent = sent_segment['sentence'] # 预测的caption 170 | cands_list.append(cand_sent) 171 | # timestamp = sent_segment['timestamp'] 172 | ref_idx = sent_i 173 | timestamp = self.cands_timestamp[vid]['timestamps'][ref_idx] 174 | if Use_Ref: 175 | ref = self.refs_for_eval[vid]['sentences'] 176 | ref_timestamps = self.refs_for_eval[vid]['timestamps'] 177 | use_ref_idx = self.use_ref_timestamps(timestamp, ref_timestamps) 178 | if not use_ref_idx: 179 | use_ref_idx = list( 180 | range(len(ref_timestamps))) 181 | refs_sent = [ref[idx] for idx in use_ref_idx] 182 | refs_list.append(refs_sent) 183 | 184 | """ 185 | 使用 video作为 references 186 | """ 187 | vid_feats = self.vid_clip_feats[vid] 188 | vid_frames_len = len(vid_feats) 189 | duration = self.vid_duration_dict[vid] 190 | start = timestamp[0]*vid_frames_len//duration 191 | end = timestamp[1]*vid_frames_len//duration 192 | fg_vid_segment_feat = vid_feats[int(start):int(end)] 193 | 194 | seg_feat_dict['{}_seg_{}'.format(vid, sent_i)] = fg_vid_segment_feat 195 | seg_list.append('{}_seg_{}'.format(vid, sent_i)) 196 | 197 | vid_group_boundaries.append((count, count + len(vid_pred))) 198 | count += len(vid_pred) 199 | 200 | emscore_metric = EMScorer(vid_feat_cache=seg_feat_dict) 201 | vid_emscore_with_idf = emscore_metric.score(cands=cands_list, refs=refs_list, vids=seg_list, idf=self.emscore_idf_dict) 202 | for key in vid_emscore_with_idf: 203 | for item in vid_emscore_with_idf[key]: 204 | scores = [] 205 | for beg, end in vid_group_boundaries: 206 | scores.append(float(torch.mean(vid_emscore_with_idf[key][item][beg: end]))) 207 | vid_emscore_with_idf[key][item] = scores 208 | 209 | final_vid_emscore = {} 210 | for key in vid_emscore_with_idf: 211 | for item in vid_emscore_with_idf[key]: 212 | final_vid_emscore[key + '_' + item] = vid_emscore_with_idf[key][item] 213 | return final_vid_emscore 214 | 215 | 216 | def main(args): 217 | 218 | evaluator = EMScore_ANET_FOIL(args, prediction_filename=args.submission_right, idf=args.use_idf, verbose=args.verbose) 219 | right_vid_emscores = evaluator.evaluate() 220 | 221 | evaluator = EMScore_ANET_FOIL(args, prediction_filename=args.submission_foil, idf=args.use_idf, verbose=args.verbose) 222 | foil_vid_emscores = evaluator.evaluate() 223 | 224 | for key in right_vid_emscores: 225 | res = np.array(right_vid_emscores[key]) > np.array(foil_vid_emscores[key]) 226 | res_sum = np.sum(res) 227 | print(key, res_sum, '{:.2f}'.format(100*res_sum/len(res))) 228 | 229 | if __name__ == '__main__': 230 | parser = argparse.ArgumentParser() 231 | parser.add_argument('--storage_path', type=str, default='', 232 | help='the path you storage ActivityNet-FOIL dataset.') 233 | parser.add_argument('--verbose', default=True, 234 | help='Print intermediate steps.') 235 | parser.add_argument('--use_references', action='store_true', default=True) 236 | parser.add_argument('--use_idf', action='store_true', default=True) 237 | 238 | args = parser.parse_args() 239 | 240 | args.submission_right = os.path.join(args.storage_path, 'final_right_video_sentences.json') 241 | args.submission_foil = os.path.join(args.storage_path, 'final_foil_video_sentences.json') 242 | args.idf_corpus = os.path.join(args.storage_path, 'train.json') 243 | args.anet_vid_clip_feats_path = os.path.join(args.storage_path, 'ActivityNet-FOIL_video_feats') 244 | 245 | main(args) -------------------------------------------------------------------------------- /emscore/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | import json 5 | import cv2 6 | import numpy as np 7 | from tqdm import tqdm 8 | import math 9 | from math import log 10 | from torch.nn.utils.rnn import pad_sequence 11 | import sys 12 | import time 13 | import os 14 | from collections import defaultdict, Counter 15 | from multiprocessing import Pool 16 | from functools import partial 17 | from itertools import chain 18 | 19 | def compute_correlation_uniquehuman(pred, all_human_scores): 20 | num_workers = 3 21 | import scipy.stats 22 | 23 | pred = np.around(pred, decimals=4) 24 | 25 | spearman = 0 26 | for worker_i in range(num_workers): 27 | tmp, p_value = scipy.stats.spearmanr(pred, all_human_scores[:, worker_i]) 28 | assert p_value < 0.01 29 | spearman += tmp 30 | spearman /= num_workers 31 | spearman = np.around(spearman, decimals=4) 32 | 33 | kendalltau = 0 34 | for worker_i in range(num_workers): 35 | tmp, p_value = scipy.stats.kendalltau(pred, all_human_scores[:, worker_i]) 36 | assert p_value < 0.01 37 | kendalltau += tmp 38 | kendalltau /= num_workers 39 | kendalltau = np.around(kendalltau, decimals=4) 40 | 41 | print('kendall: {}, spear: {}'.format(kendalltau, spearman)) 42 | return kendalltau, spearman 43 | 44 | def normalize_matrix(A): 45 | assert len(A.shape) == 2 46 | A_norm = torch.linalg.norm(A, dim=-1, keepdim=True) 47 | return A / A_norm 48 | 49 | def encode_video(video_file, preprocess, model, batch_size, device): 50 | 51 | cv_start_time = time.perf_counter() 52 | cap = cv2.VideoCapture(video_file) 53 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 54 | images = [] 55 | count = 0 56 | ret = True 57 | 58 | while (count < frameCount and ret): 59 | ret, frame = cap.read() 60 | if not ret: # if file is empty break loop 61 | break 62 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 63 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 64 | count += 1 65 | 66 | cv_end_time = time.perf_counter() 67 | time_diff = cv_end_time-cv_start_time 68 | # print(f"cv done in {time_diff:.2f} seconds") 69 | 70 | 71 | image_embed_start_time = time.perf_counter() 72 | image_input = torch.tensor(np.stack(images)).to(device) 73 | image_features_list = [] 74 | # bs = 256 75 | with torch.no_grad(): 76 | n_inter = math.ceil(len(image_input)/batch_size) 77 | for i in range(n_inter): 78 | image_features = model.encode_image(image_input[i*batch_size: (i+1)*batch_size]).float() 79 | image_features_list.append(image_features) 80 | image_features = torch.cat(image_features_list, dim=0) 81 | image_features /= image_features.norm(dim=-1, keepdim=True) 82 | cap.release() 83 | 84 | vid_feature = normalize_matrix(torch.mean(image_features, dim=0, keepdim=True)).squeeze() 85 | 86 | image_embed_end_time = time.perf_counter() 87 | time_diff = image_embed_end_time - image_embed_start_time 88 | # print(f"image embed done in {time_diff:.2f} seconds") 89 | 90 | return image_features, vid_feature 91 | 92 | def encode_text(vid_caps, model, tokenizer, idf_dict, device): 93 | text_input = tokenizer(vid_caps).to(device=device) 94 | with torch.no_grad(): 95 | text_features = model.encode_text(text_input, local=True).float() 96 | text_features /= text_features.norm(dim=-1, keepdim=True) 97 | 98 | # For special tokens, use [SOS] and [EOS] 99 | txt_len = text_input.argmax(dim=-1) 100 | mask = torch.zeros_like(text_input) 101 | for i in range(len(mask)): 102 | mask[i][0:txt_len[i]+1] = 1 103 | 104 | # For special tokens, only use [EOS] 105 | # txt_len = text_input.argmax(dim=-1) 106 | # mask = torch.zeros_like(text_input) 107 | # for i in range(len(mask)): 108 | # mask[i][1:txt_len[i]+1] = 1 109 | 110 | # # For special tokens, don't use [SOS] and [EOS] 111 | # txt_len = text_input.argmax(dim=-1) 112 | # mask = torch.zeros_like(text_input) 113 | # for i in range(len(mask)): 114 | # mask[i][1:txt_len[i]] = 1 115 | 116 | idf_weights = torch.tensor([[idf_dict[int(i)] for i in a] for a in text_input.cpu()]) 117 | 118 | return text_features, mask, idf_weights 119 | 120 | 121 | def process(a, tokenizer=None): 122 | if tokenizer is not None: 123 | a = tokenizer(a)[0].tolist() 124 | return set(a) 125 | 126 | 127 | def get_idf_dict(arr, tokenizer, nthreads=4): 128 | """ 129 | Returns mapping from word piece index to its inverse document frequency. 130 | 131 | 132 | Args: 133 | - :param: `arr` (list of str) : sentences to process. 134 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 135 | - :param: `nthreads` (int) : number of CPU threads to use 136 | """ 137 | idf_count = Counter() 138 | num_docs = len(arr) 139 | 140 | process_partial = partial(process, tokenizer=tokenizer) 141 | 142 | with Pool(nthreads) as p: 143 | idf_count.update(chain.from_iterable(p.map(process_partial, arr))) 144 | 145 | idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) 146 | idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) 147 | return idf_dict 148 | 149 | 150 | def refs_greedy_cos(ref_embedding, ref_masks, ref_idf, hyp_embedding, hyp_masks, hyp_idf, return_matched_idx): 151 | """ 152 | Compute greedy matching based on cosine similarity. 153 | 154 | Args: 155 | - :param: `ref_embedding` (torch.Tensor): 156 | embeddings of reference sentences, BxKxd, 157 | B: batch size, K: longest length, d: bert dimenison. 158 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for 159 | reference sentences. 160 | - :param: `hyp_embedding` (torch.Tensor): 161 | embeddings of candidate sentences, BxKxd, 162 | B: batch size, K: longest length, d: bert dimenison 163 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for 164 | candidate sentences. 165 | """ 166 | # ref_embedding and hyp_embedding are aleady L2-normalized. 167 | 168 | batch_size = ref_embedding.size(0) 169 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 170 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) 171 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 172 | masks = masks.float().to(sim.device) 173 | sim = sim * masks 174 | 175 | word_precision, matched_indices = sim.max(dim=2) 176 | word_recall = sim.max(dim=1)[0] 177 | 178 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) 179 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True)) 180 | precision_scale = hyp_idf.to(word_precision.device) 181 | recall_scale = ref_idf.to(word_recall.device) 182 | 183 | P = (word_precision * precision_scale).sum(dim=1) 184 | R = (word_recall * recall_scale).sum(dim=1) 185 | F = 2 * P * R / (P + R) 186 | 187 | if return_matched_idx: 188 | return P, R, F, matched_indices 189 | else: 190 | return P, R, F, torch.zeros_like(P) 191 | 192 | def vid_greedy_cos(ref_embedding, ref_masks, hyp_embedding, hyp_masks, hyp_idf, return_matched_idx): 193 | """ 194 | Compute greedy matching based on cosine similarity. 195 | 196 | Args: 197 | - :param: `ref_embedding` (torch.Tensor): 198 | embeddings of reference sentences, BxKxd, 199 | B: batch size, K: longest length, d: bert dimenison. 200 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for 201 | reference sentences. 202 | - :param: `hyp_embedding` (torch.Tensor): 203 | embeddings of candidate sentences, BxKxd, 204 | B: batch size, K: longest length, d: bert dimenison 205 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for 206 | candidate sentences. 207 | """ 208 | # ref_embedding and hyp_embedding are aleady L2-normalized. 209 | 210 | batch_size = ref_embedding.size(0) 211 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 212 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) 213 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 214 | masks = masks.float().to(sim.device) 215 | sim = sim * masks 216 | 217 | word_precision, matched_indices = sim.max(dim=2) 218 | word_recall = sim.max(dim=1)[0] 219 | 220 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) 221 | precision_scale = hyp_idf.to(word_precision.device) 222 | P = (word_precision * precision_scale).sum(dim=1) 223 | R = word_recall.sum(dim=1)/ref_masks.sum(dim=1) 224 | F = 2 * P * R / (P + R) 225 | 226 | if return_matched_idx: 227 | return P, R, F, matched_indices 228 | else: 229 | return P, R, F, torch.zeros_like(P) 230 | 231 | 232 | 233 | def em_cos_score( 234 | model, refs, hyps, ori_cands, ori_refs, vids, vid_feat_cache, tokenizer, idf_dict, preprocess, verbose=True, batch_size=64, device="cuda:0", return_matched_idx=False 235 | ): 236 | """ 237 | Compute EMScore. 238 | 239 | Args: 240 | - :param: `model` : a BERT model in `pytorch_pretrained_bert` 241 | - :param: `refs` (list of str): reference sentences 242 | - :param: `hyps` (list of str): candidate sentences 243 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model` 244 | - :param: `verbose` (bool): turn on intermediate status update 245 | - :param: `batch_size` (int): bert score processing batch size 246 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 247 | """ 248 | 249 | refs_preds_local = [] 250 | refs_pred_matched_idxs = [] 251 | refs_preds_global = [] 252 | 253 | vid_preds_local = [] 254 | vid_pred_matched_idxs = [] 255 | vid_preds_global = [] 256 | 257 | 258 | """process text""" 259 | def dedup_and_sort(l): 260 | return sorted(list(set(l)), key=lambda x: len(x.split(" ")), reverse=True) 261 | 262 | sentences = dedup_and_sort(refs + hyps) 263 | embs = [] 264 | iter_range = range(0, len(sentences), batch_size) 265 | if verbose: 266 | print("computing text embedding.") 267 | iter_range = tqdm(iter_range) 268 | text_local_stats_dict = dict() 269 | text_global_stats_dict = dict() 270 | for batch_start in iter_range: 271 | sen_batch = sentences[batch_start: batch_start + batch_size] 272 | embs, masks, text_idfs = encode_text(sen_batch, model, tokenizer, idf_dict, device=device) 273 | embs = embs.cpu() 274 | masks = masks.cpu() 275 | for i, sen in enumerate(sen_batch): 276 | sequence_len = masks[i].sum().item() 277 | 278 | # For special tokens, use [SOS] and [EOS] 279 | local_emb = embs[i, 0:sequence_len] 280 | global_emb = embs[i, sequence_len-1] 281 | idf = text_idfs[i, 0:sequence_len] 282 | 283 | # For special tokens, don't use any 284 | # local_emb = embs[i, 1:sequence_len+1] 285 | # global_emb = embs[i, sequence_len+1] 286 | # idf = text_idfs[i, 1:sequence_len+1] 287 | 288 | # For special tokens, only use [EOS] 289 | # local_emb = embs[i, 1:sequence_len+1] 290 | # global_emb = embs[i, sequence_len] 291 | # idf = text_idfs[i, 1:sequence_len+1] 292 | 293 | text_local_stats_dict[sen] = (local_emb, idf) 294 | text_global_stats_dict[sen] = global_emb 295 | 296 | 297 | """process video""" 298 | if vids: 299 | if vid_feat_cache: 300 | ori_vids = vids 301 | vid_local_stats_dict = vid_feat_cache 302 | vid_global_stats_dict = dict() 303 | for vid in vid_local_stats_dict: 304 | image_features = vid_local_stats_dict[vid] 305 | vid_feature = normalize_matrix(torch.mean(image_features, dim=0, keepdim=True)).squeeze() 306 | vid_global_stats_dict[vid] = vid_feature 307 | else: 308 | ori_vids = vids # video paths list 309 | unique_vids = list(set(vids)) 310 | if verbose: 311 | print("computing vid embedding.") 312 | vid_local_stats_dict = dict() 313 | vid_global_stats_dict = dict() 314 | for vid_i in tqdm(range(len(unique_vids))): 315 | video_file = unique_vids[vid_i] 316 | image_features, vid_feature = encode_video(video_file, preprocess, model, batch_size=512, device=device) 317 | # vid_name = video_file.split('/')[-1][:-4] 318 | vid_local_stats_dict[video_file] = image_features.cpu() 319 | vid_global_stats_dict[video_file] = vid_feature.cpu() 320 | 321 | 322 | def pad_local_batch_stats(sen_batch, stats_dict, device): 323 | stats = [stats_dict[s] for s in sen_batch] 324 | emb, idf = zip(*stats) 325 | emb = [e.to(device) for e in emb] 326 | lens = [e.size(0) for e in emb] 327 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=0.0) 328 | idf_pad = pad_sequence(idf, batch_first=True) 329 | 330 | def length_to_mask(lens): 331 | lens = torch.tensor(lens, dtype=torch.long) 332 | max_len = max(lens) 333 | base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) 334 | return base < lens.unsqueeze(1) 335 | 336 | pad_mask = length_to_mask(lens).to(device) 337 | return emb_pad, pad_mask, idf_pad 338 | 339 | def pad_vid_local_batch_stats(sen_batch, stats_dict, device): 340 | stats = [stats_dict[s] for s in sen_batch] 341 | emb = stats 342 | emb = [e.to(device) for e in emb] 343 | lens = [e.size(0) for e in emb] 344 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=0.0) 345 | 346 | def length_to_mask(lens): 347 | lens = torch.tensor(lens, dtype=torch.long) 348 | max_len = max(lens) 349 | base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) 350 | return base < lens.unsqueeze(1) 351 | 352 | pad_mask = length_to_mask(lens).to(device) 353 | return emb_pad, pad_mask 354 | 355 | def pad_global_batch_stats(sen_batch, stats_dict, device): 356 | stats = [stats_dict[s] for s in sen_batch] 357 | emb = stats 358 | emb = [e.to(device) for e in emb] 359 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=0.0) 360 | return emb_pad 361 | 362 | """ if references are avaliable """ 363 | if refs: 364 | iter_range = range(0, len(hyps), batch_size) 365 | if verbose: 366 | print("computing greedy matching, references as ground truth.") 367 | iter_range = tqdm(iter_range) 368 | 369 | with torch.no_grad(): 370 | for batch_start in iter_range: 371 | batch_hyps = hyps[batch_start: batch_start + batch_size] 372 | hyp_stats_local = pad_local_batch_stats(batch_hyps, text_local_stats_dict, device) 373 | hyp_stats_global = pad_global_batch_stats(batch_hyps, text_global_stats_dict, device) 374 | 375 | batch_refs = refs[batch_start: batch_start + batch_size] 376 | ref_stats_local = pad_local_batch_stats(batch_refs, text_local_stats_dict, device) 377 | ref_stats_global = pad_global_batch_stats(batch_refs, text_global_stats_dict, device) 378 | 379 | P, R, F1, matched_indices = refs_greedy_cos(*ref_stats_local, *hyp_stats_local, return_matched_idx) 380 | refs_preds_local.append(torch.stack((P, R, F1), dim=-1).cpu()) 381 | refs_pred_matched_idxs.append(matched_indices) 382 | 383 | refs_s_cogr = torch.bmm(hyp_stats_global.unsqueeze(1), ref_stats_global.unsqueeze(1).transpose(1,2)).squeeze() 384 | refs_preds_global.append(refs_s_cogr) 385 | 386 | 387 | """ if video used as ground truth """ 388 | if vids: 389 | if verbose: 390 | print("computing greedy matching, video as ground truth.") 391 | iter_range = range(0, len(ori_cands), batch_size) 392 | with torch.no_grad(): 393 | for batch_start in iter_range: 394 | batch_ori_hyp = ori_cands[batch_start: batch_start + batch_size] 395 | ori_hyp_stats_local = pad_local_batch_stats(batch_ori_hyp, text_local_stats_dict, device) 396 | ori_hyp_stats_global = pad_global_batch_stats(batch_ori_hyp, text_global_stats_dict, device) 397 | 398 | batch_ori_vids = ori_vids[batch_start: batch_start + batch_size] 399 | ori_vids_stats_local = pad_vid_local_batch_stats(batch_ori_vids, vid_local_stats_dict, device) 400 | ori_vids_stats_global = pad_global_batch_stats(batch_ori_vids, vid_global_stats_dict, device) 401 | 402 | P, R, F1, matched_indices = vid_greedy_cos(*ori_vids_stats_local, *ori_hyp_stats_local, return_matched_idx) 403 | vid_preds_local.append(torch.stack((P, R, F1), dim=-1).cpu()) 404 | vid_pred_matched_idxs.append(matched_indices) 405 | 406 | vid_s_cogr = torch.bmm(ori_hyp_stats_global.unsqueeze(1), ori_vids_stats_global.unsqueeze(1).transpose(1, 2)).squeeze() 407 | vid_preds_global.append(vid_s_cogr) 408 | 409 | 410 | results = dict() 411 | """ if references are avaliable """ 412 | if refs: 413 | refs_preds_local = torch.cat(refs_preds_local, dim=0).cpu() 414 | if len(refs) != 1: 415 | refs_preds_global = torch.cat(refs_preds_global, dim=0).cpu() 416 | else: 417 | refs_preds_global = refs_preds_global[0].cpu() 418 | results['refs_result'] = {} 419 | results['refs_result']['figr'] = refs_preds_local 420 | results['refs_result']['cogr'] = refs_preds_global 421 | results['refs_result']['matched_indices'] = torch.cat(refs_pred_matched_idxs) 422 | 423 | """ if video used as ground truth """ 424 | if vids: 425 | vid_preds_local = torch.cat(vid_preds_local, dim=0).cpu() 426 | if len(vids) != 1: 427 | vid_preds_global = torch.cat(vid_preds_global, dim=0).cpu() 428 | else: 429 | vid_preds_global = vid_preds_global[0].cpu() 430 | results['vid_result'] = {} 431 | results['vid_result']['figr'] = vid_preds_local 432 | results['vid_result']['cogr'] = vid_preds_global 433 | results['vid_result']['matched_indices'] = torch.cat(vid_pred_matched_idxs) 434 | 435 | 436 | return results --------------------------------------------------------------------------------