├── .gitignore ├── .gitmodules ├── asset ├── char_recog.jpg ├── charbank_top10_pr_curve_PerActor.jpg ├── v1_figure.png ├── v2_figure.jpg └── v3_figure.jpg ├── autoad_i ├── model_ad.py └── model_tfm.py ├── autoad_ii ├── character_recognition │ ├── MAD_charbank_2023mar.json │ ├── build_exemplar_mad.py │ ├── build_exemplar_movienet.py │ ├── config.py │ ├── extract_clip_face.py │ ├── name_loader.py │ ├── readme.md │ └── recog_main.py ├── readme.md └── recall_within_neighbours.py ├── autoad_iii ├── metrics │ ├── cast_list_for_eval.json │ ├── cmd_fn_to_imdb.json │ ├── critic_metric.py │ ├── llm_ad_eval_gpt.py │ └── llm_ad_eval_llama.py └── readme.md ├── datasets ├── mad2imdb.json └── mad_split.json ├── models └── gpt_utils.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | *.jpg 3 | *.npy 4 | *.zip 5 | *.pyc 6 | *.csv 7 | *.json 8 | .idea/ 9 | logs/ 10 | .DS_Store 11 | */.DS_Store 12 | .vscode/ 13 | log* 14 | */logs* 15 | *.cluster 16 | wandb/ 17 | */wandb/ 18 | *.mp4 19 | *.pkl 20 | *.out 21 | tmp/ 22 | *.pth 23 | *.pkl 24 | data/ 25 | */sps-*/ 26 | **/sps-*/ 27 | */sps-*.tar.gz 28 | **/sps-*.tar.gz 29 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "coco-caption"] 2 | path = coco_caption 3 | url = https://github.com/LuoweiZhou/coco-caption.git 4 | -------------------------------------------------------------------------------- /asset/char_recog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/AutoAD/9cf094ab32c549510cc9b8c09395d1767dcf5837/asset/char_recog.jpg -------------------------------------------------------------------------------- /asset/charbank_top10_pr_curve_PerActor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/AutoAD/9cf094ab32c549510cc9b8c09395d1767dcf5837/asset/charbank_top10_pr_curve_PerActor.jpg -------------------------------------------------------------------------------- /asset/v1_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/AutoAD/9cf094ab32c549510cc9b8c09395d1767dcf5837/asset/v1_figure.png -------------------------------------------------------------------------------- /asset/v2_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/AutoAD/9cf094ab32c549510cc9b8c09395d1767dcf5837/asset/v2_figure.jpg -------------------------------------------------------------------------------- /asset/v3_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/AutoAD/9cf094ab32c549510cc9b8c09395d1767dcf5837/asset/v3_figure.jpg -------------------------------------------------------------------------------- /autoad_i/model_ad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Tuple, List, Union, Optional 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 5 | from collections import OrderedDict 6 | from einops import rearrange 7 | from model_tfm import PerceiverEncoder 8 | 9 | 10 | class VideoCaptionModel(nn.Module): 11 | def __init__(self, 12 | num_latents: int = 10, 13 | num_layers: int = 2, 14 | prefix_size: int = 512, 15 | use_context_perceiver: int = 0, 16 | use_subtitle_perceiver: int = 0, 17 | **kwargs, 18 | ): 19 | super().__init__() 20 | if len(kwargs): 21 | print(f'WARNING [VideoCaptionModel] kwargs not used: {kwargs}') 22 | self.num_layers = num_layers 23 | self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') 24 | self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] 25 | 26 | ### visual ### 27 | self.perceiver = PerceiverEncoder( 28 | num_latents=num_latents, 29 | d_latents=prefix_size, 30 | num_layers=num_layers, 31 | nhead=prefix_size//64) 32 | self.project = nn.Linear(prefix_size, self.gpt_embedding_size) 33 | nn.init.normal_(self.project.weight, std=prefix_size ** -0.5) 34 | nn.init.zeros_(self.project.bias) 35 | 36 | ### context ### 37 | self.use_context_perceiver = use_context_perceiver 38 | assert use_context_perceiver in [0, 1] 39 | if use_context_perceiver == 1: 40 | # produce , around the context features 41 | self.context_special_token = nn.Embedding(2, embedding_dim=self.gpt_embedding_size) 42 | 43 | ### subtitle ### 44 | self.use_subtitle_perceiver = use_subtitle_perceiver 45 | assert use_subtitle_perceiver in [0, 3, 4] 46 | if use_subtitle_perceiver in [3, 4]: 47 | # produce , around the context features 48 | self.subtitle_special_token = nn.Embedding(2, embedding_dim=self.gpt_embedding_size) 49 | 50 | ### BOS token for AD generation 51 | self.bos_token = nn.Embedding(1, embedding_dim=self.gpt_embedding_size) 52 | 53 | def wrap_context(self, context_embed, prompt=None): 54 | """assume context_embed: B,N,C. Add on it""" 55 | assert prompt is None 56 | B = context_embed.shape[0] 57 | bos = self.context_special_token.weight[None, 0:1].repeat(B,1,1) 58 | eos = self.context_special_token.weight[None, 1:2].repeat(B,1,1) 59 | return torch.cat((bos, context_embed, eos), dim=1) 60 | 61 | def wrap_subtitle(self, subtitle_embed): 62 | B = subtitle_embed.shape[0] 63 | """assume subtitle_embed: B,N,C. Add on it""" 64 | bos = self.subtitle_special_token.weight[None, 0:1].repeat(B,1,1) 65 | eos = self.subtitle_special_token.weight[None, 1:2].repeat(B,1,1) 66 | return torch.cat((bos, subtitle_embed, eos), dim=1) 67 | 68 | def forward(self, visual_feature, mask=None, labels=None): 69 | """purely for visual prompt""" 70 | # visual_feature: b t c 71 | # prefix_vector: b k c 72 | latent_vector = self.perceiver(visual_feature) 73 | prefix_vector = self.project(latent_vector) 74 | return prefix_vector 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | # UNIT TEST 80 | from gpt_utils import generate_beam, generate_greedy 81 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 82 | model = VideoCaptionModel() 83 | prefix_vector = model(torch.randn(1, 1, 512)) 84 | print(generate_greedy(model, tokenizer, embed=prefix_vector)) 85 | print(generate_beam(model, tokenizer, embed=prefix_vector)) -------------------------------------------------------------------------------- /autoad_i/model_tfm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer part modified from OpenAI's CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 3 | Caption module modified from ClipCap: https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing#scrollTo=OArDkm_24w4L 4 | 5 | Designed for short video captioning. 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from typing import Tuple, List, Union, Optional 11 | from collections import OrderedDict 12 | from torch.nn import LayerNorm 13 | from einops import rearrange 14 | 15 | 16 | class QuickGELU(nn.Module): 17 | def forward(self, x: torch.Tensor): 18 | return x * torch.sigmoid(1.702 * x) 19 | 20 | 21 | class ResidualAttentionBlock_Step(nn.Module): 22 | def __init__(self, d_model: int, n_head: int,): 23 | super().__init__() 24 | 25 | self.attn = nn.MultiheadAttention(d_model, n_head) 26 | self.ln_1 = LayerNorm(d_model) 27 | self.mlp = nn.Sequential(OrderedDict([ 28 | ("c_fc", nn.Linear(d_model, d_model * 4)), 29 | ("gelu", QuickGELU()), 30 | ("c_proj", nn.Linear(d_model * 4, d_model)) 31 | ])) 32 | self.ln_2 = LayerNorm(d_model) 33 | 34 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 35 | return tensor if pos is None else tensor + pos 36 | 37 | def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None, pos: torch.Tensor = None): 38 | key_padding_mask = key_padding_mask.to(device=x.device) if key_padding_mask is not None else None 39 | q = k = self.with_pos_embed(x, pos) 40 | return self.attn(q, k, x, need_weights=False, key_padding_mask=key_padding_mask)[0] 41 | 42 | def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None, pos: torch.Tensor = None): 43 | x_norm = self.ln_1(x) 44 | x = x + self.attention(x_norm, key_padding_mask=key_padding_mask, pos=pos) 45 | x = x + self.mlp(self.ln_2(x)) 46 | return x, x_norm 47 | 48 | 49 | class TemporalEncoder(nn.Module): 50 | def __init__(self, width: int, layers: int, heads: int,): 51 | super().__init__() 52 | self.width = width 53 | self.layers = layers 54 | self.resblocks = nn.ModuleList([ResidualAttentionBlock_Step(width, heads) for _ in range(layers)]) 55 | 56 | def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None, pos: torch.Tensor = None): 57 | intermediate = [] 58 | for block in self.resblocks: 59 | x, x_norm = block(x, key_padding_mask, pos) 60 | intermediate.append(x_norm) 61 | intermediate.pop(0) 62 | intermediate.append(x) 63 | return intermediate 64 | 65 | 66 | class PerceiverEncoder(nn.Module): 67 | """Perceiver-like module, with TransformerEncoder([latent; features])""" 68 | def __init__(self, num_latents=16, d_latents=768, nhead=8, num_layers=2): 69 | super().__init__() 70 | self.num_latents = num_latents 71 | self.latent = nn.Parameter(torch.empty(num_latents, d_latents)) 72 | self.temporal_pos_embed = nn.Parameter(torch.empty(512, d_latents)) 73 | self.encoder = TemporalEncoder(width=d_latents, layers=num_layers, heads=nhead) 74 | self.visual_prenorm = LayerNorm(d_latents) 75 | self.initialize_parameters() 76 | 77 | def initialize_parameters(self): 78 | nn.init.normal_(self.latent, mean=0, std=1) 79 | nn.init.normal_(self.temporal_pos_embed, mean=0, std=1.0) 80 | proj_std = (self.encoder.width ** -0.5) * ((2 * self.encoder.layers) ** -0.5) 81 | attn_std = self.encoder.width ** -0.5 82 | fc_std = (2 * self.encoder.width) ** -0.5 83 | for block in self.encoder.resblocks: 84 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 85 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 86 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 87 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 88 | 89 | def forward(self, visual_feature, key_padding_mask=None): 90 | B, T, *_ = visual_feature.shape 91 | visual_feature = rearrange(visual_feature, 'b t c -> t b c') 92 | temp_pos = self.temporal_pos_embed[0:T, None, :] 93 | visual_feature = self.visual_prenorm(visual_feature) + temp_pos 94 | latent = self.latent[:,None,:].repeat(1,B,1) # k,b,c 95 | concat = torch.cat((latent, visual_feature), dim=0) 96 | enc_out = self.encoder(concat, key_padding_mask, pos=None)[-1] # last layer output 97 | 98 | latent_out = enc_out[0:self.num_latents, :] 99 | return rearrange(latent_out, 'k b c -> b k c') 100 | -------------------------------------------------------------------------------- /autoad_ii/character_recognition/build_exemplar_mad.py: -------------------------------------------------------------------------------- 1 | """Find in-movie exemplars for character's profile pictures in MAD, 2 | by taking top-K nearest neighbour""" 3 | 4 | import pandas as pd 5 | import json 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | import os 10 | import matplotlib.pyplot as plt 11 | import glob 12 | 13 | 14 | def draw_sim_curve(curves, char_list): 15 | assert len(curves) == len(char_list) 16 | if isinstance(curves[0], torch.Tensor): 17 | curves = torch.stack(curves, 0) 18 | else: 19 | curves = np.stack(curves, 0) 20 | fig, ax = plt.subplots(figsize=(18,4)) 21 | for i in range(len(char_list)): 22 | ax.plot(curves[i], label=char_list[i]) 23 | ax.legend() 24 | return fig, ax 25 | 26 | 27 | MAD_dir = "/scratch/shared/beegfs/maxbain/datasets/MAD_Language_Grounding_Movie_Audio_Descriptions" 28 | audiovault_dir = "/scratch/shared/beegfs/maxbain/datasets/audiovault" 29 | char_bank_exp_fn = "MAD_charbank_2023mar.json" 30 | TOP_K_CHARS = 10 31 | 32 | # load charbank json 33 | with open(os.path.join(char_bank_exp_fn)) as fobj: 34 | charbank_dict = json.load(fobj) 35 | 36 | # available at: wget http://www.robots.ox.ac.uk/~htd/autoad/MAD_id2imdb.json 37 | with open(os.path.join(MAD_dir, "MAD_id2imdb.json"), "r") as fobj: 38 | imdb_map = json.load(fobj) 39 | 40 | # available at: wget http://www.robots.ox.ac.uk/~htd/autoad/audiovault_actors.csv 41 | char_map = pd.read_csv(os.path.join(audiovault_dir, "audiovault_actors.csv")) 42 | 43 | # using CLIP-L-14 44 | # available at: wget http://www.robots.ox.ac.uk/~htd/autoad/audiovault_face_ViT-L-14.pth.tar 45 | face_data = torch.load('audiovault_face_ViT-L-14.pth.tar') 46 | profile_list = face_data['filenames'] 47 | profile_list = [i.split('.')[0] for i in profile_list] 48 | profile_features = face_data['clip_embedding'] 49 | assert len(profile_list) == profile_features.shape[0] 50 | 51 | char_dict = {} 52 | TOP_K = 5 53 | failed_avid = 0 54 | failed_char = 0 55 | 56 | print(f"computing charbank for top{TOP_K_CHARS} characters for top{TOP_K} features") 57 | 58 | save_root = '/scratch/shared/beegfs/htd/MAD/charbank_cos_top10_cal_2023jul' 59 | os.makedirs(save_root, exist_ok=True) 60 | feature_root = '/scratch/shared/beegfs/htd/DATA/MAD/CLIP_L14_frames_features_5fps' 61 | avid_list = sorted(glob.glob(os.path.join(feature_root, '*.npy'))) 62 | avid_list = [os.path.basename(i).replace('.npy', '') for i in avid_list] 63 | print(f'Get {len(avid_list)} movie features from {feature_root}') 64 | 65 | 66 | for avid in tqdm.tqdm(avid_list): 67 | imdbid = imdb_map[avid] 68 | cdf = char_map[char_map['imdbid']==imdbid] 69 | cdf = cdf.reset_index() 70 | cdf['cast_index'] = cdf.index 71 | 72 | if avid not in charbank_dict: 73 | print(f'{avid} is not downloaded in charbank_dict') 74 | failed_avid += 1 75 | continue 76 | 77 | imdb_dump = charbank_dict[avid] 78 | cast_info = imdb_dump[:TOP_K_CHARS] 79 | 80 | movie_feature = torch.from_numpy(np.load(os.path.join(feature_root, f'{avid}.npy'))).float().cuda() 81 | movie_feature_normed = movie_feature / movie_feature.norm(dim=-1, keepdim=True) 82 | 83 | curves = [] 84 | curves_self = [] 85 | char_dicts = [] 86 | all_cos_curves = [] 87 | 88 | for char_info in cast_info: 89 | char_id_int = int(char_info['id'].replace('nm', '')) 90 | crows = cdf[cdf['id'] == char_id_int] 91 | assert len(crows) == 1 92 | crow = crows.iloc[0] 93 | assert int(char_info['id'].replace('nm','')) == crow['id'] 94 | curr_role = char_info['role'] 95 | 96 | # ignore multi-role actor cases for now 97 | if isinstance(curr_role, list): 98 | curr_role = curr_role[0] 99 | crow_id_str = f"{crow['id']:07d}" 100 | 101 | if char_info['id'] in profile_list: 102 | profile_idx = profile_list.index(char_info['id']) 103 | profile_ftr = profile_features[profile_idx].float()[None,:].cuda() 104 | 105 | # cos_curve = torch.nn.functional.cosine_similarity(curr_MAD_ftrs, ftr_profile) 106 | ftr_profile_normed = profile_ftr / profile_ftr.norm(dim=-1, keepdim=True) 107 | cos_curve = movie_feature_normed @ ftr_profile_normed[0,] 108 | 109 | # get top5 features in the movie then average 110 | # adaptive top5 for diversity: with a gap of 10second (5FPS) 111 | HALF_WINDOW = 10 * 5 112 | cos_curve_copy = cos_curve.clone() 113 | topkidx_list = [] 114 | topkval_list = [] 115 | for _ in range(5): 116 | max_val, max_idx = torch.max(cos_curve_copy, dim=-1) 117 | topkidx_list.append(max_idx) 118 | topkval_list.append(max_val) 119 | cos_curve_copy[max(0, max_idx-HALF_WINDOW): min(cos_curve_copy.shape[0], max_idx+HALF_WINDOW)] = -1 120 | topkidx = torch.stack(topkidx_list, 0) 121 | 122 | # average top5 features as the exemplar 123 | avg_profile = movie_feature_normed[topkidx].mean(0, keepdim=True) 124 | cos_curve_self = movie_feature_normed @ avg_profile[0,] 125 | else: 126 | print(f'{avid} {curr_role} cannot be found in profile_list') 127 | failed_char += 1 128 | cos_curve = torch.ones(movie_feature_normed.shape[0], device='cuda') * -float('inf') 129 | cos_curve_self = torch.ones(movie_feature_normed.shape[0], device='cuda') * -1 130 | topkidx = torch.ones(5, device='cuda') * -1 131 | 132 | char_info_array = { 133 | "id": int(char_info['id'].replace('nm','')), 134 | "name": str(char_info['name']), 135 | "long imdb name": str(char_info['name']), 136 | "role": str(char_info['role']), 137 | "cos_curve": cos_curve.cpu(), 138 | "cos_curve_self_top5": cos_curve_self.cpu(), 139 | 'top5_idx': topkidx.cpu(), 140 | } 141 | curves.append(char_info_array['cos_curve']) 142 | curves_self.append(char_info_array['cos_curve_self_top5']) 143 | char_dicts.append(char_info_array) 144 | 145 | stack_cos = torch.stack([i['cos_curve_self_top5'] for i in char_dicts], 0) 146 | per_movie_info = { 147 | 'roles': [i['role'] for i in char_dicts], 148 | 'names': [i['name'] for i in char_dicts], 149 | 'cos': stack_cos, 150 | 'top5_idx': [i['top5_idx'] for i in char_dicts], 151 | } 152 | char_dict[avid] = per_movie_info 153 | 154 | save_name = f'{avid}.charbank.pth.tar' 155 | torch.save(per_movie_info, os.path.join(save_root, save_name)) 156 | 157 | print('finished') 158 | print(f'failed avid: {failed_avid}') 159 | print(f'failed char: {failed_char}') 160 | print(f'saved to {save_root}') 161 | 162 | -------------------------------------------------------------------------------- /autoad_ii/character_recognition/build_exemplar_movienet.py: -------------------------------------------------------------------------------- 1 | """Find in-movie exemplars for character's profile pictures in MovieNet. 2 | The source is from AudioVault charbank download. 3 | This script works on the intersection: 4 | (set(MovieNet) + set(AudioVault)) - set(MAD) 5 | and process them. 6 | Output: the {imdbid}.charbank.pth.tar files in /scratch/shared/beegfs/htd/DATA/MovieNet/charbank_cos_top10 7 | """ 8 | 9 | import pandas as pd 10 | import json 11 | import numpy as np 12 | import torch 13 | import tqdm 14 | import os 15 | import sys 16 | import matplotlib.pyplot as plt 17 | from glob import glob 18 | import re 19 | 20 | 21 | def draw_sim_curve(curves, char_list): 22 | assert len(curves) == len(char_list) 23 | if isinstance(curves[0], torch.Tensor): 24 | curves = torch.stack(curves, 0) 25 | else: 26 | curves = np.stack(curves, 0) 27 | fig, ax = plt.subplots(figsize=(18,4)) 28 | for i in range(len(char_list)): 29 | ax.plot(curves[i], label=char_list[i]) 30 | ax.legend() 31 | return fig, ax 32 | 33 | 34 | def get_imdb_to_process(): 35 | """a set operation: (set(MovieNet) + set(AudioVault)) - set(MAD)""" 36 | # audiovault 37 | av_id = pd.read_csv('/scratch/shared/beegfs/maxbain/datasets/audiovault/audiovault_imdbids.csv') 38 | av_id_matched = av_id[av_id['exact_match']]['imdbid'].tolist() 39 | # movienet 40 | movienet_anno = pd.read_csv('/work/htd/Desktop_tmp/AutoMad/movienet/movienet_face_anno.csv') 41 | mvn_id = movienet_anno['movie_id'].unique().tolist() 42 | # mad 43 | mad_imdb = pd.read_csv('/scratch/shared/beegfs/htd/DATA/MAD/mad_imdb_info.csv') 44 | mad_id = mad_imdb['imdb'].unique().tolist() 45 | 46 | imdb_list = list((set(mvn_id).intersection(set(av_id_matched))) - set(mad_id)) 47 | print(f"{len(imdb_list)=}") 48 | return sorted(imdb_list), "v1" 49 | 50 | 51 | def get_imdb_to_process_outside_av(): 52 | """a set operation: not in audiovault, but in the downloaded address""" 53 | # audiovault 54 | av_id = pd.read_csv('/scratch/shared/beegfs/maxbain/datasets/audiovault/audiovault_imdbids.csv') 55 | av_id_matched = av_id[av_id['exact_match']]['imdbid'].tolist() 56 | # movienet 57 | movienet_anno = pd.read_csv('/work/htd/Desktop_tmp/AutoMad/movienet/movienet_face_anno.csv') 58 | mvn_id = movienet_anno['movie_id'].unique().tolist() 59 | # mad 60 | mad_imdb = pd.read_csv('/scratch/shared/beegfs/htd/DATA/MAD/mad_imdb_info.csv') 61 | mad_id = mad_imdb['imdb'].unique().tolist() 62 | # recent download 63 | download_list = glob("/scratch/shared/beegfs/maxbain/datasets/audiovault/cast_mainpage_selenium/*.json") 64 | download_id = [os.path.basename(i).split('.')[0] for i in download_list] 65 | 66 | imdb_list = list((set(mvn_id) - set(av_id_matched) - set(mad_id)).intersection(set(download_id))) 67 | print(f"{len(imdb_list)=}") 68 | return sorted(imdb_list), "v2" 69 | 70 | 71 | MAD_dir = "/scratch/shared/beegfs/maxbain/datasets/MAD_Language_Grounding_Movie_Audio_Descriptions" 72 | audiovault_dir = "/scratch/shared/beegfs/maxbain/datasets/audiovault" 73 | movienet_charbank_json_dir = "/scratch/shared/beegfs/maxbain/datasets/audiovault/cast_mainpage_selenium" 74 | char_bank_exp_fn = "../data_post_proc/MAD_charbank_2023mar.json" 75 | TOP_K_CHARS = 10 76 | 77 | 78 | # Load raw info from audiovault 79 | char_map = pd.read_csv(os.path.join(audiovault_dir, "audiovault_actors.csv")) 80 | # char_map = char_map.groupby("imdbid").head(TOP_K_CHARS) 81 | 82 | face_data = torch.load('audiovault_face_ViT-L-14.pth.tar') 83 | profile_list = face_data['filenames'] 84 | profile_list = [i.split('.')[0] for i in profile_list] 85 | profile_features = face_data['clip_embedding'] 86 | assert len(profile_list) == profile_features.shape[0] 87 | 88 | 89 | # two batches of download 90 | # get imdb_id 91 | for get_imdb_fn in [get_imdb_to_process, get_imdb_to_process_outside_av]: 92 | # iterate: 93 | # movienet_imdb_list, version = get_imdb_to_process() 94 | # movienet_imdb_list, version = get_imdb_to_process_outside_av() 95 | 96 | movienet_imdb_list, version = get_imdb_fn() 97 | # check feature exists 98 | feature_root = "/scratch/shared/beegfs/htd/DATA/MovieNet/keyframe_feat/openai-clip-vit-l-14" 99 | print(f"{len(movienet_imdb_list)=}") 100 | movienet_imdb_list = [i for i in movienet_imdb_list if os.path.exists(os.path.join(feature_root, f"{i}.npy"))] 101 | print(f"{len(movienet_imdb_list)=}") 102 | 103 | failed_char = 0 104 | char_dict = {} 105 | save_root = "/scratch/shared/beegfs/htd/MovieNet/charbank_cos_top10" 106 | os.makedirs(save_root, exist_ok=True) 107 | 108 | # what do we get: char info array 109 | for imdbid in tqdm.tqdm(movienet_imdb_list): 110 | if version == 'v1': 111 | pkl_path = os.path.join(audiovault_dir, f'imdb/{imdbid}.pkl') 112 | if not os.path.exists(pkl_path): 113 | print(f"{pkl_path} does not exist") 114 | continue 115 | imdb_dump = np.load(pkl_path, allow_pickle=True) 116 | assert imdbid.replace('tt','') == imdb_dump['imdbID'] 117 | cast_info = imdb_dump['cast'][:TOP_K_CHARS] 118 | elif version == 'v2': 119 | json_path = os.path.join(movienet_charbank_json_dir, f'{imdbid}.json') 120 | cast_info = json.load(open(json_path)) 121 | cast_info = cast_info[:TOP_K_CHARS] 122 | 123 | # get CLIP feature: default L14 124 | v_feature = torch.from_numpy(np.load(f"{feature_root}/{imdbid}.npy")).float().cuda() 125 | v_feature_normed = v_feature / v_feature.norm(dim=-1, keepdim=True) 126 | char_dicts = [] 127 | 128 | for cdx, cinfo in enumerate(cast_info): 129 | if isinstance(cinfo, dict): 130 | curr_role = cinfo['role'] 131 | role = curr_role 132 | role = re.sub(r'\(.*?\)', '', role) 133 | person_id = cinfo['id'].replace('nm', '') 134 | name = cinfo['name'] 135 | long_name = name 136 | else: 137 | curr_role = cinfo._get_currentRole() 138 | # ignore multi-role actor cases for now 139 | if isinstance(curr_role, list): 140 | curr_role = curr_role[0] 141 | role = curr_role['name'] 142 | role = re.sub(r'\(.*?\)', '', role) 143 | person_id = cinfo.personID 144 | name = cinfo['name'] 145 | long_name = cinfo['long imdb name'] 146 | 147 | char_info_array = { 148 | "id": int(person_id), 149 | "name": str(name), 150 | "long imdb name": str(long_name), 151 | "role": str(role), 152 | } 153 | 154 | if 'nm'+str(person_id) in profile_list: 155 | ftr_idx = profile_list.index('nm'+str(person_id)) 156 | ftr_profile = profile_features[ftr_idx].float()[None,:].cuda() 157 | 158 | ftr_profile_normed = ftr_profile / ftr_profile.norm(dim=-1, keepdim=True) 159 | cos_curve = v_feature_normed @ ftr_profile_normed[0,] 160 | # get top5 faces in the movie then average 161 | # adaptive top5 for diversity: with a gap of 10 shots (3 KeyFrames per shot) 162 | HALF_WINDOW = 10 * 3 163 | cos_curve_copy = cos_curve.clone() 164 | topkidx_list = [] 165 | topkval_list = [] 166 | for _ in range(5): 167 | max_val, max_idx = torch.max(cos_curve_copy, dim=-1) 168 | topkidx_list.append(max_idx) 169 | topkval_list.append(max_val) 170 | cos_curve_copy[max(0, max_idx-HALF_WINDOW): min(cos_curve_copy.shape[0], max_idx+HALF_WINDOW)] = -1 171 | topkidx = torch.stack(topkidx_list, 0) 172 | # _, topkidx = torch.topk(cos_curve, k=5, dim=-1) 173 | 174 | avg_profile = v_feature_normed[topkidx].mean(0, keepdim=True) 175 | cos_curve_self = v_feature_normed @ avg_profile[0,] 176 | else: 177 | print(f'{imdbid} {char_info_array["role"]} cannot be found in profile_list') 178 | failed_char += 1 179 | cos_curve = torch.ones(v_feature_normed.shape[0], device='cuda') * -float('inf') 180 | cos_curve_self = torch.ones(v_feature_normed.shape[0], device='cuda') * -1 181 | topkidx = torch.ones(5, device='cuda') * -1 182 | 183 | char_info_array.update({ 184 | "cos_curve": cos_curve.cpu(), 185 | "cos_curve_self_top5": cos_curve_self.cpu(), 186 | 'top5_idx': topkidx.cpu()} 187 | ) 188 | char_dicts.append(char_info_array) 189 | 190 | try: 191 | stack_cos = torch.stack([i['cos_curve_self_top5'] for i in char_dicts], 0) 192 | except: 193 | print("warning: stack COS failed") 194 | stack_cos = torch.zeros(TOP_K_CHARS, v_feature_normed.shape[0]) 195 | 196 | per_movie_info = {'roles': [i['role'] for i in char_dicts], 197 | 'names': [i['name'] for i in char_dicts], 198 | 'cos': stack_cos, 199 | 'top5_idx': [i['top5_idx'] for i in char_dicts], 200 | } 201 | 202 | per_movie_info_simple = [{'id': 'nm'+f"{item['id']:07d}", 'name': item['name'], 'role': item['role']} for item in char_dicts] 203 | char_dict[imdbid] = per_movie_info_simple 204 | save_name = f'{imdbid}.charbank.pth.tar' 205 | torch.save(per_movie_info, os.path.join(save_root, save_name)) 206 | 207 | print(f"{failed_char=}") 208 | print(f"total char should be {len(movienet_imdb_list) * 10}") 209 | print(f"failed ratio = {failed_char/(len(movienet_imdb_list)*10)}") 210 | 211 | if version == 'v1': 212 | with open("MovieNet_charbank_300.json", 'w') as fobj: 213 | json.dump(char_dict, fobj) 214 | elif version == 'v2': 215 | with open("MovieNet_charbank_patch_148.json", 'w') as fobj: 216 | json.dump(char_dict, fobj) 217 | 218 | print(f'finished, saved to {save_root}') 219 | sys.exit(0) 220 | -------------------------------------------------------------------------------- /autoad_ii/character_recognition/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import json 5 | from datetime import datetime 6 | import torch 7 | import numpy as np 8 | import random 9 | from transformers import DistilBertTokenizer, DistilBertModel, MPNetTokenizer, GPT2Tokenizer 10 | from tensorboardX import SummaryWriter 11 | 12 | sys.path.append('../') 13 | import utils.tensorboard_utils as TB 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpt', default='gpt2', type=str) 19 | parser.add_argument('--model', default='dec', type=str) 20 | parser.add_argument('--seed', default=888,type=int) 21 | parser.add_argument('--language_model', default='gpt', type=str) 22 | parser.add_argument('--dataset', default='mad', type=str) 23 | parser.add_argument('--video_filter', default=1, type=int) 24 | parser.add_argument('--num_frames', default=8, type=int) 25 | parser.add_argument('--num_latents', default=10, type=int) 26 | parser.add_argument('--num_layers', default=2, type=int) 27 | parser.add_argument('--fps', default=5, type=int) 28 | parser.add_argument('--batch_size', default=16, type=int) 29 | parser.add_argument('--lr', default=1e-4, type=float) 30 | parser.add_argument('--lr_backbone', default=1e-7, type=float) 31 | parser.add_argument('--loss', default='nce', type=str) 32 | parser.add_argument('--schedule', default=[10000], nargs='*', type=int, 33 | help='learning rate schedule (when to drop lr by 10x)') 34 | parser.add_argument('--wd', default=1e-5, type=float) 35 | parser.add_argument('--resume', default='', type=str) 36 | parser.add_argument('--test', default='', type=str) 37 | parser.add_argument('--pretrain', default='', type=str) 38 | parser.add_argument('--epochs', default=10, type=int) 39 | parser.add_argument('--start_epoch', default=0, type=int) 40 | parser.add_argument('--clip_grad', default=0, type=float) 41 | parser.add_argument('--prefix', default='tmp', type=str) 42 | parser.add_argument('--gpu', default=None, type=str) 43 | parser.add_argument('--unit_test', default=None, type=str) 44 | parser.add_argument('-j', '--num_workers', default=2, type=int) 45 | parser.add_argument('--train_what', default='all', type=str) 46 | parser.add_argument('--name_prefix', default='', type=str) 47 | parser.add_argument('--sim', default='cos', type=str) 48 | parser.add_argument('--sentence_mode', default='cls', type=str) 49 | parser.add_argument('--eval_freq', default=1, type=int) 50 | parser.add_argument('--runtime_save_iter', default=1000, type=int) 51 | parser.add_argument('--aux_loss', default=1, type=int) 52 | parser.add_argument('--dropout', default=0.0, type=float) 53 | parser.add_argument('--clip_temp_mode', default='avg', type=str) 54 | parser.add_argument('--optim_policy', default='default', type=str) 55 | parser.add_argument('--optim', default='adamw', type=str) 56 | parser.add_argument('--amp', default=1, type=int) 57 | parser.add_argument('--num_clips', default=4, type=int) 58 | parser.add_argument('--num_hypo', default=1, type=int) 59 | parser.add_argument('--use_context', default=1, type=int) 60 | parser.add_argument('--use_border_adapter', default=0, type=int) 61 | parser.add_argument('--remove_vision', default=0, type=int) 62 | parser.add_argument('--use_charbank', default=0) 63 | parser.add_argument('--use_unl', default=0, type=int) 64 | parser.add_argument('--lookahead', default=0, type=int) 65 | parser.add_argument('--rephrase', default=0, type=int) 66 | parser.add_argument('--use_bos', default=1, type=int) 67 | 68 | parser.add_argument('--num_history', default=64, type=int) 69 | parser.add_argument('--version', default='raw', type=str) 70 | parser.add_argument('--test_version', default=None, type=str) 71 | parser.add_argument('--clip_version', default='B32', type=str) 72 | parser.add_argument('--dev_split', default=0, type=int) 73 | 74 | parser.add_argument('--backprop_freq', default=1, type=int) 75 | parser.add_argument('--test_num_clips', default=1, type=int) 76 | parser.add_argument('--perceiver_type', default=2, type=int) 77 | parser.add_argument('--context_perceiver_type', default=0, type=int) 78 | parser.add_argument('--subtitle_perceiver_type', default=0, type=int) 79 | parser.add_argument('--context_feature_type', default='gpt', type=str) 80 | 81 | parser.add_argument('--freezeBN', action='store_true') 82 | parser.add_argument('--downstream', action='store_true') 83 | parser.add_argument('--single_video', action='store_true') 84 | parser.add_argument('--single_align_video', action='store_true') 85 | parser.add_argument('--cross_video', action='store_true') 86 | parser.add_argument('--inference', action='store_true') 87 | parser.add_argument('--convert_from_frozen_bn', action='store_true') 88 | parser.add_argument('--keep_bn_eval', action='store_true') 89 | parser.add_argument('--init_s3d', action='store_true') 90 | parser.add_argument('--extract_feature', action='store_true') 91 | parser.add_argument('--feature_root', default='feature_coin/timesformer_8f_1fps', type=str) 92 | 93 | parser.add_argument('--test_mode', default='default-val', type=str) 94 | parser.add_argument('--save_video', action='store_true') 95 | 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def set_path(args): 101 | now = datetime.now() 102 | dt_string = now.strftime("%Y_%m_%d_%H_%M") 103 | args.launch_timestamp = dt_string 104 | 105 | if args.resume: 106 | exp_path = os.path.dirname(os.path.dirname(args.resume)) 107 | elif args.test: 108 | if os.path.dirname(args.test).endswith('model'): 109 | exp_path = os.path.dirname(os.path.dirname(args.test)) 110 | else: 111 | exp_path = os.path.dirname(args.test) 112 | else: 113 | name_prefix = f"{args.name_prefix}_" if args.name_prefix else "" 114 | unit_test_tag = f'[unit-test-{args.unit_test}]' if args.unit_test else '' 115 | clip_tag = f'clips{args.num_clips}_' if args.dataset == 'mad' else "" 116 | context_tag = f'Context{args.num_history}_' if args.use_context else '' 117 | if args.context_perceiver_type != 0: 118 | context_tag += f'{args.context_feature_type}_' 119 | version_tag = f'_{args.version}' if args.dataset == 'mad' else '' 120 | CLIP_tag = f'ViT-{args.clip_version}' 121 | no_vision_tag = f'NoVision_' if args.remove_vision else '' 122 | rephrase_tag = f'Rephrase_' if args.rephrase else '' 123 | exp_path = (f"log-{args.prefix}/{name_prefix}{no_vision_tag}{context_tag}{unit_test_tag}{dt_string}_" 124 | f"{args.model}-{args.gpt}-P{args.perceiver_type}C{args.context_perceiver_type}S{args.subtitle_perceiver_type}_BOS{args.use_bos}_layer{args.num_layers}_latent{args.num_latents}_" 125 | f"Loss-{args.loss}_CharBank{args.use_charbank}_{rephrase_tag}Ahead{args.lookahead}_{args.language_model}_" 126 | f"token-{args.sentence_mode}_sim-{args.sim}_hypo{args.num_hypo}_{args.dataset}{version_tag}_{CLIP_tag}_DEV{args.dev_split}_{clip_tag}frames{args.num_frames}_" 127 | f"policy-{args.optim_policy}_" 128 | f"bs{args.batch_size}_lr{args.lr}") 129 | 130 | pre_prefix = '' 131 | log_path = os.path.join(pre_prefix, exp_path, 'log') 132 | model_path = os.path.join(pre_prefix, exp_path, 'model') 133 | exp_path = os.path.join(pre_prefix, exp_path) 134 | if not os.path.exists(log_path): 135 | os.makedirs(log_path) 136 | if not os.path.exists(model_path): 137 | os.makedirs(model_path) 138 | 139 | with open(f'{log_path}/running_command.txt', 'a') as f: 140 | json.dump({'command_time_stamp':dt_string, **args.__dict__}, f, indent=2) 141 | f.write('\n') 142 | 143 | return log_path, model_path, exp_path 144 | 145 | 146 | def setup(args): 147 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 148 | if torch.cuda.is_available(): 149 | if args.gpu is None: 150 | args.gpu = str(os.environ["CUDA_VISIBLE_DEVICES"]) 151 | else: 152 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 153 | device = torch.device('cuda') 154 | 155 | num_gpu = len(str(args.gpu).split(',')) 156 | args.num_gpu = num_gpu 157 | args.batch_size = num_gpu * args.batch_size 158 | print('=> Effective BatchSize = %d' % args.batch_size) 159 | else: 160 | args.num_gpu = 0 161 | device = torch.device('cpu') 162 | print('=> Run with CPU') 163 | 164 | torch.manual_seed(args.seed) 165 | np.random.seed(args.seed) 166 | random.seed(args.seed) 167 | torch.backends.cudnn.benchmark = True 168 | 169 | args.log_path, args.model_path, args.exp_path = set_path(args) 170 | 171 | writer_train = SummaryWriter(logdir=os.path.join(args.log_path, 'train'), 172 | flush_secs=60) 173 | args.train_plotter = TB.PlotterThread(writer_train) 174 | writer_val = SummaryWriter(logdir=os.path.join(args.log_path, 'val'), 175 | flush_secs=60) 176 | args.val_plotter = TB.PlotterThread(writer_val) 177 | 178 | # re-write language_model if use CLIP 179 | if args.model == 'clip': 180 | args.language_model = 'clip' 181 | elif args.model == 'timesformer': 182 | args.language_model = 'mpnet' 183 | 184 | if args.language_model == 'bert': 185 | args.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 186 | elif args.language_model in ['distilbert']: 187 | args.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 188 | elif args.language_model == 'mpnet': 189 | args.tokenizer = MPNetTokenizer.from_pretrained("microsoft/mpnet-base") 190 | elif args.language_model == 'clip': 191 | import clip 192 | class Tokenizer(): 193 | def __call__(self, str_list, return_tensors='pt', **kwargs): 194 | token = clip.tokenize(str_list, truncate=True) 195 | if return_tensors != 'pt': 196 | token = token.numpy() 197 | return {'input_ids': token} 198 | args.tokenizer = Tokenizer() 199 | elif args.language_model == 'gpt': 200 | args.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 201 | 202 | args.iteration = 1 203 | args.stop_token_int = args.tokenizer(".").input_ids[0] 204 | 205 | if '/srun' in os.environ['_']: # in sbatch 206 | print('running command: {') 207 | for key, item in args.__dict__.items(): 208 | print(f' "{key}": {item}') 209 | print('}') 210 | 211 | return device 212 | 213 | 214 | 215 | def optim_policy(args, model, policy='default', version='gpt2'): 216 | params = [] 217 | no_decay = ['.ln_', '.bn', '.bias', '.logit_scale', '.entropy_scale'] 218 | param_group_no_decay = [] 219 | param_group_with_decay = [] 220 | 221 | if policy == 'default': 222 | ### only train xattn module, fix other gpt weights 223 | for name, param in model.named_parameters(): 224 | if not param.requires_grad: 225 | print(f'Param not requires_grad: {name}') 226 | continue 227 | if ('gpt.' in name) and ('.xattn' not in name): 228 | continue ## never touch gpt weights 229 | print(f'Param to optimize: {name}') 230 | if any([i in name for i in no_decay]): 231 | param_group_no_decay.append(param) 232 | else: 233 | param_group_with_decay.append(param) 234 | params.append({'params': param_group_no_decay, 'lr': args.lr, 'weight_decay': 0.0}) 235 | params.append({'params': param_group_with_decay, 'lr': args.lr, 'weight_decay': args.wd}) 236 | 237 | elif policy == 'pos': 238 | ### only train xattn module, fix other gpt weights, except pos embedding 239 | for name, param in model.named_parameters(): 240 | if not param.requires_grad: 241 | print(f'Param not requires_grad: {name}') 242 | continue 243 | if ('gpt.' in name) and ('wpe.' not in name) and ('.xattn' not in name): 244 | continue ## never touch gpt weights 245 | print(f'Param to optimize: {name}') 246 | if any([i in name for i in no_decay]): 247 | param_group_no_decay.append(param) 248 | else: 249 | param_group_with_decay.append(param) 250 | params.append({'params': param_group_no_decay, 'lr': args.lr, 'weight_decay': 0.0}) 251 | params.append({'params': param_group_with_decay, 'lr': args.lr, 'weight_decay': args.wd}) 252 | 253 | elif policy == 'all': # train all gpt weights 254 | raise NotImplementedError 255 | 256 | elif policy == 'half': # train half gpt blocks 257 | if version == 'gpt2': 258 | NUM_BLOCKS = 12 259 | elif version == 'gpt2-medium': 260 | NUM_BLOCKS = 24 261 | freeze_blocks = ['gpt.transformer.wte', 'gpt.transformer.wpe',] + \ 262 | [f'gpt.transformer.h.{i}.' for i in range(int(NUM_BLOCKS//2))] 263 | 264 | for name, param in model.named_parameters(): 265 | if not param.requires_grad: 266 | print(f'Param not requires_grad: {name}') 267 | continue 268 | if ('gpt.' in name) and ('.xattn' not in name): 269 | if any([i in name for i in freeze_blocks]): 270 | continue 271 | print(f'Param to optimize: {name}') 272 | if any([i in name for i in no_decay]): 273 | param_group_no_decay.append(param) 274 | else: 275 | param_group_with_decay.append(param) 276 | params.append({'params': param_group_no_decay, 'lr': args.lr, 'weight_decay': 0.0}) 277 | params.append({'params': param_group_with_decay, 'lr': args.lr, 'weight_decay': args.wd}) 278 | 279 | else: 280 | raise NotImplementedError 281 | 282 | return params -------------------------------------------------------------------------------- /autoad_ii/character_recognition/extract_clip_face.py: -------------------------------------------------------------------------------- 1 | """Extract CLIP visual feature of character profile images""" 2 | 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | from glob import glob 8 | import clip 9 | import PIL 10 | from PIL import Image 11 | from glob import glob 12 | 13 | 14 | class ImageFolder(Dataset): 15 | def __init__(self, data_root: str, preprocess): 16 | print(f'building dataset from {data_root} ...') 17 | self.data_root = data_root 18 | self.all_paths = sorted(glob(os.path.join(self.data_root, '*'))) 19 | self.preprocess = preprocess 20 | self.dummy = torch.zeros(3, 224, 224) 21 | 22 | def __len__(self): 23 | return len(self.all_paths) 24 | 25 | def __getitem__(self, index: int): 26 | fname = self.all_paths[index] 27 | image_path = f"{fname}" 28 | assert os.path.exists(image_path) 29 | is_error = False 30 | image = self.dummy 31 | try: 32 | image = self.preprocess(Image.open(image_path)) 33 | except PIL.UnidentifiedImageError: 34 | is_error = True 35 | except OSError: 36 | is_error = True 37 | except BaseException: 38 | is_error = True 39 | if is_error: 40 | return image, "ERROR", os.path.basename(image_path) 41 | return image, 'YES', os.path.basename(image_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | # clip_model_type = 'ViT-B/32' 46 | clip_model_type = 'ViT-L/14' 47 | 48 | # model 49 | device = torch.device("cuda:0") 50 | clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False) 51 | clip_model = clip_model.eval() 52 | 53 | # dataset 54 | image_root = '/scratch/shared/beegfs/htd/audiovault/actor_profiles' 55 | # available at: wget http://www.robots.ox.ac.uk/~htd/autoad/actor_profiles.tar 56 | ds = ImageFolder(image_root, preprocess) 57 | dl = torch.utils.data.DataLoader(ds, batch_size=200, shuffle=False, num_workers=8, drop_last=False) 58 | 59 | # main loop 60 | all_embeddings = [] 61 | all_captions = [] 62 | 63 | progress = tqdm(total=len(dl)) 64 | counter = 0 65 | clip_model_name = clip_model_type.replace('/', '-') 66 | out_data_path = f"audiovault_face_{clip_model_name}.pth.tar" 67 | all_valid_mask = [] 68 | 69 | for i, data in enumerate(dl): 70 | images, captions, image_names = data 71 | images = images.to(device) 72 | with torch.no_grad(): 73 | feature = clip_model.encode_image(images).cpu() 74 | is_valid = list(map(lambda x: x != "ERROR", captions)) 75 | mask = torch.tensor(is_valid) 76 | all_embeddings.append(feature[mask]) 77 | image_names = [image_name for j, image_name in enumerate(image_names) if is_valid[j]] 78 | all_captions.extend(image_names) 79 | all_valid_mask.append(mask) 80 | progress.update() 81 | counter += len(image_names) 82 | 83 | all_valid_mask = torch.cat(all_valid_mask, dim=0) 84 | torch.save({"clip_embedding": torch.cat(all_embeddings, dim=0), "filenames": all_captions}, out_data_path) 85 | progress.close() 86 | print(f'finished extracting {clip_model_type} features from {image_root}') 87 | print(f'Success rate: {all_valid_mask.float().mean()}') 88 | assert torch.cat(all_embeddings, dim=0).shape[0] == len(all_captions) 89 | 90 | """ 91 | python extract_clip_face.py 92 | """ -------------------------------------------------------------------------------- /autoad_ii/character_recognition/readme.md: -------------------------------------------------------------------------------- 1 | ## Character recognition module 2 | 3 | character recognition module 4 | 5 | This module takes movie cast list (including name and profile pictures) and movie frame features as inputs, 6 | and outputs the on-screen characters by visual feature matching. 7 | 8 | 9 | ### Preparation 10 | 1. We download actor profile pictures from IMDB page. The collection of pictures can be downloaded [here (27GB)](https://thor.robots.ox.ac.uk/autoad/actor_profiles.tar). 11 | 2. We extract CLIP image features for these actor profile pictures with [./extract_clip_face.py](./extract_clip_face.py). The extracted image features can be downloaded [here](http://www.robots.ox.ac.uk/~htd/autoad/audiovault_face_ViT-L-14.pth.tar). 12 | 3. We download cast list from IMDB page. The collection of cast list (text data) can be downloaded [here](http://www.robots.ox.ac.uk/~htd/autoad/audiovault_actors.csv). 13 | 4. We download [MovieNet dataset](https://movienet.github.io/) and extract their frame features with the [same script](./extract_clip_face.py). We post-processed their face annotations for easy loading: download [here](http://www.robots.ox.ac.uk/~htd/autoad/movienet_face_anno.csv) 14 | 15 | ### Find in-context exemplar 16 | 1. For actor profile pictures, we compare them with the movie frame features, and find their nearest neighbours, as the in-context exemplars. The scripts are [./build_exemplar_mad.py](./build_exemplar_mad.py) and [./build_exemplar_movienet.py](./build_exemplar_movienet.py). 17 | 18 | ### Train character recognition module 19 | 1. After all the preparations above, we train a simple character recognition module on MovieNet faces with this script [./recog_main.py](./recog_main.py) 20 | 2. We inference the trained character recognition module on MAD movie features to recognize on-screen characters. The inference output can be downloaded [here](http://www.robots.ox.ac.uk/~htd/autoad/MAD_char_prob_dict_trainval_MV550_CBcharbank_cos_top10_cal_jul.pkl). 21 | 22 | 23 | ### Results 24 | When evaluated on a subset of MovieNet movies, 25 | the transformer-based character recognition module outperforms baselines by a large margin. 26 | 27 | PR curve 28 | -------------------------------------------------------------------------------- /autoad_ii/character_recognition/recog_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import functools 7 | import random 8 | import numpy as np 9 | import torch.cuda.amp as amp 10 | import time 11 | import matplotlib.pyplot as plt 12 | import torch.nn.functional as F 13 | from einops import rearrange, repeat 14 | import math 15 | from tqdm import tqdm 16 | import string 17 | import json 18 | import pickle 19 | import pandas as pd 20 | 21 | from config import parse_args, set_path, setup, optim_policy 22 | sys.path.append('../gpt/') 23 | from gpt_model import TemporalDecoder 24 | from name_loader import LSMDC_NameLoader, MAD_NameLoader, MovieNet_NameLoader 25 | 26 | sys.path.append('../') 27 | import utils.tensorboard_utils as TB 28 | from utils.train_utils import clip_gradients, in_sbatch, set_worker_sharing_strategy 29 | from utils.data_utils import DataLoaderBG 30 | from utils.utils import AverageMeter, AverageMeter_Raw, save_checkpoint, \ 31 | calc_topk_accuracy, ProgressMeter, neq_load_customized, save_runtime_checkpoint 32 | 33 | 34 | 35 | class CharRecog(nn.Module): 36 | def __init__(self, dim, num_layers, use_proj=False): 37 | super().__init__() 38 | self.dim = dim 39 | self.num_layers = num_layers 40 | assert dim % 64 == 0 41 | self.tfm = TemporalDecoder(width=dim, layers=num_layers, heads=dim//64) 42 | self.pos_enc = nn.Embedding(256, embedding_dim=dim) 43 | nn.init.normal_(self.pos_enc.weight, mean=0, std=0.5) 44 | self.norm = nn.LayerNorm(dim) 45 | self.classifier = nn.Linear(dim, 1) 46 | nn.init.normal_(self.classifier.weight, std=0.01) 47 | self.use_proj = use_proj 48 | if use_proj: 49 | self.face_proj = nn.Linear(dim, dim) 50 | nn.init.normal_(self.face_proj.weight, std=0.01) 51 | nn.init.zeros_(self.face_proj.bias) 52 | 53 | def forward(self, face, face_padding_mask, visual): 54 | """face: B,10,C 55 | face_padding_mask: B,10 56 | visual: B,T,C""" 57 | # warning: tfm expects SEQ, B, C 58 | T = visual.shape[1] 59 | pos = self.pos_enc.weight[None, 0:T] 60 | if self.use_proj: 61 | face = self.face_proj(face) 62 | visual = self.face_proj(visual) 63 | out = self.tfm(x=face.transpose(0,1), memory=visual.transpose(0,1), tgt_key_padding_mask=face_padding_mask, pos=pos.transpose(0,1)) 64 | out = out[-1].transpose(0,1) 65 | logits = self.classifier(self.norm(out)) 66 | return logits 67 | 68 | 69 | def train(loader, model, optimizer, lr_scheduler, grad_scaler, device, epoch, args, val_loader=None): 70 | batch_time = AverageMeter('Time',':.2f') 71 | data_time = AverageMeter('Data',':.2f') 72 | losses = AverageMeter('Loss',':.4f') 73 | progress = ProgressMeter( 74 | len(loader), [batch_time, data_time], 75 | prefix='Epoch:[{}]'.format(epoch)) 76 | model.train() 77 | end = time.time() 78 | tic = time.time() 79 | optimizer.zero_grad() 80 | 81 | for idx, input_data in enumerate(loader): 82 | data_time.update(time.time() - end) 83 | video_seq = input_data['video'].to(device, non_blocking=True) 84 | B,N,T,C = video_seq.shape 85 | video_seq = rearrange(video_seq, 'b n t c -> (b n) t c') 86 | loss_dict = {} 87 | 88 | exemplar_feature = input_data['exampler_feature'].to(device, non_blocking=True) 89 | exemplar_feature = rearrange(exemplar_feature, 'b n t c -> (b n) t c') 90 | exemplar_attn_mask = input_data['exampler_attn_mask'].to(device, non_blocking=True) 91 | exemplar_attn_mask = rearrange(exemplar_attn_mask, 'b n t -> (b n) t') 92 | 93 | # char_text = input_data['char_text'] 94 | tgt_text = input_data['text'] 95 | # tgt = text_to_label(char_text, tgt_text) 96 | tgt = input_data['binary_tgt'].to(device, non_blocking=True) 97 | tgt = rearrange(tgt, 'b n t -> (b n) t') 98 | 99 | logits = model(exemplar_feature, face_padding_mask= ~exemplar_attn_mask.bool(), visual=video_seq) 100 | assert logits.shape[2] == 1 101 | logits_flatten = logits[:,:,0][exemplar_attn_mask.bool()] 102 | tgt_flatten = tgt[exemplar_attn_mask.bool()] 103 | 104 | # # over sampling 105 | # N_tgt = tgt_flatten.float().sum().item() 106 | # random_mask = torch.rand_like(tgt_flatten.float()) 107 | # random_mask = random_mask * (1-tgt_flatten) 108 | # _, chosen_idx = torch.topk(random_mask, k=int(N_tgt)) 109 | # chosen_mask = tgt_flatten.clone() 110 | # chosen_mask.scatter_(0, chosen_idx, 1) 111 | # logits_flatten = logits_flatten[chosen_mask.bool()] 112 | # tgt_flatten = tgt_flatten[chosen_mask.bool()] 113 | 114 | if tgt_flatten.float().mean().item() > 0: 115 | weight_flatten = torch.ones_like(logits_flatten) * 0.5/(1-tgt_flatten.float().mean()) 116 | weight_flatten.masked_fill_(tgt_flatten.bool(), value=0.5/tgt_flatten.float().mean()) 117 | else: 118 | print('warning: all zero labels') 119 | weight_flatten = torch.ones_like(logits_flatten) 120 | loss = F.binary_cross_entropy_with_logits(logits_flatten, tgt_flatten.float(), weight=weight_flatten) 121 | 122 | prec = ((logits_flatten > 0) * tgt_flatten).float().sum() / torch.clamp((logits_flatten > 0).float().sum(), min=1e-5) 123 | recall = ((logits_flatten > 0) * tgt_flatten).float().sum() / torch.clamp(tgt_flatten.float().sum(), min=1e-5) 124 | loss_dict = {'loss': loss.detach(), 'prec': prec, 'recall': recall} 125 | 126 | if idx == 0: 127 | avg_meters = {k: AverageMeter(f'{k}:',':.4f') for k in loss_dict.keys()} 128 | for metric, value in loss_dict.items(): 129 | avg_meters[metric].update(value.item(), B) 130 | 131 | loss.backward() 132 | optimizer.step() 133 | optimizer.zero_grad() 134 | 135 | batch_time.update(time.time() - end) 136 | progress.display(idx) 137 | print('\t' + ' '.join([f"{k}:{v.item():.3f}" for k,v in loss_dict.items()])) 138 | lr_scheduler.step() 139 | 140 | if args.iteration % 5 == 0: 141 | for k, v in loss_dict.items(): 142 | args.train_plotter.add_data(f'local/{k}', v.item(), args.iteration) 143 | 144 | end = time.time() 145 | args.iteration += 1 146 | 147 | print(f'epoch {epoch} finished, takes {time.time() - tic} seconds') 148 | for metric_name, avg_meter in avg_meters.items(): 149 | args.train_plotter.add_data(f'global/{metric_name}', avg_meter.avg, epoch) 150 | return losses.avg 151 | 152 | 153 | @torch.no_grad() 154 | def evaluate(loader, model, device, epoch, args, prefix=None): 155 | model.eval() 156 | if args.test: 157 | all_predictions = [] 158 | 159 | for idx, input_data in tqdm(enumerate(loader), total=len(loader)): 160 | video_seq = input_data['video'].to(device, non_blocking=True) 161 | B,N,T,C = video_seq.shape 162 | video_seq = rearrange(video_seq, 'b n t c -> (b n) t c') 163 | 164 | exemplar_feature = input_data['exampler_feature'].to(device, non_blocking=True) 165 | exemplar_feature = rearrange(exemplar_feature, 'b n t c -> (b n) t c') 166 | exemplar_attn_mask = input_data['exampler_attn_mask'].to(device, non_blocking=True) 167 | exemplar_attn_mask = rearrange(exemplar_attn_mask, 'b n t -> (b n) t') 168 | 169 | if isinstance(loader.dataset, LSMDC_NameLoader): 170 | char_text = input_data['char_text'] 171 | tgt_text = input_data['text'] 172 | tgt = text_to_label(char_text, tgt_text) 173 | tgt = tgt.to(device, non_blocking=True) 174 | tgt = rearrange(tgt, 'b n t -> (b n) t') 175 | else: 176 | # char_text = input_data['char_text'] 177 | tgt_text = input_data['text'] 178 | # tgt = text_to_label(char_text, tgt_text) 179 | tgt = input_data['binary_tgt'].to(device, non_blocking=True) 180 | tgt = rearrange(tgt, 'b n t -> (b n) t') 181 | 182 | logits = model(exemplar_feature, face_padding_mask= ~exemplar_attn_mask.bool(), visual=video_seq) 183 | assert logits.shape[2] == 1 184 | logits_flatten = logits[:,:,0][exemplar_attn_mask.bool()] 185 | tgt_flatten = tgt[exemplar_attn_mask.bool()] 186 | 187 | if tgt_flatten.float().mean().item() > 0: 188 | weight_flatten = torch.ones_like(logits_flatten) * 0.5/(1-tgt_flatten.float().mean()) 189 | weight_flatten.masked_fill_(tgt_flatten.bool(), value=0.5/tgt_flatten.float().mean()) 190 | else: 191 | print('warning: all zero labels') 192 | weight_flatten = torch.ones_like(logits_flatten) 193 | loss = F.binary_cross_entropy_with_logits(logits_flatten, tgt_flatten.float(), weight=weight_flatten) 194 | 195 | prec = ((logits_flatten > 0) * tgt_flatten).float().sum() / torch.clamp((logits_flatten > 0).float().sum(), min=1e-5) 196 | recall = ((logits_flatten > 0) * tgt_flatten).float().sum() / torch.clamp(tgt_flatten.float().sum(), min=1e-5) 197 | loss_dict = {'loss': loss.detach(), 'prec': prec, 'recall': recall} 198 | 199 | if idx == 0: 200 | avg_meters = {k: AverageMeter(f'{k}:',':.4f') for k in loss_dict.keys()} 201 | for metric, value in loss_dict.items(): 202 | avg_meters[metric].update(value.item(), B) 203 | 204 | if args.test: 205 | probs = logits[:,:,0].sigmoid() 206 | probs = probs.masked_fill(~exemplar_attn_mask.bool(), -1) 207 | for n_idx in range(N): 208 | if isinstance(loader.dataset, LSMDC_NameLoader): 209 | start_log = input_data['start'][0][n_idx] 210 | end_log = input_data['end'][0][n_idx] 211 | else: 212 | start_log = input_data['start'][0][n_idx] 213 | end_log = input_data['end'][0][n_idx] 214 | all_predictions.append({'vid': input_data['vid'][0], 215 | 'prob': probs[n_idx].tolist(), 216 | 'start': start_log, 217 | 'end': end_log, 218 | 'movienet_tgt': tgt_flatten.tolist()}) 219 | 220 | print(' '.join([f'{metric_name}: {avg_meter.avg}' for metric_name, avg_meter in avg_meters.items()])) 221 | if args.test: 222 | for pred_item in all_predictions: 223 | if isinstance(pred_item['vid'], np.int64): 224 | pred_item['vid'] = int(pred_item['vid']) 225 | import ipdb; ipdb.set_trace() 226 | with open('MAD_eval_prob.json', 'w') as fobj: 227 | json.dump(all_predictions, fobj) 228 | sys.exit(0) 229 | 230 | # with open('MAD_train_prob.json', 'w') as fobj: json.dump(all_predictions, fobj) 231 | 232 | # char_prob_dict = dict(tuple(pd.DataFrame.from_records(all_predictions).groupby('vid'))) 233 | # with open('MAD_char_prob_dict_train.pkl', 'wb') as fobj: 234 | # pickle.dump(char_prob_dict, fobj) 235 | 236 | for metric_name, avg_meter in avg_meters.items(): 237 | args.val_plotter.add_data(f'global/{metric_name}', avg_meter.avg, epoch) 238 | return avg_meters['loss'].avg 239 | 240 | 241 | def text_to_label(char_text, tgt_text): 242 | assert len(char_text) == len(tgt_text) 243 | B = len(char_text) 244 | N = len(char_text[0]) 245 | tgt_tensor = torch.zeros(B, N, 10, dtype=torch.long) 246 | for b_idx, (char_t, tgt_t) in enumerate(zip(char_text, tgt_text)): 247 | assert len(char_t) == len(tgt_t) 248 | for t_idx, (char, tgt) in enumerate(zip(char_t, tgt_t)): 249 | if 'unknown' in tgt: 250 | continue 251 | tgt_list = tgt.split(',') 252 | tgt_list = [rm_punct(t).strip() for t in tgt_list] 253 | tgt_list = [t for t in tgt_list if len(t) > 0] 254 | if len(tgt_list) == 0: 255 | continue 256 | tgt_array = np.array(tgt_list) 257 | 258 | char_list = char.split('possible characters:')[-1] 259 | char_list = char_list.split('') 260 | char_list = [rm_punct(c).strip() for c in char_list] 261 | char_list = [c for c in char_list if len(c) > 0] 262 | char_array = np.array(char_list) 263 | num_C = char_array.shape[0] 264 | 265 | tgt_tensor[b_idx, t_idx, 0:num_C] = torch.tensor(((char_array[:, None] == tgt_array[None, :]).astype(int).sum(-1) > 0).astype(int)) 266 | return tgt_tensor 267 | 268 | 269 | translator_rm_punct = str.maketrans('', '', string.punctuation) 270 | def rm_punct(s): 271 | new_string = s.translate(translator_rm_punct) 272 | return new_string 273 | 274 | 275 | def get_dataset(args): 276 | batch_size = -1 277 | train_mode = 'train' 278 | val_mode = 'val' 279 | tokenizer = args.tokenizer 280 | 281 | if args.dataset == 'lsmdc_name': 282 | trainD = LSMDC_NameLoader 283 | elif args.dataset == 'mad_name': 284 | trainD = MAD_NameLoader 285 | elif args.dataset == 'movienet_name': 286 | trainD = MovieNet_NameLoader 287 | train_dataset = trainD( 288 | tokenizer=tokenizer, 289 | mode=train_mode, 290 | num_frames=args.num_frames, 291 | num_clips=args.num_clips, 292 | batch_size=batch_size, 293 | version=args.version, 294 | return_gpt_feature=False, # args.context_perceiver_type!=0, 295 | clip_version=args.clip_version, 296 | return_subtitle_gpt_feature=args.subtitle_perceiver_type!=0, 297 | context_feature_type=args.context_feature_type, 298 | use_charbank=args.use_charbank, 299 | lookahead=args.lookahead, 300 | rephrase=args.rephrase, 301 | load_history=int(args.perceiver_type==4), 302 | force_resample=True, 303 | ) 304 | 305 | valD = MovieNet_NameLoader 306 | val_dataset = valD( 307 | tokenizer=tokenizer, 308 | mode=val_mode, 309 | num_frames=args.num_frames, 310 | num_clips=16, # args.num_clips, 311 | batch_size=batch_size, 312 | version='lsmdc_named', # args.version, 313 | return_gpt_feature=False, # args.context_perceiver_type!=0, 314 | clip_version=args.clip_version, 315 | return_subtitle_gpt_feature=args.subtitle_perceiver_type!=0, 316 | context_feature_type=args.context_feature_type, 317 | use_charbank=args.use_charbank, 318 | lookahead=args.lookahead, 319 | load_history=int(args.perceiver_type==4), 320 | force_resample=True, 321 | ) 322 | 323 | train_sampler = torch.utils.data.RandomSampler(train_dataset) 324 | train_loader = DataLoaderBG(train_dataset, 325 | batch_size=args.batch_size, num_workers=args.num_workers, 326 | collate_fn=train_dataset.collate_fn, pin_memory=True, drop_last=True, 327 | shuffle=(train_sampler is None), sampler=train_sampler, 328 | ) 329 | 330 | val_sampler = torch.utils.data.SequentialSampler(val_dataset) 331 | val_bs = args.batch_size 332 | val_loader = DataLoaderBG(val_dataset, 333 | batch_size=val_bs, num_workers=args.num_workers//2, 334 | collate_fn=val_dataset.collate_fn, pin_memory=True, drop_last=False, 335 | shuffle=(val_sampler is None), sampler=val_sampler, 336 | ) 337 | return train_dataset, val_dataset, train_loader, val_loader 338 | 339 | 340 | def get_model_card(tag): 341 | model_card = {} 342 | return model_card.get(tag, tag), tag 343 | 344 | 345 | def main(args): 346 | device = setup(args) 347 | if args.clip_version == 'L14': 348 | visual_dim = 768 349 | else: 350 | visual_dim = 512 351 | model = CharRecog(dim=visual_dim, 352 | num_layers=args.num_layers, 353 | use_proj=True) 354 | model.to(device) 355 | model_without_dp = model 356 | 357 | ### test ### 358 | if args.test: 359 | print(f"test from checkpoint {args.test}") 360 | args.test, _ = get_model_card(args.test) 361 | if os.path.exists(args.test): 362 | checkpoint = torch.load(args.test, map_location='cpu') 363 | state_dict = checkpoint['state_dict'] 364 | args.start_epoch = checkpoint['epoch']+1 365 | args.iteration = checkpoint['iteration'] 366 | best_acc = checkpoint.get('best_acc', 0) 367 | try: 368 | model_without_dp.load_state_dict(state_dict) 369 | except: 370 | print('[WARNING] Non-Equal load for resuming training!') 371 | neq_load_customized(model_without_dp, state_dict, verbose=True) 372 | else: 373 | print(f'{args.test} does not exists, test random init?') 374 | import ipdb; ipdb.set_trace() 375 | args.start_epoch = 1 376 | args.iteration = 0 377 | 378 | unit_test_feature_root = None 379 | print(f'test with {args.test_mode} mode') 380 | 381 | # D = LSMDC_NameLoader 382 | # D = MovieNet_NameLoader 383 | # test_mode = 'test' 384 | 385 | # for MAD-TRAIN movies: 386 | # D = MAD_NameLoader; test_mode = 'train' 387 | 388 | # for MAD-EVAL movies: 389 | D = LSMDC_NameLoader; test_mode = 'test' 390 | 391 | test_dataset = D( 392 | tokenizer=args.tokenizer, 393 | mode=test_mode, 394 | num_frames=args.num_frames, 395 | num_clips=args.num_clips, 396 | unit_test_feature_root=unit_test_feature_root, 397 | version=args.version, 398 | clip_version=args.clip_version, 399 | test_version=args.test_version, 400 | return_gpt_feature=False, 401 | return_subtitle_gpt_feature=args.subtitle_perceiver_type!=0, 402 | context_feature_type=args.context_feature_type, 403 | use_charbank=args.use_charbank, 404 | lookahead=args.lookahead, 405 | load_history=int(args.perceiver_type==4), 406 | force_resample=True 407 | ) 408 | test_sampler = torch.utils.data.SequentialSampler(test_dataset) 409 | 410 | loader = torch.utils.data.DataLoader(test_dataset, 411 | batch_size=1, num_workers=args.num_workers, 412 | collate_fn=test_dataset.collate_fn, pin_memory=True, drop_last=False, 413 | shuffle=False, sampler=test_sampler, 414 | worker_init_fn=set_worker_sharing_strategy 415 | ) 416 | evaluate(loader, model, device, args.start_epoch, args) 417 | sys.exit(0) 418 | 419 | ### dataset ### 420 | _, _, train_loader, val_loader = get_dataset(args) 421 | 422 | ### optimizer ### 423 | params = model.parameters() 424 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) 425 | 426 | ### restart ### 427 | if args.resume: 428 | print(f"resume from checkpoint {args.resume}") 429 | args.resume, _ = get_model_card(args.resume) 430 | checkpoint = torch.load(args.resume, map_location='cpu') 431 | state_dict = checkpoint['state_dict'] 432 | args.start_epoch = checkpoint['epoch']+1 433 | args.iteration = checkpoint['iteration'] 434 | best_acc = checkpoint['best_acc'] 435 | if args.convert_from_frozen_bn: 436 | tmp_state_dict = {} 437 | for k,v in state_dict.items(): 438 | if '.bn' in k: 439 | tmp_state_dict[k.replace('.scale', '.weight')] = v 440 | else: 441 | tmp_state_dict[k] = v 442 | state_dict = tmp_state_dict 443 | 444 | try: 445 | model_without_dp.load_state_dict(state_dict) 446 | except: 447 | missing_keys, unexpected_keys = model_without_dp.load_state_dict(state_dict, strict=False) 448 | if len(missing_keys): 449 | print(f'[Missing keys]:{"="*12}\n{chr(10).join(missing_keys)}\n{"="*20}') 450 | if len(unexpected_keys): 451 | print(f'[Unexpected keys]:{"="*12}\n{chr(10).join(unexpected_keys)}\n{"="*20}') 452 | user_input = input('[WARNING] Non-Equal load for resuming training, continue? [y/n]') 453 | if user_input.lower() == 'n': 454 | sys.exit() 455 | try: 456 | optimizer.load_state_dict(checkpoint['optimizer']) 457 | except Exception as e: 458 | print(f'Not resuming optimizer states due to Error: {e}\nInitialized the optimizer instead...') 459 | ### restart ### 460 | 461 | args.decay_steps = args.epochs * len(train_loader) 462 | args.warmup_epochs = float(args.epochs / 20) 463 | def lr_schedule_fn(iteration, iter_per_epoch, args): 464 | if iteration < args.warmup_epochs * iter_per_epoch: 465 | lr_multiplier = iteration / (args.warmup_epochs * iter_per_epoch) 466 | else: 467 | lr_multiplier = 0.5 * \ 468 | (1. + math.cos(math.pi * (iteration - args.warmup_epochs*iter_per_epoch) / (args.epochs*iter_per_epoch - args.warmup_epochs*iter_per_epoch))) 469 | return lr_multiplier 470 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 471 | optimizer, functools.partial(lr_schedule_fn, iter_per_epoch=len(train_loader), args=args) 472 | ) 473 | if args.resume: 474 | lr_scheduler.step(args.iteration) # for resume mode 475 | grad_scaler = amp.GradScaler() 476 | torch.manual_seed(0) 477 | 478 | best_acc = 100 479 | evaluate(val_loader, model, device, -1, args) 480 | # main loop 481 | for epoch in range(args.start_epoch, args.epochs): 482 | np.random.seed(epoch) 483 | random.seed(epoch) 484 | train(train_loader, model, optimizer, lr_scheduler, grad_scaler, device, epoch, args) 485 | val_loss = evaluate(val_loader, model, device, epoch, args) 486 | if (epoch % args.eval_freq == 0) or (epoch == args.epochs - 1): 487 | # is_best = val_loss < best_acc # temporary use val loss 488 | is_best = False # rewritten 489 | best_acc = min(val_loss, best_acc) 490 | state_dict = model_without_dp.state_dict() 491 | save_dict = { 492 | 'epoch': epoch, 493 | 'state_dict': state_dict, 494 | 'best_acc': best_acc, 495 | 'optimizer': optimizer.state_dict(), 496 | 'iteration': args.iteration} 497 | save_checkpoint(save_dict, is_best, args.eval_freq, 498 | filename=os.path.join(args.model_path, 'epoch%d.pth.tar' % epoch), 499 | keep_all=True) 500 | print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) 501 | sys.exit(0) 502 | 503 | if __name__ == '__main__': 504 | args = parse_args() 505 | main(args) 506 | 507 | 508 | 509 | """ 510 | python recog_main.py --batch_size 64 --num_clips 8 --lookahead 2 --use_context 0 --use_charbank global-ce -j 8 --epochs 5 --clip_version L14 --dataset mad_name 511 | 512 | python recog_main.py --batch_size 64 --num_clips 8 --lookahead 2 --use_context 0 --use_charbank global-ce -j 8 --epochs 5 --clip_version L14 --dataset movienet_name 513 | 514 | 515 | # inference: 516 | python recog_main.py --batch_size 64 --num_clips 8 --lookahead 2 --use_context 0 --use_charbank global-ce -j 8 --epochs 5 --clip_version L14 \ 517 | --test ckpt_dir/model/ckpt.pth.tar 518 | """ -------------------------------------------------------------------------------- /autoad_ii/readme.md: -------------------------------------------------------------------------------- 1 | ## AutoAD-II 2 | The codebase for *AutoAD II: The Sequel - Who, When, and What in Movie Audio Description*. 3 | [[project page]](https://www.robots.ox.ac.uk/~vgg/research/autoad/) 4 | [[AutoAD-II PDF]](https://www.robots.ox.ac.uk/~vgg/publications/2023/Han23a/han23a.pdf) 5 | 6 | 7 | 8 | ### Code and Models 9 | We are working on open-sourcing code and models. 10 | 11 | ### Reference 12 | ```bibtex 13 | @InProceedings{han2023autoad2, 14 | title={{AutoAD II: The Sequel} - Who, When, and What in Movie Audio Description}, 15 | author={Tengda Han and Max Bain and Arsha Nagrani and G\"ul Varol and Weidi Xie and Andrew Zisserman}, 16 | booktitle={ICCV}, 17 | year={2023}} 18 | ``` 19 | -------------------------------------------------------------------------------- /autoad_ii/recall_within_neighbours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_score import BERTScorer # from https://github.com/Tiiiger/bert_score 3 | bert_scorer = BERTScorer(lang="en", rescale_with_baseline=True) 4 | 5 | 6 | def recall_within_neighbours(sentences_gt, sentences_gen, topk=(1,5), N=16): 7 | """compute R@k/N as described in AutoAD-II (https://www.robots.ox.ac.uk/~vgg/publications/2023/Han23a/han23a.pdf) 8 | This metric compares a (long) list of sentences with another list of sentences. 9 | It uses BERTScore (https://github.com/Tiiiger/bert_score) to compute sentence-sentence similarity, 10 | but uses the relative BERTScore values to get a recall, for robustness. 11 | """ 12 | # get sentence-sentence BertScore 13 | ss_score = [] 14 | for sent in sentences_gen: 15 | ss_score.append(bert_scorer.score(sentences_gt, [sent] * len(sentences_gt))[-1]) 16 | ss_score = torch.stack(ss_score, dim=0) 17 | 18 | window = N 19 | topk_output = [] 20 | for i in range(0, ss_score.shape[0]-window+1, window//2): 21 | topk_output.append(calc_topk_accuracy(ss_score[i:i+window,i:i+window], torch.arange(window).to(ss_score.device), topk=topk)) 22 | 23 | topk_avg = torch.stack(topk_output, 0).mean(0).tolist() 24 | for k, res in zip(topk, topk_avg): 25 | print(f"Recall@{k}/{N}: {res:.3f}") 26 | return topk_avg 27 | 28 | 29 | def calc_topk_accuracy(output, target, topk=(1,)): 30 | """ 31 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 32 | Given predicted and ground truth labels, calculate top-k accuracies. 33 | """ 34 | maxk = max(topk) 35 | batch_size = target.size(0) 36 | 37 | _, pred = output.topk(maxk, 1, True, True) 38 | pred = pred.t() 39 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 40 | 41 | res = [] 42 | for k in topk: 43 | correct_k = correct[:k].reshape(-1).float().sum(0) 44 | res.append(correct_k.mul_(1 / batch_size)) 45 | return torch.stack(res) 46 | 47 | 48 | if __name__ == '__main__': 49 | # Example 50 | # in practice, we put all the ADs of a movie in sentences_gt and sentences_gen 51 | sentences_gt = [ 52 | "Inside the flat, Tony switches on the light.", 53 | "Mark and Margo and Inspector Hubbard stand facing him.", 54 | "He quickly opens the door.", 55 | "A detective stands outside.", 56 | "Tony slowly closes the door again.", 57 | "Distressed, Margo turns away.", 58 | "Tony shrugs, looks at the drinks on the sideboard, then walks calmly to the desk and picks up a bottle of whisky.", 59 | "He puts the key on the desk.", 60 | "Taking a glass from the sideboard, he pours himself a large drink.", 61 | ] 62 | sentences_gen = [ 63 | "Tony stands.", 64 | "They stand facing him.", 65 | "He stand.", 66 | "A man stands.", 67 | "Tony closes the door.", 68 | "Margo turns away.", 69 | "Tony walks around.", 70 | "He walks towards the desk.", 71 | "He holds a book.", 72 | ] 73 | result = recall_within_neighbours(sentences_gt, sentences_gen, topk=(1,3), N=4) 74 | # Should get 75 | # Recall@1/4: 0.667 76 | # Recall@3/4: 0.917 77 | 78 | -------------------------------------------------------------------------------- /autoad_iii/metrics/cast_list_for_eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "tt0042208": [ 3 | "Dix Handley", 4 | "Alonzo D. Emmerich", 5 | "Doll Conovan", 6 | "Gus Minissi", 7 | "Doc Erwin Riedenschneider", 8 | "Police Commissioner Hardy", 9 | "Cobby", 10 | "Lt. Ditrich", 11 | "Louis Ciavelli", 12 | "Maria Ciavelli" 13 | ], 14 | "tt0046912": [ 15 | "Tony Wendice", 16 | "Margot Wendice", 17 | "Mark Halliday", 18 | "Chief Inspector Hubbard", 19 | "Charles Swann", 20 | "The Storyteller", 21 | "Detective Pearson", 22 | "Detective Williams", 23 | "First Detective", 24 | "Police Sergeant O'Brien" 25 | ], 26 | "tt0050468": [ 27 | "Wyatt Earp", 28 | "Doc Holliday", 29 | "Laura Denbow", 30 | "Kate Fisher", 31 | "Johnny Ringo", 32 | "Ike Clanton", 33 | "Cotton Wilson", 34 | "Charles Bassett", 35 | "Shanghai Pierce", 36 | "Billy Clanton" 37 | ], 38 | "tt0056592": [ 39 | "Atticus Finch", 40 | "Dill Harris", 41 | "Sheriff Heck Tate", 42 | "Maudie Atkinson", 43 | "Mrs. Dubose", 44 | "Tom Robinson", 45 | "Calpurnia", 46 | "Judge Taylor", 47 | "Mayella Violet Ewell", 48 | "Bob Ewell" 49 | ], 50 | "tt0059742": [ 51 | "Maria", 52 | "Captain Georg von Trapp", 53 | "The Baroness", 54 | "Max Detweiler", 55 | "Mother Abbess", 56 | "Liesl von Trapp", 57 | "Louisa von Trapp", 58 | "Friedrich von Trapp", 59 | "Kurt von Trapp", 60 | "Brigitta von Trapp" 61 | ], 62 | "tt0064115": [ 63 | "Butch Cassidy", 64 | "The Sundance Kid", 65 | "Etta Place", 66 | "Percy Garris", 67 | "Bike Salesman", 68 | "Sheriff Bledsoe", 69 | "Woodcock", 70 | "Agnes", 71 | "Harvey Logan", 72 | "Marshal" 73 | ], 74 | "tt0066999": [ 75 | "Harry", 76 | "Bressler", 77 | "Chico", 78 | "The Mayor", 79 | "Killer", 80 | "Chief", 81 | "De Georgio", 82 | "Mrs. Russell", 83 | "Norma", 84 | "Bus Driver" 85 | ], 86 | "tt0079592": [ 87 | "Sherlock Holmes", 88 | "Dr. John H. Watson", 89 | "Inspector Foxborough", 90 | "Mary Kelly", 91 | "Sir Charles Warren", 92 | "Prime Minister Lord Salisbury", 93 | "Inspector Lestrade", 94 | "Robert Lees", 95 | "Annie Crook", 96 | "Doctor Hardy" 97 | ], 98 | "tt0081398": [ 99 | "Jake La Motta", 100 | "Vickie La Motta", 101 | "Joey", 102 | "Salvy", 103 | "Tommy Como", 104 | "Lenore", 105 | "Mario", 106 | "Patsy", 107 | "Guido", 108 | "Toppy" 109 | ], 110 | "tt0082495": [ 111 | "Laurie Strode", 112 | "Sam Loomis", 113 | "Leigh Brackett", 114 | "Graham", 115 | "Jimmy", 116 | "Karen", 117 | "Gary Hunt", 118 | "The Shape / Patrolman #3", 119 | "Budd", 120 | "Mrs. Alves" 121 | ], 122 | "tt0083929": [ 123 | "Jeff Spicoli", 124 | "Stacy Hamilton", 125 | "Brad Hamilton", 126 | "Mike Damone", 127 | "Mark 'Rat' Ratner", 128 | "Linda Barrett", 129 | "Mr. Hand", 130 | "Arnold", 131 | "Mr. Vargas", 132 | "Lisa" 133 | ], 134 | "tt0088763": [ 135 | "Marty McFly", 136 | "Dr. Emmett Brown", 137 | "Lorraine Baines", 138 | "George McFly", 139 | "Biff Tannen", 140 | "Jennifer Parker", 141 | "Dave McFly", 142 | "Linda McFly", 143 | "Sam Baines", 144 | "Stella Baines" 145 | ], 146 | "tt0092005": [ 147 | "Gordie Lachance", 148 | "Chris Chambers", 149 | "Teddy Duchamp", 150 | "Vern Tessio", 151 | "Ace Merrill", 152 | "Billy Tessio", 153 | "Charlie Hogan", 154 | "Eyeball Chambers", 155 | "Vince Desjardins", 156 | "Mr. Lachance" 157 | ], 158 | "tt0096969": [ 159 | "Ron Kovic", 160 | "Young Ron", 161 | "Mr. Kovic", 162 | "Mrs. Kovic", 163 | "Tommy Kovic", 164 | "Young Tommy", 165 | "Jimmy Kovic", 166 | "Young Jimmy", 167 | "Susanne Kovic", 168 | "Young Susanne" 169 | ], 170 | "tt0099204": [ 171 | "Joey O'Brien", 172 | "Larry", 173 | "Tina", 174 | "Joy Munchack", 175 | "Harry Munchack", 176 | "Lila", 177 | "Donna", 178 | "Little Jack Turgeon", 179 | "Big Jack Turgeon", 180 | "Benny" 181 | ], 182 | "tt0101912": [ 183 | "Johnny", 184 | "Frankie", 185 | "Nick", 186 | "Tim", 187 | "Cora", 188 | "Nedda", 189 | "Tino", 190 | "Luther", 191 | "Artemis", 192 | "Jorge" 193 | ], 194 | "tt0101917": [ 195 | "Freddy Krueger", 196 | "Maggie Burroughs", 197 | "John Doe", 198 | "Tracy", 199 | "Carlos", 200 | "Spencer", 201 | "Doc", 202 | "Childless Man", 203 | "Childless Woman", 204 | "Orphanage Woman" 205 | ], 206 | "tt0103919": [ 207 | "Helen Lyle", 208 | "The Candyman", 209 | "Trevor Lyle", 210 | "Bernadette Walsh", 211 | "Anne-Marie McCoy", 212 | "Jake", 213 | "Clara", 214 | "Billy", 215 | "Monica", 216 | "Student" 217 | ], 218 | "tt0106598": [ 219 | "Air Traffic Controller", 220 | "Captain Air Traffic", 221 | "F-16 Pilot", 222 | "Beldar Conehead / Donald R. DeCicco", 223 | "Prymatt Conehead / Mary Margaret DeCicco", 224 | "Ang Pilot", 225 | "Motel Clerk", 226 | "Customer", 227 | "Otto", 228 | "Marlax" 229 | ], 230 | "tt0107616": [ 231 | "Leonato", 232 | "Hero", 233 | "Margaret", 234 | "Friar Francis", 235 | "Antonio", 236 | "George Seacole", 237 | "Francis Seacole", 238 | "Hugh Oatcake", 239 | "Ursula", 240 | "Beatrice" 241 | ], 242 | "tt0110366": [ 243 | "Spanky", 244 | "Stymie", 245 | "Froggy", 246 | "Porky", 247 | "Buckwheat", 248 | "Uh-Huh", 249 | "Butch", 250 | "Woim", 251 | "Waldo", 252 | "Mary Ann" 253 | ], 254 | "tt0110413": [ 255 | "Leon", 256 | "Stansfield", 257 | "Mathilda", 258 | "Tony", 259 | "Malky", 260 | "1st Stansfield Man", 261 | "2nd Stansfield Man", 262 | "3rd Stansfield Man", 263 | "4th Stansfield Man", 264 | "Mathilda's Father" 265 | ], 266 | "tt0112442": [ 267 | "Girl Decoy", 268 | "Carjacker", 269 | "Eddie Dominguez", 270 | "Fouchet", 271 | "Noah Trafficante", 272 | "Kuni", 273 | "Ferguson", 274 | "Casper", 275 | "Andy", 276 | "Marcus Burnett" 277 | ], 278 | "tt0112572": [ 279 | "Carol Brady", 280 | "Mike Brady", 281 | "Marcia Brady", 282 | "Greg Brady", 283 | "Jan Brady", 284 | "Peter Brady", 285 | "Cindy Brady", 286 | "Bobby Brady", 287 | "Alice Nelson", 288 | "Sam Franklin" 289 | ], 290 | "tt0112573": [ 291 | "Young William", 292 | "Malcolm Wallace", 293 | "John Wallace", 294 | "Campbell", 295 | "MacClannough", 296 | "Elder Stewart", 297 | "Young Hamish", 298 | "Mother MacClannough", 299 | "Priest No. 1", 300 | "Young Murron" 301 | ], 302 | "tt0114608": [ 303 | "Jeryline", 304 | "Cordelia", 305 | "Irene", 306 | "Uncle Willy", 307 | "Roach", 308 | "Sheriff Tupper", 309 | "Wally Enfield", 310 | "Homer", 311 | "Wanda", 312 | "Danny" 313 | ], 314 | "tt0114814": [ 315 | "McManus", 316 | "Keaton", 317 | "Fenster", 318 | "Hockney", 319 | "Verbal", 320 | "Dave Kujan", 321 | "Kobayashi", 322 | "Edie Finneran", 323 | "Jack Baer", 324 | "Jeff Rabin" 325 | ], 326 | "tt0115783": [ 327 | "Keats", 328 | "Moses", 329 | "Colton", 330 | "Bledsoe", 331 | "Capt. Jensen", 332 | "Traci", 333 | "Detective Sulliman", 334 | "Detective Jones", 335 | "Finch", 336 | "Charles" 337 | ], 338 | "tt0116705": [ 339 | "Howard Langston", 340 | "Myron Larabee", 341 | "Ted Maltin", 342 | "Liz Langston", 343 | "Officer Hummell", 344 | "D.J.", 345 | "Jamie Langston", 346 | "Mall Santa", 347 | "Johnny", 348 | "First Lady" 349 | ], 350 | "tt0119528": [ 351 | "Fletcher Reede", 352 | "Audrey Reede", 353 | "Max Reede", 354 | "Jerry", 355 | "Greta", 356 | "Samantha Cole", 357 | "Miranda", 358 | "Judge Marshall Stevens", 359 | "Dana Appleton", 360 | "Mr. Allan" 361 | ], 362 | "tt0120694": [ 363 | "Laurie Strode / Keri Tate", 364 | "Will Brennan", 365 | "Molly", 366 | "Charlie", 367 | "Sarah", 368 | "Norma", 369 | "John", 370 | "Ronny", 371 | "Jimmy", 372 | "Tony" 373 | ], 374 | "tt0120784": [ 375 | "Porter", 376 | "Val Resnick", 377 | "Rosie", 378 | "Arthur Stegman", 379 | "Det. Hicks", 380 | "Mrs. Lynn Porter", 381 | "Phil", 382 | "Carter", 383 | "Pearl", 384 | "Det. Leary" 385 | ], 386 | "tt0120787": [ 387 | "Steven Taylor", 388 | "Emily Bradford Taylor", 389 | "David Shaw", 390 | "Mohamed Karaman", 391 | "Raquel Martinez", 392 | "Bobby Fain", 393 | "Ambassador Alice Wills", 394 | "Sandra Bradford", 395 | "Jason Gates", 396 | "Ann Gates" 397 | ], 398 | "tt0134119": [ 399 | "Tom Ripley", 400 | "Marge Sherwood", 401 | "Dickie Greenleaf", 402 | "Meredith Logue", 403 | "Freddie Miles", 404 | "Peter Smith-Kingsley", 405 | "Herbert Greenleaf", 406 | "Inspector Roverini", 407 | "Alvin MacCarron", 408 | "Aunt Joan" 409 | ], 410 | "tt0159273": [ 411 | "Burnett", 412 | "Reigart", 413 | "Stackhouse", 414 | "Rodway", 415 | "Piquet", 416 | "O'Malley", 417 | "Lokar", 418 | "Tracker", 419 | "Bazda", 420 | "Petty Officer Kennedy" 421 | ], 422 | "tt0161081": [ 423 | "Claire Spencer", 424 | "Caitlin Spencer", 425 | "Mary Feur", 426 | "Warren Feur", 427 | "Norman Spencer", 428 | "Beatrice", 429 | "Jody", 430 | "PhD Student #1", 431 | "PhD Student #2", 432 | "Teddy" 433 | ], 434 | "tt0168501": [ 435 | "Harper Stewart", 436 | "Jordan Armstrong", 437 | "Lance Sullivan", 438 | "Julian Murch", 439 | "Quentin", 440 | "Robin", 441 | "Mia Morgan", 442 | "Shelby", 443 | "Anita", 444 | "Candy" 445 | ], 446 | "tt0181689": [ 447 | "Chief John Anderton", 448 | "Director Lamar Burgess", 449 | "Jad", 450 | "Fletcher", 451 | "Knott", 452 | "Evanna", 453 | "Pre-Crime Cop", 454 | "Pre-Crime Cop", 455 | "Pre-Crime Cop", 456 | "Pre-Crime Cop" 457 | ], 458 | "tt0186566": [ 459 | "Frank Corvin", 460 | "Hawk Hawkins", 461 | "Jerry O'Neill", 462 | "Tank Sullivan", 463 | "Bob Gerson", 464 | "Sara Holland", 465 | "Eugene Davis", 466 | "Ethan Glance", 467 | "Roger Hines", 468 | "Barbara Corvin" 469 | ], 470 | "tt0212338": [ 471 | "Jack Byrnes", 472 | "Greg Focker", 473 | "Pam Byrnes", 474 | "Dina Byrnes", 475 | "Deborah Byrnes", 476 | "Denny Byrnes", 477 | "Kevin Rawley", 478 | "Dr. Larry Banks", 479 | "Dr. Bob Banks", 480 | "Linda Banks" 481 | ], 482 | "tt0243133": [ 483 | "Ed Crane", 484 | "Doris Crane", 485 | "Frank", 486 | "Big Dave Brewster", 487 | "Ann Nirdlinger Brewster", 488 | "Creighton Tolliver", 489 | "Birdy Abundas", 490 | "Walter Abundas", 491 | "Freddy Riedenschneider", 492 | "Officer Persky" 493 | ], 494 | "tt0257076": [ 495 | "Sgt. Dan 'Hondo' Harrelson", 496 | "Jim Street", 497 | "Chris Sanchez", 498 | "Deacon 'Deke' Kaye", 499 | "T.J. McCabe", 500 | "Brian Gamble", 501 | "Michael Boxer", 502 | "Alex Montel", 503 | "Lt. Greg Velasquez", 504 | "Capt. Thomas Fuller" 505 | ], 506 | "tt0268380": [ 507 | "Manfred", 508 | "Sid", 509 | "Diego", 510 | "Soto", 511 | "Zeke", 512 | "Carl", 513 | "Frank / Start", 514 | "Oscar", 515 | "Lenny / Oscar / Dab", 516 | "Jennifer" 517 | ], 518 | "tt0274166": [ 519 | "Johnny English", 520 | "Exotic Woman", 521 | "Bough", 522 | "Agent One", 523 | "Carlos Vendetta", 524 | "Dieter Klein", 525 | "Official at Funeral", 526 | "Prime Minister", 527 | "Pegasus", 528 | "Pegasus' Secretary" 529 | ], 530 | "tt0312528": [ 531 | "The Cat", 532 | "Quinn", 533 | "Mom", 534 | "Sally", 535 | "Conrad", 536 | "Mrs. Kwan", 537 | "Mr. Humberfloob / Voice of the Fish", 538 | "Thing One", 539 | "Thing One", 540 | "Thing Two" 541 | ], 542 | "tt0322259": [ 543 | "Brian O'Conner", 544 | "Roman Pearce", 545 | "Monica Fuentes", 546 | "Carter Verone", 547 | "Tej", 548 | "Agent Bilkins", 549 | "Agent Markham", 550 | "Suki", 551 | "Orange Julius", 552 | "Slap Jack" 553 | ], 554 | "tt0356634": [ 555 | "Jon", 556 | "Liz", 557 | "Happy Chapman", 558 | "Garfield", 559 | "Wendell", 560 | "Christopher Mello", 561 | "Miss Ace Hardware", 562 | "Announcer", 563 | "Dog Owner #1", 564 | "Dog Owner #2" 565 | ], 566 | "tt0367652": [ 567 | "Deuce Bigalow", 568 | "T.J. Hicks", 569 | "Gaspar Voorsboch", 570 | "Heinz Hummer", 571 | "Chadsworth Buckingham, III", 572 | "Rodrigo", 573 | "Gian-Carlo", 574 | "Eva", 575 | "Enzo Giarraputo", 576 | "Assapopoulos Mariolis" 577 | ], 578 | "tt0375173": [ 579 | "Alfie", 580 | "Lu Schnitman", 581 | "Dorie", 582 | "Phil", 583 | "Julie", 584 | "Terry", 585 | "Max", 586 | "Marlon", 587 | "Lonette", 588 | "Wing" 589 | ], 590 | "tt0377091": [ 591 | "Sam", 592 | "Clyde", 593 | "Marty", 594 | "Rocky", 595 | "George", 596 | "Millie", 597 | "Kile", 598 | "Maggie Tooney", 599 | "Jasper", 600 | "Cashier" 601 | ], 602 | "tt0377107": [ 603 | "Catherine", 604 | "Robert", 605 | "Harold Dobbs - Hal", 606 | "Cop", 607 | "Claire", 608 | "Limo Driver", 609 | "Professor Barrow", 610 | "Friend at Party", 611 | "Friend at Party", 612 | "Theoretical Physicist" 613 | ], 614 | "tt0387575": [ 615 | "Jennifer Tilly / Tiffany", 616 | "Chucky", 617 | "Glen / Glenda", 618 | "Redman", 619 | "Joan", 620 | "Pete Peters", 621 | "Psychs", 622 | "Stan", 623 | "Tony Gardner", 624 | "Santa" 625 | ], 626 | "tt0389860": [ 627 | "Michael Newman", 628 | "Donna Newman", 629 | "Morty", 630 | "Ammer", 631 | "Ted Newman", 632 | "Trudy Newman", 633 | "Bill", 634 | "Ben Newman at 7 Years Old", 635 | "Ben at 17 Years Old", 636 | "Ben Newman at 22-30 Years Old" 637 | ], 638 | "tt0402022": [ 639 | "Aeon Flux", 640 | "Trevor Goodchild", 641 | "Oren Goodchild", 642 | "Sithandra", 643 | "Handler", 644 | "Keeper", 645 | "Una Flux", 646 | "Freya", 647 | "Claudius", 648 | "Giroux" 649 | ], 650 | "tt0407887": [ 651 | "Billy", 652 | "Colin", 653 | "Costello", 654 | "Dignam", 655 | "Queenan", 656 | "Mr. French", 657 | "Madolyn", 658 | "Brown", 659 | "Ellerby", 660 | "Cousin Sean" 661 | ], 662 | "tt0409459": [ 663 | "Laurie Jupiter / Silk Spectre II", 664 | "Dr. Manhattan / Jon Osterman", 665 | "Adrian Veidt / Ozymandias", 666 | "Rorschach", 667 | "Edward Blake / Comedian", 668 | "Dan Dreiberg / Nite Owl", 669 | "Sally Jupiter / Silk Spectre", 670 | "Moloch", 671 | "Hollis Mason", 672 | "Janey Slater" 673 | ], 674 | "tt0414853": [ 675 | "Otis the Cow", 676 | "Daisy the Cow", 677 | "Ben the Cow", 678 | "Miles the Mule", 679 | "Bessy the Cow", 680 | "Etta the Hen", 681 | "Dag the Coyote", 682 | "Pip the Mouse", 683 | "Freddy the Ferret", 684 | "Peck the Rooster / Gopher" 685 | ], 686 | "tt0441773": [ 687 | "Po", 688 | "Shifu", 689 | "Tigress", 690 | "Tai Lung", 691 | "Monkey", 692 | "Mantis", 693 | "Viper", 694 | "Crane", 695 | "Oogway", 696 | "Mr. Ping" 697 | ], 698 | "tt0452608": [ 699 | "Jensen Ames", 700 | "Hennessey", 701 | "Coach", 702 | "Machine Gun Joe", 703 | "Case", 704 | "Pachenko", 705 | "Ulrich", 706 | "Lists", 707 | "Gunner", 708 | "Travis Colt" 709 | ], 710 | "tt0454945": [ 711 | "Viola", 712 | "Duke", 713 | "Olivia Lennox", 714 | "Dinklage", 715 | "Gold", 716 | "Daphne", 717 | "Justin", 718 | "Monique", 719 | "Paul", 720 | "Kia" 721 | ], 722 | "tt0478304": [ 723 | "Mr. O'Brien", 724 | "Jack", 725 | "Mrs. O'Brien", 726 | "Young Jack", 727 | "R.L.", 728 | "Steve", 729 | "Grandmother", 730 | "Guide", 731 | "Mr. Reynolds", 732 | "Architect" 733 | ], 734 | "tt0493464": [ 735 | "Wesley", 736 | "Sloan", 737 | "Fox", 738 | "Pekwarsky", 739 | "Cross", 740 | "The Gunsmith", 741 | "Cathy", 742 | "The Repairman", 743 | "Mr. X", 744 | "The Exterminator" 745 | ], 746 | "tt0758758": [ 747 | "Chris McCandless", 748 | "Billie McCandless", 749 | "Walt McCandless", 750 | "Carine McCandless / Additional Narrator", 751 | "Rainey", 752 | "Jan Burres", 753 | "Wayne Westerberg", 754 | "Tracy Tatro", 755 | "Ron Franz", 756 | "Jim Gallien" 757 | ], 758 | "tt0803096": [ 759 | "Anduin Lothar", 760 | "Garona", 761 | "Medivh", 762 | "Llane Wrynn", 763 | "Durotan / Antonidas", 764 | "Khadgar", 765 | "Orgrim", 766 | "Blackhand", 767 | "Gul'dan", 768 | "Lady Taria" 769 | ], 770 | "tt0872230": [ 771 | "Bug", 772 | "Alex", 773 | "Jerome", 774 | "Penelope", 775 | "Brandon", 776 | "Brittany", 777 | "Jay", 778 | "Fang", 779 | "Abel", 780 | "May" 781 | ], 782 | "tt0887912": [ 783 | "Staff Sergeant William James", 784 | "Sergeant JT Sanborn", 785 | "Specialist Owen Eldridge", 786 | "Sergeant Matt Thompson", 787 | "Contractor Team Leader", 788 | "Colonel Reed", 789 | "Connie James", 790 | "Colonel John Cambridge", 791 | "Black Suit Man", 792 | "Beckham" 793 | ], 794 | "tt0898367": [ 795 | "Man", 796 | "Boy", 797 | "Old Man", 798 | "Veteran", 799 | "Motherly Woman", 800 | "Thief", 801 | "Gang Member", 802 | "Woman", 803 | "Bearded Man", 804 | "Archer's Woman" 805 | ], 806 | "tt0901476": [ 807 | "Liv", 808 | "Emma", 809 | "Nate", 810 | "Fletcher", 811 | "Daniel", 812 | "Marion", 813 | "Deb", 814 | "Kevin", 815 | "Colson", 816 | "Kathy" 817 | ], 818 | "tt0905372": [ 819 | "Kate Lloyd", 820 | "Carter", 821 | "Dr. Sander Halvorson", 822 | "Adam Finch", 823 | "Jameson", 824 | "Griggs", 825 | "Edvard Wolner", 826 | "Juliette", 827 | "Lars", 828 | "Olav" 829 | ], 830 | "tt0963794": [ 831 | "Jeff", 832 | "Amy", 833 | "Stacy", 834 | "Eric", 835 | "Mathias", 836 | "Lead Mayan", 837 | "Mayan Bowman", 838 | "Mayan Horseman", 839 | "Dimitri", 840 | "Taxi Driver" 841 | ], 842 | "tt0981227": [ 843 | "Nick", 844 | "Norah", 845 | "Thom", 846 | "Dev", 847 | "Caroline", 848 | "Tris", 849 | "Beefy Guy", 850 | "Gary", 851 | "Tal", 852 | "Bishop Allen" 853 | ], 854 | "tt1013743": [ 855 | "Roy Miller", 856 | "June Havens", 857 | "Fitzgerald", 858 | "Antonio", 859 | "Director George", 860 | "Simon Feck", 861 | "Bernhard", 862 | "Rodney", 863 | "Braces", 864 | "April Havens" 865 | ], 866 | "tt1092026": [ 867 | "Young Tara", 868 | "Graeme Willy", 869 | "Clive Gollings", 870 | "Sword Vendor", 871 | "Adam Shadowchild", 872 | "Security Guard", 873 | "Adam Shadowchild Fan", 874 | "Jorge", 875 | "Valet", 876 | "Pat Stevens" 877 | ], 878 | "tt1127180": [ 879 | "Christine Brown", 880 | "Clay Dalton", 881 | "Mrs. Ganush", 882 | "Rham Jas", 883 | "Mr. Jacks", 884 | "Shaun San Dena", 885 | "Leonard Dalton", 886 | "Stu Rubin", 887 | "Trudy Dalton", 888 | "Ilenka Ganush" 889 | ], 890 | "tt1261945": [ 891 | "Carrie Bradshaw", 892 | "Charlotte York", 893 | "Miranda Hobbes", 894 | "Samantha Jones", 895 | "Bergdorf Salesgirl", 896 | "Mr. Big", 897 | "Steve", 898 | "Harry", 899 | "Lily", 900 | "Lily" 901 | ], 902 | "tt1263670": [ 903 | "Bad Blake", 904 | "Manager", 905 | "Barmaid", 906 | "Jack Greene", 907 | "Bill Wilson", 908 | "Tony", 909 | "Jo Ann", 910 | "Wesley Barnes", 911 | "Jean Craddock", 912 | "Ann" 913 | ], 914 | "tt1291584": [ 915 | "Brendan Conlon", 916 | "Tommy Conlon", 917 | "Paddy Conlon", 918 | "Tess Conlon", 919 | "Frank Campana", 920 | "Principal Zito", 921 | "Colt Boyd", 922 | "Bryan Callen", 923 | "Sam Sheridan", 924 | "Fenroy" 925 | ], 926 | "tt1386703": [ 927 | "Douglas Quaid / Hauser", 928 | "Lori Quaid", 929 | "Melina", 930 | "Cohaagen", 931 | "Harry", 932 | "Matthias", 933 | "McClane", 934 | "Marek", 935 | "Resistance Fighter", 936 | "Military Adjutant" 937 | ], 938 | "tt1412386": [ 939 | "Evelyn Greenslade", 940 | "Graham Dashwood", 941 | "Graham's Colleague", 942 | "Judge", 943 | "Estate Agent", 944 | "Douglas Ainslie", 945 | "Jean Ainslie", 946 | "Muriel Donnelly", 947 | "Staff Nurse", 948 | "Dr. Ghujarapartidar" 949 | ], 950 | "tt1536044": [ 951 | "Surveillance Camera Expert", 952 | "Daniel Rey", 953 | "Ali Rey", 954 | "Katie", 955 | "Brad", 956 | "Kristi Rey", 957 | "Hunter Rey", 958 | "Hunter Rey", 959 | "Micah", 960 | "Martine" 961 | ], 962 | "tt1600195": [ 963 | "CIA Man", 964 | "Riah", 965 | "Thermal", 966 | "Mara", 967 | "Game Announcer", 968 | "Hot Dog Vendor", 969 | "Driver", 970 | "Mrs. Murphy", 971 | "Kozlow's Tech", 972 | "Gregory" 973 | ], 974 | "tt1602098": [ 975 | "Albert Nobbs", 976 | "Emmy", 977 | "Helen", 978 | "Mrs. Baker", 979 | "Mary", 980 | "Sean Casey", 981 | "Patrick", 982 | "Mrs. Moore", 983 | "Mr. Moore", 984 | "Milady" 985 | ], 986 | "tt1655420": [ 987 | "Marilyn Monroe", 988 | "Colin Clark", 989 | "Vivien Leigh", 990 | "Sir Laurence Olivier", 991 | "Sir Kenneth Clark", 992 | "Lady Jane Clark", 993 | "Hugh Perceval", 994 | "Vanessa", 995 | "Jack Cardiff", 996 | "Cotes-Preedy" 997 | ], 998 | "tt1655460": [ 999 | "George Gergenblatt", 1000 | "Linda Gergenblatt", 1001 | "Seth", 1002 | "Carvin Waggie", 1003 | "Eva", 1004 | "Rick Gergenblatt", 1005 | "Wayne Davidson", 1006 | "Karen", 1007 | "Kathy", 1008 | "Almond Cohen" 1009 | ], 1010 | "tt1764234": [ 1011 | "Jackie", 1012 | "Frankie", 1013 | "Russell", 1014 | "Mickey", 1015 | "Driver", 1016 | "Johnny Amato", 1017 | "Markie Trattman", 1018 | "Steve Caprio", 1019 | "Barry Caprio", 1020 | "Dillon" 1021 | ], 1022 | "tt1840309": [ 1023 | "Tris", 1024 | "Four", 1025 | "Natalie", 1026 | "Eric", 1027 | "Marcus", 1028 | "Christina", 1029 | "Peter", 1030 | "Andrew", 1031 | "Caleb", 1032 | "Tori" 1033 | ], 1034 | "tt1999890": [ 1035 | "Jodi", 1036 | "The Other", 1037 | "Natalie", 1038 | "Brooke", 1039 | "Taylor", 1040 | "Quinn", 1041 | "Asher", 1042 | "Gavin", 1043 | "Gate Guard", 1044 | "Britney" 1045 | ], 1046 | "tt2377322": [ 1047 | "Sarchie", 1048 | "Mendoza", 1049 | "Jen", 1050 | "Jimmy", 1051 | "Gordon", 1052 | "Santino", 1053 | "Butler", 1054 | "Nadler", 1055 | "Christina", 1056 | "Jane" 1057 | ], 1058 | "tt2719848": [ 1059 | "Rob Hall", 1060 | "Ang Dorjee", 1061 | "Michael Groom", 1062 | "Andy 'Harold' Harris", 1063 | "Neal Beidleman", 1064 | "Lene Gammelgaard", 1065 | "Lopsang", 1066 | "Charlotte Fox", 1067 | "Tim Madsen", 1068 | "Klev Schoening" 1069 | ], 1070 | "tt3062096": [ 1071 | "Robert Langdon", 1072 | "Sienna Brooks", 1073 | "Christoph Bouchard", 1074 | "Harry Sims", 1075 | "Elizabeth Sinskey", 1076 | "Bertrand Zobrist", 1077 | "Vayentha", 1078 | "Marta Alvarez", 1079 | "Dr. Marconi", 1080 | "Florence Hospital Taxi Driver" 1081 | ], 1082 | "tt3470600": [ 1083 | "Buster Moon", 1084 | "Rosita", 1085 | "Mike", 1086 | "Ash", 1087 | "Eddie", 1088 | "Johnny", 1089 | "Meena", 1090 | "Nana", 1091 | "Young Nana", 1092 | "Miss Crawly / Additional Voices" 1093 | ], 1094 | "tt3530002": [ 1095 | "Ethan", 1096 | "Isaac", 1097 | "Chris", 1098 | "Betsy", 1099 | "Diana", 1100 | "Cindy", 1101 | "Mr. Green", 1102 | "Sarah", 1103 | "Rebecca Grinch", 1104 | "Tommy Owens" 1105 | ], 1106 | "tt4196776": [ 1107 | "Jason Bourne", 1108 | "CIA Director Robert Dewey", 1109 | "Heather Lee", 1110 | "Asset", 1111 | "Nicky Parsons", 1112 | "Aaron Kalloor", 1113 | "Craig Jeffers", 1114 | "Director NI Edwin Russell", 1115 | "Malcolm Smith", 1116 | "Christian Dassault" 1117 | ], 1118 | "tt4382872": [ 1119 | "Leonard", 1120 | "Harry Turner", 1121 | "Victoria", 1122 | "Robertson", 1123 | "Drake", 1124 | "Sitterson", 1125 | "Higgins", 1126 | "Agent Stevens", 1127 | "Dmitri", 1128 | "Purvis" 1129 | ], 1130 | "tt4633694": [ 1131 | "Miles Morales", 1132 | "Peter B. Parker", 1133 | "Gwen Stacy", 1134 | "Uncle Aaron", 1135 | "Jefferson Davis", 1136 | "Aunt May", 1137 | "Rio Morales", 1138 | "Mary Jane", 1139 | "Spider-Ham", 1140 | "Peni Parker" 1141 | ], 1142 | "tt5952594": [ 1143 | "Roman", 1144 | "Henry", 1145 | "Myles", 1146 | "Martha", 1147 | "Psychologist", 1148 | "Dan", 1149 | "Tom", 1150 | "Elijah", 1151 | "Roberto", 1152 | "Inmate at Anger Management" 1153 | ], 1154 | "tt6324278": [ 1155 | "Yi", 1156 | "Peng", 1157 | "Jin", 1158 | "Everest", 1159 | "Dr. Zara", 1160 | "Burnish", 1161 | "Nai Nai", 1162 | "Yi's Mom", 1163 | "Goon Leader", 1164 | "Yak Herder" 1165 | ], 1166 | "tt7014006": [ 1167 | "Kayla Day", 1168 | "Mark Day", 1169 | "Olivia", 1170 | "Gabe", 1171 | "Riley", 1172 | "Trevor", 1173 | "Aniyah", 1174 | "Aiden", 1175 | "Kennedy", 1176 | "Steph" 1177 | ] 1178 | } -------------------------------------------------------------------------------- /autoad_iii/metrics/critic_metric.py: -------------------------------------------------------------------------------- 1 | """ CRITIC (Co-Referencing In Text for Identifying Characters) metric. 2 | 3 | This script evaluates the character identification between list of predicted sentences and reference sentences. 4 | It has two steps: 5 | Step1: build synonym set for each identity in each sentence (based on "fastcoref"), for both GT and prediction 6 | Step2: compute identity IoU for each sentence, then aggregate 7 | 8 | This metric is based on the co-reference package "F-COREF": https://pypi.org/project/fastcoref/ 9 | """ 10 | 11 | import pandas as pd 12 | from fastcoref import FCoref 13 | import numpy as np 14 | from tqdm import tqdm 15 | import json 16 | import argparse 17 | 18 | coref_model = FCoref(device='cuda:0', enable_progress_bar=False) 19 | 20 | 21 | def build_synonym(coref_data, source_idx, role_names, drop_pronouns=True): 22 | """ Function to extract clusters containing any of the character names. """ 23 | res = [] 24 | coref_text = coref_data.text 25 | total_rows = np.max(source_idx) + 1 26 | synonym_rows = {idx: [] for idx in np.arange(total_rows)} 27 | synonym_rows_cid = {idx: [] for idx in np.arange(total_rows)} 28 | synonym_rows_origin = {idx: [] for idx in np.arange(total_rows)} 29 | 30 | for _, cluster in enumerate(coref_data.get_clusters(as_strings=False)): 31 | cluster_name = None 32 | # some cluster is char name; some is not (e.g. a letter, it, the letter) 33 | cluster_str_origin = [coref_text[x[0]:x[1]] for x in cluster] 34 | cluster_str = [coref_text[x[0]:x[1]] for x in cluster] 35 | match_role_set = set(cluster_str).intersection(set(role_names)) 36 | IS_CHAR = len(match_role_set) > 0 37 | if IS_CHAR: 38 | if not len(match_role_set) == 1: 39 | # if a coref result match multiple characters, it is not a good coref; discard it 40 | print(f'Warning: found a bad coref {set(cluster_str)} vs. {role_names}, continue') 41 | continue 42 | cluster_name = list(match_role_set)[0] 43 | # assign the synonym set back to each data row 44 | cluster_source_idx = [source_idx[x[0]:x[1]] for x in cluster] 45 | if len(cluster_name.split()) > 1: # has "first last", or "title last", or "title first last" 46 | cluster_str.extend([cluster_name.split()[0], cluster_name.split()[-1], 47 | cluster_name.split()[0].lower(), cluster_name.split()[-1].lower(), 48 | cluster_name.split()[0].upper(), cluster_name.split()[-1].upper()]) 49 | synonym_set = list(set(cluster_str)) 50 | if drop_pronouns: 51 | synonym_set = [i for i in synonym_set if i.lower() not in ['she', 'he', 'her', 'his', 'they']] 52 | 53 | for item, text in zip(cluster_source_idx, cluster_str_origin): 54 | if not np.mean(item) == np.max(item): 55 | continue 56 | if item[0] != -1: 57 | if cluster_name not in synonym_rows_cid[item[0]]: # dedup 58 | synonym_rows[item[0]].append(synonym_set) 59 | synonym_rows_cid[item[0]].append(cluster_name) 60 | synonym_rows_origin[item[0]].append(text) 61 | 62 | res = {k:list(v) for k,v in synonym_rows.items()} 63 | return res, synonym_rows_origin, synonym_rows_cid 64 | 65 | 66 | def get_iou(list1, list2): 67 | """Get IoU of two lists of strings""" 68 | intersection = set(list1).intersection(set(list2)) 69 | union = set(list1).union(set(list2)) 70 | if len(union) == 0: 71 | return 0 72 | else: 73 | return len(intersection)/len(union) 74 | 75 | 76 | def coref_metric(df, character_list): 77 | roles_str = "" 78 | if len(character_list) > 1: 79 | roles_str = ', '.join(character_list[:-1]) + ' and ' 80 | if len(character_list) > 0: 81 | roles_str += character_list[-1] + '.' 82 | 83 | ### prepare GT and pred 84 | # dataframe should have keys 'text_gt' and 'text_gen' 85 | 86 | # Keep the row index of each character string (like "a" "1" ".". not movie characters), 87 | # because FCoref returns the string indexes of each identity (e.g. "Jack" with position [110, 114]), 88 | # we want to know which sentence each Coref output comes from. 89 | # e.g. ["Jack smiles", "He stands up"] 90 | # --> [[0,0,0,0,0, 0,0,0,0,0, 0], [1,1,1,1,1, 1,1,1,1,1, 1,1]] 91 | 92 | # FCoref gives character string index starting from 1, rather than 0. 93 | # Therefore we prepend "-1" as a placeholder. 94 | # e.g. ["Jack smiles", "He stands up"] 95 | # --> [[-1, 0,0,0,0,0, 0,0,0,0,0, 0], [-1, 1,1,1,1,1, 1,1,1,1,1, 1,1]] 96 | gt_text = ' '.join(df['text_gt'].tolist()) 97 | gt_source_idx_list = [[i]*len(x) for i,x in enumerate(df['text_gt'].tolist())] 98 | gt_source_idx = [] 99 | for i, item in enumerate(gt_source_idx_list): 100 | if i != 0: 101 | item = [-1] + item 102 | gt_source_idx.extend(item) 103 | assert len(gt_source_idx) == len(gt_text) 104 | 105 | 106 | pred_text = ' '.join(df['text_gen'].tolist()) 107 | pred_source_idx_list = [[i]*len(x) for i,x in enumerate(df['text_gen'].tolist())] 108 | pred_source_idx = [] 109 | for i, item in enumerate(pred_source_idx_list): 110 | if i != 0: 111 | item = [-1] + item 112 | pred_source_idx.extend(item) 113 | assert len(pred_source_idx) == len(pred_text) 114 | 115 | # we prepend the cast list for Coref model, 116 | # we also prepend multiple -1s to the index list as placeholders 117 | # e.g. ["Jack and Rose", "Jack smiles", "He stands up"] 118 | # --> [[-1]*13, 119 | # [-1, 0,0,0,0,0, 0,0,0,0,0, 0], 120 | # [-1, 1,1,1,1,1, 1,1,1,1,1, 1,1]] 121 | gt_source_idx = [-1] * (len(roles_str)+1) + gt_source_idx 122 | pred_source_idx = [-1] * (len(roles_str)+1) + pred_source_idx 123 | 124 | 125 | ### Compute coref, get identity clusters 126 | coref_gts = coref_model.predict( 127 | texts=[f"{roles_str} {gt_text}"] 128 | )[0] 129 | assert len(gt_source_idx) == len(coref_gts.text) 130 | 131 | coref_preds = coref_model.predict( 132 | texts=[f"{roles_str} {pred_text}"] 133 | )[0] 134 | assert len(pred_source_idx) == len(coref_preds.text) 135 | 136 | ### Compute synonym set for each sentence 137 | synonym_rows_gt, synonym_origin_gt, synonym_cid_gt = build_synonym(coref_gts, gt_source_idx, character_list) 138 | assert len(df) == len(synonym_rows_gt) 139 | synonym_rows_pred, synonym_origin_pred, synonym_cid_pred = build_synonym(coref_preds, pred_source_idx, character_list) 140 | assert len(df) == len(synonym_rows_pred) 141 | 142 | # Rewrite text with fullnames to reduce ambiguilty 143 | gt_sentence_list = df['text_gt'].tolist() 144 | assert len(gt_sentence_list) == len(synonym_origin_gt) 145 | fullname_gt_sentence_list = [] 146 | for s_idx, ps in enumerate(gt_sentence_list): 147 | origin_words = synonym_origin_gt[s_idx] 148 | cids = synonym_cid_gt[s_idx] 149 | fullname_sentence = ps 150 | for ow, cid in zip(origin_words, cids): 151 | fullname_sentence = fullname_sentence.replace(ow, cid) 152 | fullname_gt_sentence_list.append(fullname_sentence) 153 | df['fullname_gt'] = fullname_gt_sentence_list 154 | 155 | pred_sentence_list = df['text_gen'].tolist() 156 | assert len(pred_sentence_list) == len(synonym_origin_pred) 157 | fullname_pred_sentence_list = [] 158 | for s_idx, ps in enumerate(pred_sentence_list): 159 | origin_words = synonym_origin_pred[s_idx] 160 | cids = synonym_cid_pred[s_idx] 161 | fullname_sentence = ps 162 | for ow, cid in zip(origin_words, cids): 163 | fullname_sentence = fullname_sentence.replace(ow, cid) 164 | fullname_pred_sentence_list.append(fullname_sentence) 165 | df['fullname_pred'] = fullname_pred_sentence_list 166 | 167 | ### Aggregate results for each sentence 168 | iou_list = [] 169 | iou_list_for_df = [] 170 | 171 | for row_idx in tqdm(range(len(df))): 172 | synonym_set = synonym_rows_gt[row_idx] 173 | num_set = len(synonym_set) 174 | if num_set == 0: 175 | continue 176 | iou_list.append(get_iou(synonym_cid_gt[row_idx], synonym_cid_pred[row_idx])) 177 | iou_list_for_df.append(get_iou(synonym_cid_gt[row_idx], synonym_cid_pred[row_idx])) 178 | 179 | return iou_list, df 180 | 181 | 182 | if __name__ == '__main__': 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument('--path', type=str, 185 | default='your_path/inference.csv', 186 | help="inference output in csv file. Require 'vid', 'text_gt' and 'text_gen' columns.") 187 | args = parser.parse_args() 188 | 189 | # Helper functions specific to CMD 190 | with open("cast_list_for_eval.json", 'r') as fobj: 191 | imdbid_to_cast = json.load(fobj) 192 | 193 | with open("cmd_fn_to_imdb.json", 'r') as fobj: 194 | cmd_fn_to_imdb = json.load(fobj) 195 | 196 | pred_df = pd.read_csv(args.path) 197 | 198 | # vid refers to CMD filenames (each movie has 10-20 clips), 199 | # here we found the IMDB ID of their source movie. 200 | pred_df['imdbid'] = pred_df.apply(lambda x: cmd_fn_to_imdb[x['vid']], axis=1) 201 | pred_df['text_gen'] = pred_df['text_gen'].astype(str) 202 | 203 | # we assume after grouping by "imdbid", each group is temporally sorted 204 | # otherwise Coref does not make sense 205 | val_df_dict = dict(tuple(pred_df.groupby('imdbid'))) 206 | 207 | total_iou = [] 208 | 209 | for imdbid, df in tqdm(val_df_dict.items(), total=len(val_df_dict)): 210 | char_list = imdbid_to_cast[imdbid] 211 | iou_list, _ = coref_metric(df, char_list) 212 | total_iou.extend(iou_list) 213 | 214 | print(f"avg IoU on {len(total_iou)} predictions with identities", np.mean(total_iou)) 215 | -------------------------------------------------------------------------------- /autoad_iii/metrics/llm_ad_eval_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI 3 | import argparse 4 | import json 5 | import ast 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pandas as pd 9 | from functools import partial 10 | from multiprocessing import Process, Queue 11 | 12 | 13 | # Execute function with a timeout 14 | # OpenAI API call sometimes freezes 15 | def my_function(result_queue, fn, *args, **kwargs): 16 | result = fn(*args, **kwargs) 17 | result_queue.put(result) 18 | 19 | def run_with_timeout(fn, *args,): 20 | result_queue = Queue() 21 | while True: 22 | p = Process(target=my_function, args=(result_queue, fn, *args,)) 23 | p.start() 24 | p.join(timeout=3) # Set the timeout value (in seconds) 25 | 26 | if p.is_alive(): 27 | p.terminate() 28 | p.join() 29 | print("Function timed out, redo ...") 30 | else: 31 | result = result_queue.get() 32 | break 33 | return result 34 | 35 | 36 | def eval_each(text_gt, text_pred, client): 37 | """ Compute the LLM score for one pair """ 38 | completion = client.chat.completions.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the quality of generative outputs for movie audio descriptions. " 45 | "Your task is to compare the predicted audio descriptions with the correct audio descriptions and determine its level of match, considering mainly the visual elements like actions, objects and interactions. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Check if the predicted audio description covers the main visual events from the movie, especially focusing on the verbs and nouns.\n" 49 | "- Evaluate whether the predicted audio description includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" 50 | "- Consider synonyms or paraphrases as valid matches. Consider pronouns like 'he' or 'she' as valid matches with character names. Consider different character names as valid matches. \n" 51 | "- Provide a single evaluation score that reflects the level of match of the prediction, considering the visual elements like actions, objects and interactions." 52 | }, 53 | { 54 | "role": "user", 55 | "content": 56 | "Please evaluate the following movie audio description pair:\n\n" 57 | f"Correct Audio Description: {text_gt}\n" 58 | f"Predicted Audio Description: {text_pred}\n\n" 59 | "Provide your evaluation only as a matching score where the matching score is an integer value between 0 and 5, with 5 indicating the highest level of match. " 60 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the matching score in INTEGER, not STRING." 61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 62 | "For example, your response should look like this: {'score': }." 63 | } 64 | ] 65 | ) 66 | # Convert response to a Python dictionary. 67 | response_message = completion.choices[0].message.content 68 | response_dict = ast.literal_eval(response_message) 69 | del completion 70 | return [response_dict] 71 | 72 | 73 | def main(args): 74 | """ Main function to control the flow of the program. """ 75 | client = OpenAI() 76 | eval_fn = partial(eval_each, client=client) 77 | 78 | pred_df = pd.read_csv(args.path) 79 | 80 | all_output = [] 81 | gt_pred_pair = [(x, y) for x,y in zip(pred_df['text_gt'], pred_df['text_gen'])] 82 | 83 | # save output regularly, in case api call breaks 84 | chunk_size = 200 85 | num_chunk = (len(gt_pred_pair) // chunk_size) + 1 86 | 87 | for chunk_idx in tqdm(range(num_chunk)): 88 | if chunk_idx != 0: 89 | with open(f'tmp/log_{chunk_idx:05d}.json', 'w') as fobj: 90 | fobj.write(all_output) 91 | 92 | gt_pred_pair_current = gt_pred_pair[chunk_idx*chunk_size : (chunk_idx+1)*chunk_size] 93 | for (gt, pred) in tqdm(gt_pred_pair_current, total=len(gt_pred_pair_current)): 94 | result = run_with_timeout(eval_fn, gt, pred) 95 | all_output.append(result) 96 | 97 | all_score = [] 98 | for i in all_output: 99 | try: 100 | all_score.append(i[0]['score']) 101 | except: 102 | print(i, 'does not follow the format, skip.') 103 | continue 104 | 105 | print(np.mean(all_score)) 106 | with open(f'tmp/log_final.json', 'w') as fobj: 107 | fobj.write(all_output) 108 | 109 | 110 | def parse_args(): 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('path', type=str, help="inference output in csv file. Require 'text_gt' and 'text_gen' columns.") 113 | parser.add_argument("--api_key", required=False, default=None, help="OpenAI API key.") 114 | args = parser.parse_args() 115 | return args 116 | 117 | 118 | if __name__ == "__main__": 119 | args = parse_args() 120 | if args.api_key is not None: 121 | os.environ['OPENAI_API_KEY'] = args.api_key 122 | # otherwise set OPENAI_API_KEY by running: export OPENAI_API_KEY='your-api-key-here' 123 | main(args) 124 | -------------------------------------------------------------------------------- /autoad_iii/metrics/llm_ad_eval_llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import ast 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import sys 9 | sys.path.insert(0, '../../Video-LLaMA/') 10 | from video_llama.common.config import Config 11 | import video_llama.tasks as tasks 12 | 13 | 14 | class Args(): 15 | """load a customied yaml config file to load Llama2 model.""" 16 | def __init__(self): 17 | self.cfg_path = '../config_video_llama/llama2_eval.yaml' 18 | self.options = None 19 | 20 | 21 | @torch.no_grad() 22 | def eval_each(model, text_gt, text_pred): 23 | sys_prompt = "[INST] <>\nYou are an intelligent chatbot designed for evaluating the quality of generative outputs for movie audio descriptions. Your task is to compare the predicted audio descriptions with the correct audio descriptions and determine its level of match, considering mainly the visual elements like actions, objects and interactions. Here's how you can accomplish the task:------##INSTRUCTIONS: - Check if the predicted audio description covers the main visual events from the movie, especially focusing on the verbs and nouns.\n- Evaluate whether the predicted audio description includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n- Consider synonyms or paraphrases as valid matches. Consider pronouns like 'he' or 'she' as valid matches with character names. Consider different character names as valid matches. \n- Provide a single evaluation score that reflects the level of match of the prediction, considering the visual elements like actions, objects and interactions. \n<>\n\n{} [/INST] " 24 | 25 | prompt = ( 26 | "Please evaluate the following movie audio description pair:\n\n" 27 | f"Correct Audio Description: {text_gt}\n" 28 | f"Predicted Audio Description: {text_pred}\n\n" 29 | "Provide your evaluation only as a matching score where the matching score is an integer value between 0 and 5, with 5 indicating the highest level of match. " 30 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the matching score in INTEGER, not STRING." 31 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 32 | "For example, your response should look like this: {'score': }." 33 | ) 34 | 35 | input_tokens = model.llama_tokenizer(sys_prompt.format(prompt), return_tensors="pt", add_special_tokens=False).to('cuda').input_ids 36 | input_tokens = input_tokens.cuda() 37 | input_embedding = model.llama_model.model.embed_tokens(input_tokens) 38 | 39 | outputs1 = model.llama_model.generate( 40 | inputs_embeds=input_embedding, 41 | max_new_tokens=8, 42 | stopping_criteria=None, 43 | num_beams=1, 44 | do_sample=True, 45 | min_length=1, 46 | top_p=0.9, 47 | repetition_penalty=1, 48 | length_penalty=1, 49 | temperature=1, 50 | ) 51 | 52 | output_text = model.llama_tokenizer.batch_decode(outputs1, add_special_tokens=False) 53 | return output_text 54 | 55 | 56 | def main(args): 57 | # read llama2-7b model by customized config 58 | cfg = Config(Args()) 59 | cfg.pretty_print() 60 | task = tasks.setup_task(cfg) 61 | model = task.build_model(cfg).cuda() 62 | model.eval() 63 | 64 | pred_df = pd.read_csv(args.path) 65 | 66 | counter = 0 67 | all_output = [] 68 | all_score = [] 69 | 70 | for _, row in tqdm(pred_df.iterrows(), total=len(pred_df)): 71 | counter += 1 72 | if counter % 200 == 0: # monitor progress 73 | print(np.mean(all_score), f'at {counter} sampeles') 74 | text_gt = row['text_gt'] 75 | text_pred = row['text_gen'] 76 | eval_out = eval_each(model=model, text_gt=text_gt, text_pred=text_pred) 77 | assert len(eval_out) == 1 78 | eval_out = eval_out[0] 79 | try: 80 | score_dict = ast.literal_eval(eval_out.replace('','').replace('','').strip()) 81 | except: 82 | print('get error from:', eval_out) 83 | score_dict = {'score': 0} 84 | all_output.append(eval_out) 85 | all_score.append(score_dict['score']) 86 | 87 | print(f"final average score: {np.mean(all_score)}") 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('path', type=str, help="inference output in csv file. Require 'text_gt' and 'text_gen' columns.") 93 | args = parser.parse_args() 94 | main(args) -------------------------------------------------------------------------------- /autoad_iii/readme.md: -------------------------------------------------------------------------------- 1 | ## AutoAD-III 2 | The codebase for *AutoAD III: The Prequel - Back to the Pixels*. 3 | [[project page]](https://www.robots.ox.ac.uk/~vgg/research/autoad/) 4 | [[AutoAD-III PDF]](https://www.robots.ox.ac.uk/~vgg/publications/2024/Han24/han24.pdf) 5 | 6 | 7 | 8 | ### Scripts & Datasets 9 | We are working on sharing training/testing scripts and datasets. They will be available shortly here. 10 | 11 | ### Models 12 | * Movie-Llama2: 13 | [[model weight]](http://www.robots.ox.ac.uk/~htd/ad3/ad3_moviellama2_ce_iter14000.pth.tar) 14 | [[AD output on CMD-AD]](http://www.robots.ox.ac.uk/~htd/ad3/inference_ad3_moviellama2_ce_iter14000.csv) 15 | 16 | ### Reference 17 | ```bibtex 18 | @InProceedings{han2024autoad3, 19 | title={{AutoAD III: The Prequel} - Back to the Pixels}, 20 | author={Tengda Han and Max Bain and Arsha Nagrani and G\"ul Varol and Weidi Xie and Andrew Zisserman}, 21 | booktitle={CVPR}, 22 | year={2024}} 23 | ``` -------------------------------------------------------------------------------- /datasets/mad2imdb.json: -------------------------------------------------------------------------------- 1 | { 2 | "0001_American_Beauty": "tt0169547", 3 | "0002_As_Good_As_It_Gets": "tt0119822", 4 | "0003_CASABLANCA": "tt0034583", 5 | "0004_Charade": "tt0056923", 6 | "0005_Chinatown": "tt0071315", 7 | "0006_Clerks": "tt0109445", 8 | "0007_DIE_NACHT_DES_JAEGERS": "tt0048424", 9 | "0008_Fargo": "tt0116282", 10 | "0009_Forrest_Gump": "tt0109830", 11 | "0010_Frau_Ohne_Gewissen": "tt0036775", 12 | "0011_Gandhi": "tt0083987", 13 | "0012_Get_Shorty": "tt0113161", 14 | "0013_Halloween": "tt0077651", 15 | "0014_Ist_das_Leben_nicht_schoen": "tt0038650", 16 | "0016_O_Brother_Where_Art_Thou": "tt0190590", 17 | "0017_Pianist": "tt0253474", 18 | "0019_Pulp_Fiction": "tt0110912", 19 | "0020_Raising_Arizona": "tt0093822", 20 | "0021_Rear_Window": "tt0047396", 21 | "0022_Reservoir_Dogs": "tt0105236", 22 | "0023_THE_BUTTERFLY_EFFECT": "tt0289879", 23 | "0026_The_Big_Fish": "tt0319061", 24 | "0027_The_Big_Lebowski": "tt0118715", 25 | "0028_The_Crying_Game": "tt0104036", 26 | "0029_The_Graduate": "tt0061722", 27 | "0030_The_Hustler": "tt0054997", 28 | "0031_The_Lost_Weekend": "tt0037884", 29 | "0032_The_Princess_Bride": "tt0093779", 30 | "0033_Amadeus": "tt0086879", 31 | "0038_Psycho": "tt0054215", 32 | "0041_The_Sixth_Sense": "tt0167404", 33 | "0043_Thelma_and_Luise": "tt0103074", 34 | "0046_Chasing_Amy": "tt0118842", 35 | "0049_Hannah_and_her_sisters": "tt0091167", 36 | "0050_Indiana_Jones_and_the_last_crusade": "tt0097576", 37 | "0051_Men_in_black": "tt0119654", 38 | "0053_Rendezvous_mit_Joe_Black": "tt0119643", 39 | "1001_Flight": "tt1907668", 40 | "1002_Harry_Potter_and_the_Half-Blood_Prince": "tt0417741", 41 | "1003_How_to_Lose_Friends_and_Alienate_People": "tt0455538", 42 | "1004_Juno": "tt0467406", 43 | "1005_Signs": "tt0286106", 44 | "1006_Slumdog_Millionaire": "tt1010048", 45 | "1007_Spider-Man1": "tt0145487", 46 | "1008_Spider-Man2": "tt0316654", 47 | "1009_Spider-Man3": "tt0413300", 48 | "1010_TITANIC": "tt0120338", 49 | "1011_The_Help": "tt1454029", 50 | "1012_Unbreakable": "tt0217869", 51 | "10142": "tt6977338", 52 | "10149": "tt0264464", 53 | "1014_2012": "tt1190080", 54 | "1015_27_Dresses": "tt0988595", 55 | "1017_Bad_Santa": "tt0307987", 56 | "1018_Body_Of_Lies": "tt0758774", 57 | "1019_Confessions_Of_A_Shopaholic": "tt1093908", 58 | "10202": "tt0096446", 59 | "1020_Crazy_Stupid_Love": "tt1570728", 60 | "1026_Legion": "tt1038686", 61 | "1027_Les_Miserables": "tt1707386", 62 | "1028_No_Reservations": "tt0481141", 63 | "1031_Quantum_of_Solace": "tt0830515", 64 | "10322": "tt0311289", 65 | "1033_Sherlock_Holmes_A_Game_of_Shadows": "tt1515091", 66 | "1034_Super_8": "tt1650062", 67 | "1035_The_Adjustment_Bureau": "tt1385826", 68 | "1037_The_Curious_Case_Of_Benjamin_Button": "tt0421715", 69 | "1038_The_Great_Gatsby": "tt1343092", 70 | "1039_The_Queen": "tt0436697", 71 | "1040_The_Ugly_Truth": "tt1142988", 72 | "1042_Up_In_The_Air": "tt1193138", 73 | "1043_Vantage_Point": "tt0443274", 74 | "1045_An_education": "tt1174732", 75 | "1046_Australia": "tt0455824", 76 | "1047_Defiance": "tt1034303", 77 | "1048_Gran_Torino": "tt1205489", 78 | "1050_Harry_Potter_and_the_deathly_hallows_Disk_One": "tt0926084", 79 | "1051_Harry_Potter_and_the_goblet_of_fire": "tt0330373", 80 | "10527": "tt0463998", 81 | "1052_Harry_Potter_and_the_order_of_phoenix": "tt0373889", 82 | "10536": "tt0475293", 83 | "1054_Harry_Potter_and_the_prisoner_of_azkaban": "tt0304141", 84 | "1055_Marley_and_me": "tt0822832", 85 | "1057_Seven_pounds": "tt0814314", 86 | "1058_The_Damned_united": "tt1226271", 87 | "1059_The_devil_wears_prada": "tt0458352", 88 | "1060_Yes_man": "tt1068680", 89 | "1061_Harry_Potter_and_the_deathly_hallows_Disk_Two": "tt1201607", 90 | "1062_Day_the_Earth_stood_still": "tt0970416", 91 | "10784": "tt0339291", 92 | "10813": "tt0048280", 93 | "10836": "tt0866439", 94 | "10861": "tt0369610", 95 | "10894": "tt0085549", 96 | "10965": "tt0116136", 97 | "11003": "tt0286716", 98 | "11010": "tt0114898", 99 | "11099": "tt0892769", 100 | "11129": "tt0325980", 101 | "11139": "tt0120915", 102 | "11140": "tt0121765", 103 | "11143": "tt0121766", 104 | "11147": "tt0080684", 105 | "11148": "tt0086190", 106 | "11154": "tt0418279", 107 | "11318": "tt1323594", 108 | "11321": "tt0077631", 109 | "11345": "tt0183790", 110 | "11396": "tt0120689", 111 | "11430": "tt0103772", 112 | "11438": "tt1059786", 113 | "11530": "tt1305591", 114 | "11620": "tt2980516", 115 | "11727": "tt0433035", 116 | "11796": "tt0107076", 117 | "11962": "tt0290334", 118 | "12010": "tt0369339", 119 | "12079": "tt1340800", 120 | "12090": "tt0147800", 121 | "12125": "tt0425112", 122 | "12131": "tt4972582", 123 | "12132": "tt0071562", 124 | "12144": "tt0338013", 125 | "12147": "tt0477302", 126 | "12148": "tt0450259", 127 | "12186": "tt0163025", 128 | "12211": "tt0075686", 129 | "12220": "tt0365748", 130 | "12222": "tt0450385", 131 | "12263": "tt0125439", 132 | "12273": "tt1458175", 133 | "12294": "tt0114369", 134 | "12324": "tt1099212", 135 | "12358": "tt0431308", 136 | "12504": "tt0338094", 137 | "12563": "tt0319262", 138 | "12585": "tt0217869", 139 | "12618": "tt0454921", 140 | "12653": "tt0099253", 141 | "12658": "tt0103956", 142 | "12743": "tt0097637", 143 | "12852": "tt0388795", 144 | "12869": "tt0762107", 145 | "12900": "tt1823664", 146 | "12906": "tt1343727", 147 | "12911": "tt1013753", 148 | "12923": "tt0253474", 149 | "12958": "tt0058385", 150 | "13018": "tt0094226", 151 | "13027": "tt1895587", 152 | "13031": "tt0084434", 153 | "13045": "tt0212720", 154 | "13140": "tt0457510", 155 | "13146": "tt0118715", 156 | "13159": "tt0207201", 157 | "13165": "tt0101393", 158 | "13187": "tt0112453", 159 | "13191": "tt2170299", 160 | "13201": "tt0086250", 161 | "2723": "tt0115433", 162 | "2730": "tt1413492", 163 | "2731": "tt2024544", 164 | "2735": "tt4172430", 165 | "2738": "tt0450232", 166 | "2745": "tt1190080", 167 | "2750": "tt2304459", 168 | "2758": "tt0103596", 169 | "2768": "tt0453562", 170 | "2778": "tt3799372", 171 | "2787": "tt1136608", 172 | "2800": "tt1067106", 173 | "2801": "tt0087056", 174 | "2814": "tt4731136", 175 | "2818": "tt1753383", 176 | "2854": "tt3416532", 177 | "2869": "tt6054650", 178 | "2870": "tt6644200", 179 | "2873": "tt1837562", 180 | "2911": "tt1826590", 181 | "2913": "tt2194499", 182 | "2928": "tt6495770", 183 | "2934": "tt6306064", 184 | "2944": "tt1815862", 185 | "2948": "tt0027260", 186 | "2970": "tt0470982", 187 | "2986": "tt1698641", 188 | "2992": "tt1014759", 189 | "2996": "tt2316204", 190 | "3001": "tt0042192", 191 | "3001_21_JUMP_STREET": "tt1232829", 192 | "3002_30_MINUTES_OR_LESS": "tt1622547", 193 | "3003_40_YEAR_OLD_VIRGIN": "tt0405422", 194 | "3004_500_DAYS_OF_SUMMER": "tt1022603", 195 | "3005_ABRAHAM_LINCOLN_VAMPIRE_HUNTER": "tt1611224", 196 | "3007_A_THOUSAND_WORDS": "tt0763831", 197 | "3008_BAD_TEACHER": "tt1284575", 198 | "3009_BATTLE_LOS_ANGELES": "tt1217613", 199 | "3012_BRUNO": "tt0889583", 200 | "3013_BURLESQUE": "tt1126591", 201 | "3014": "tt5294550", 202 | "3014_CAPTAIN_AMERICA": "tt0458339", 203 | "3015_CHARLIE_ST_CLOUD": "tt1438254", 204 | "3016_CHASING_MAVERICKS": "tt1629757", 205 | "3017_CHRONICLE": "tt1706593", 206 | "3018_CINDERELLA_MAN": "tt0352248", 207 | "3020": "tt3640424", 208 | "3020_DEAR_JOHN": "tt0989757", 209 | "3021": "tt4649416", 210 | "3021_DEATH_AT_A_FUNERAL": "tt0795368", 211 | "3022_DINNER_FOR_SCHMUCKS": "tt0427152", 212 | "3023": "tt1243974", 213 | "3023_DISTRICT_9": "tt1136608", 214 | "3024_EASY_A": "tt1282140", 215 | "3025_FLIGHT": "tt1907668", 216 | "3026_FRIENDS_WITH_BENEFITS": "tt1632708", 217 | "3028_GHOST_RIDER_SPIRIT_OF_VENGEANCE": "tt1071875", 218 | "3030_GROWN_UPS": "tt1375670", 219 | "3031_HANSEL_GRETEL_WITCH_HUNTERS": "tt1428538", 220 | "3032_HOW_DO_YOU_KNOW": "tt1341188", 221 | "3033": "tt1231580", 222 | "3033_HUGO": "tt0970179", 223 | "3034_IDES_OF_MARCH": "tt1124035", 224 | "3035_INSIDE_MAN": "tt0454848", 225 | "3036_IN_TIME": "tt1637688", 226 | "3037_IRON_MAN2": "tt1228705", 227 | "3038_ITS_COMPLICATED": "tt1230414", 228 | "3039_JACK_AND_JILL": "tt0810913", 229 | "3040": "tt1961175", 230 | "3040_JULIE_AND_JULIA": "tt1135503", 231 | "3041_JUST_GO_WITH_IT": "tt1564367", 232 | "3042_KARATE_KID": "tt1155076", 233 | "3043_KATY_PERRY_PART_OF_ME": "tt2215719", 234 | "3045_LAND_OF_THE_LOST": "tt0071005", 235 | "3046_LARRY_CROWNE": "tt1583420", 236 | "3047_LIFE_OF_PI": "tt0454876", 237 | "3048_LITTLE_FOCKERS": "tt0970866", 238 | "3049": "tt1800241", 239 | "3049_MORNING_GLORY": "tt1126618", 240 | "3050": "tt3532216", 241 | "3050_MR_POPPERS_PENGUINS": "tt1396218", 242 | "3051_NANNY_MCPHEE_RETURNS": "tt1415283", 243 | "3052_NO_STRINGS_ATTACHED": "tt1411238", 244 | "3053_PARENTAL_GUIDANCE": "tt1047540", 245 | "3054_PERCY_JACKSON_LIGHTENING_THIEF": "tt0814255", 246 | "3055_PROMETHEUS": "tt1446714", 247 | "3056_PUBLIC_ENEMIES": "tt1152836", 248 | "3058_RUBY_SPARKS": "tt1839492", 249 | "3059": "tt2179136", 250 | "3060": "tt3316948", 251 | "3060_SANCTUM": "tt0881320", 252 | "3061_SNOW_FLOWER": "tt1541995", 253 | "3062_SORCERERS_APPRENTICE": "tt0963966", 254 | "3063_SOUL_SURFER": "tt1596346", 255 | "3066": "tt0043278", 256 | "3066_THE_ADVENTURES_OF_TINTIN": "tt0983193", 257 | "3067_THE_ART_OF_GETTING_BY": "tt1645080", 258 | "3069_THE_BOUNTY_HUNTER": "tt1038919", 259 | "3070": "tt6322922", 260 | "3070_THE_CALL": "tt1911644", 261 | "3071_THE_DESCENDANTS": "tt1033575", 262 | "3072_THE_GIRL_WITH_THE_DRAGON_TATTOO": "tt1568346", 263 | "3073_THE_GUILT_TRIP": "tt1694020", 264 | "3074_THE_ROOMMATE": "tt1265990", 265 | "3075_THE_SITTER": "tt1366344", 266 | "3076_THE_SOCIAL_NETWORK": "tt1285016", 267 | "3077_THE_VOW": "tt1606389", 268 | "3078_THE_WATCH": "tt1298649", 269 | "3079_THINK_LIKE_A_MAN": "tt1621045", 270 | "3081_THOR": "tt0800369", 271 | "3082_TITANIC1": "tt0120338", 272 | "3083_TITANIC2": "tt0120338", 273 | "3084_TOOTH_FAIRY": "tt0808510", 274 | "3085_TRUE_GRIT": "tt1403865", 275 | "3086_UGLY_TRUTH": "tt1142988", 276 | "3087_WE_BOUGHT_A_ZOO": "tt1389137", 277 | "3088_WHATS_YOUR_NUMBER": "tt0770703", 278 | "3089_XMEN_FIRST_CLASS": "tt1270798", 279 | "3090_YOUNG_ADULT": "tt1625346", 280 | "3091_ZOMBIELAND": "tt1156398", 281 | "3092_ZOOKEEPER": "tt1222817", 282 | "3103": "tt1781769", 283 | "3106": "tt5140878", 284 | "3113": "tt2798920", 285 | "3114": "tt2401878", 286 | "3117": "tt1549572", 287 | "3129": "tt5206170", 288 | "3138": "tt1519461", 289 | "3146": "tt2543164", 290 | "3153": "tt1430607", 291 | "3160": "tt2094766", 292 | "3170": "tt0230011", 293 | "3171": "tt2406566", 294 | "3209": "tt3890160", 295 | "3239": "tt1171222", 296 | "3253": "tt3628584", 297 | "3276": "tt4622512", 298 | "3277": "tt1532958", 299 | "3295": "tt2771200", 300 | "3314": "tt3174376", 301 | "3339": "tt1482393", 302 | "3340": "tt2224075", 303 | "3354": "tt2245084", 304 | "3376": "tt2562232", 305 | "3393": "tt1825683", 306 | "3401": "tt7349662", 307 | "3408": "tt1856101", 308 | "3414": "tt1620935", 309 | "3417": "tt0454084", 310 | "3447": "tt3722062", 311 | "3464": "tt4629266", 312 | "3480": "tt5884230", 313 | "3482": "tt3704700", 314 | "3500": "tt5716464", 315 | "3509": "tt3682448", 316 | "3510": "tt0398808", 317 | "3513": "tt1473832", 318 | "3521": "tt5805752", 319 | "3548": "tt6339042", 320 | "3575": "tt3707106", 321 | "3590": "tt5726616", 322 | "3599": "tt0092718", 323 | "3611": "tt2091256", 324 | "3625": "tt2989524", 325 | "3720": "tt1823672", 326 | "3743": "tt0371606", 327 | "3759": "tt0493405", 328 | "3773": "tt4575576", 329 | "3820": "tt2380307", 330 | "3834": "tt4682786", 331 | "3837": "tt2126235", 332 | "3858": "tt3322364", 333 | "3905": "tt2428170", 334 | "3911": "tt2554274", 335 | "3922": "tt3268340", 336 | "3977": "tt4555426", 337 | "4001": "tt3053228", 338 | "4007": "tt2101341", 339 | "4010": "tt1431045", 340 | "4017": "tt5463162", 341 | "4031": "tt1137450", 342 | "4043": "tt1860357", 343 | "4053": "tt1172049", 344 | "4061": "tt1837636", 345 | "4071": "tt5390504", 346 | "4080": "tt3172532", 347 | "4082": "tt6003368", 348 | "4143": "tt4160708", 349 | "4156": "tt0036775", 350 | "4200": "tt2096672", 351 | "4204": "tt5013056", 352 | "4210": "tt4701724", 353 | "4253": "tt1535108", 354 | "4266": "tt0375735", 355 | "4299": "tt0413099", 356 | "4303": "tt7026672", 357 | "4305": "tt2937696", 358 | "4368": "tt1966359", 359 | "4377": "tt2671706", 360 | "4378": "tt3411444", 361 | "4390": "tt5711148", 362 | "4423": "tt3401882", 363 | "4434": "tt2039338", 364 | "4451": "tt0119137", 365 | "4455": "tt2381941", 366 | "4460": "tt2112152", 367 | "4480": "tt1100089", 368 | "4489": "tt1142977", 369 | "4528": "tt2294629", 370 | "4535": "tt3977462", 371 | "4551": "tt2704998", 372 | "4576": "tt0322389", 373 | "4578": "tt1981128", 374 | "4587": "tt5052448", 375 | "4596": "tt1219827", 376 | "4597": "tt6111628", 377 | "4608": "tt4481414", 378 | "4611": "tt0210070", 379 | "4618": "tt3564472", 380 | "4634": "tt4824308", 381 | "4635": "tt6652708", 382 | "4638": "tt2404233", 383 | "4644": "tt2568862", 384 | "4664": "tt1653665", 385 | "4670": "tt2417712", 386 | "4671": "tt1051904", 387 | "4684": "tt4270516", 388 | "4702": "tt0947810", 389 | "4709": "tt2724532", 390 | "4719": "tt3896198", 391 | "4728": "tt0031398", 392 | "4740": "tt2119532", 393 | "4741": "tt0475290", 394 | "4753": "tt4856322", 395 | "4772": "tt0424136", 396 | "4778": "tt3072482", 397 | "4797": "tt0395571", 398 | "4798": "tt0124718", 399 | "4813": "tt2582782", 400 | "4815": "tt0411477", 401 | "4839": "tt4846340", 402 | "4880": "tt3671542", 403 | "4884": "tt4419364", 404 | "4888": "tt1981637", 405 | "4901": "tt0097523", 406 | "4902": "tt0119310", 407 | "4914": "tt1535438", 408 | "4925": "tt5478478", 409 | "4929": "tt2967224", 410 | "4933": "tt2510894", 411 | "4936": "tt0837562", 412 | "4950": "tt1292566", 413 | "4962": "tt6573444", 414 | "4970": "tt5022702", 415 | "4977": "tt0097550", 416 | "4982": "tt6791096", 417 | "4992": "tt1490785", 418 | "5014": "tt3416828", 419 | "5041": "tt1390411", 420 | "5055": "tt1628841", 421 | "5063": "tt1942884", 422 | "5074": "tt1969062", 423 | "5093": "tt5726086", 424 | "5101": "tt0816692", 425 | "5118": "tt3221698", 426 | "5139": "tt3393786", 427 | "5144": "tt1619029", 428 | "5217": "tt3348730", 429 | "5236": "tt4425200", 430 | "5237": "tt0040495", 431 | "5257": "tt2283362", 432 | "5259": "tt1640484", 433 | "5265": "tt1617661", 434 | "5270": "tt4881806", 435 | "5283": "tt0974015", 436 | "5293": "tt4139124", 437 | "5308": "tt1650554", 438 | "5335": "tt1972591", 439 | "5366": "tt3731562", 440 | "5367": "tt3850590", 441 | "5369": "tt4302938", 442 | "5417": "tt1216492", 443 | "5420": "tt3892172", 444 | "5432": "tt0486761", 445 | "5449": "tt5442430", 446 | "5461": "tt5619332", 447 | "5469": "tt4786282", 448 | "5473": "tt0275847", 449 | "5477": "tt0443272", 450 | "5494": "tt2361317", 451 | "5506": "tt3300542", 452 | "5510": "tt0490166", 453 | "5511": "tt1091191", 454 | "5522": "tt1276104", 455 | "5563": "tt5164432", 456 | "5565": "tt4669986", 457 | "5568": "tt1366338", 458 | "5574": "tt2872732", 459 | "5575": "tt0049456", 460 | "5577": "tt1870425", 461 | "5583": "tt1392190", 462 | "5594": "tt3471098", 463 | "5605": "tt2023587", 464 | "5607": "tt6911608", 465 | "5634": "tt2929690", 466 | "5641": "tt5175450", 467 | "5649": "tt5301662", 468 | "5677": "tt4046784", 469 | "5678": "tt4500922", 470 | "5682": "tt2097298", 471 | "5685": "tt2582496", 472 | "5700": "tt0290002", 473 | "5735": "tt0329374", 474 | "5737": "tt0238414", 475 | "5743": "tt4981636", 476 | "5749": "tt2649554", 477 | "5752": "tt2823054", 478 | "5758": "tt1647668", 479 | "5762": "tt2293640", 480 | "5792": "tt1935859", 481 | "5807": "tt2872462", 482 | "5814": "tt4209788", 483 | "5818": "tt2241351", 484 | "5819": "tt1210166", 485 | "5828": "tt3095734", 486 | "5852": "tt3045616", 487 | "5865": "tt0339412", 488 | "5872": "tt0119718", 489 | "5873": "tt0123179", 490 | "5898": "tt0279967", 491 | "5900": "tt0408306", 492 | "5913": "tt3402236", 493 | "5923": "tt5518906", 494 | "5950": "tt0110612", 495 | "5958": "tt5474644", 496 | "6012": "tt2369135", 497 | "6013": "tt3531824", 498 | "6022": "tt6408226", 499 | "6048": "tt2011159", 500 | "6055": "tt4550098", 501 | "6057": "tt2024469", 502 | "6076": "tt3110958", 503 | "6086": "tt1483013", 504 | "6090": "tt5164214", 505 | "6137": "tt3829920", 506 | "6153": "tt4158876", 507 | "6154": "tt1204977", 508 | "6156": "tt1018765", 509 | "6177": "tt1563742", 510 | "6186": "tt2557478", 511 | "6194": "tt3332064", 512 | "6224": "tt1355644", 513 | "6232": "tt4572514", 514 | "6319": "tt2058673", 515 | "6334": "tt3960412", 516 | "6394": "tt4341582", 517 | "6402": "tt0110932", 518 | "6491": "tt1855325", 519 | "6521": "tt1318514", 520 | "6607": "tt4466894", 521 | "6613": "tt4441150", 522 | "6617": "tt6951892", 523 | "6629": "tt1700841", 524 | "6636": "tt2140373", 525 | "6655": "tt0446029", 526 | "6656": "tt1727776", 527 | "6672": "tt1217213", 528 | "6685": "tt1020072", 529 | "6701": "tt1121096", 530 | "6706": "tt1956620", 531 | "6741": "tt2296777", 532 | "6769": "tt3397884", 533 | "6770": "tt5052474", 534 | "6775": "tt0490215", 535 | "6810": "tt1564585", 536 | "6811": "tt5758778", 537 | "6816": "tt2072233", 538 | "6819": "tt4573516", 539 | "6832": "tt4184878", 540 | "6833": "tt2398241", 541 | "6837": "tt2334871", 542 | "6859": "tt3778644", 543 | "6869": "tt0108186", 544 | "6870": "tt3210686", 545 | "6878": "tt5688932", 546 | "6890": "tt1164647", 547 | "6952": "tt2527338", 548 | "6959": "tt2488496", 549 | "6992": "tt3316960", 550 | "6994": "tt0348124", 551 | "7001": "tt4624424", 552 | "7005": "tt0480011", 553 | "7007": "tt4191054", 554 | "7026": "tt0491175", 555 | "7036": "tt3263904", 556 | "7050": "tt7690670", 557 | "7055": "tt0859635", 558 | "7131": "tt1291150", 559 | "7195": "tt1872181", 560 | "7196": "tt1440728", 561 | "7243": "tt1596363", 562 | "7682": "tt0800080", 563 | "7882": "tt4501454", 564 | "8152": "tt4052882", 565 | "8276": "tt1324999", 566 | "8295": "tt3488710", 567 | "8346": "tt0993846", 568 | "8496": "tt2109248", 569 | "8578": "tt7334528", 570 | "8587": "tt1496025", 571 | "8589": "tt3462710", 572 | "8593": "tt4761916", 573 | "8598": "tt0186654", 574 | "8601": "tt7153766", 575 | "8608": "tt6499752", 576 | "8616": "tt1524930", 577 | "8618": "tt2239822", 578 | "8637": "tt1976009", 579 | "8734": "tt0103241", 580 | "8766": "tt2582802", 581 | "8767": "tt3553442", 582 | "8811": "tt1072748", 583 | "9110": "tt0087469", 584 | "9277": "tt2911666", 585 | "9380": "tt4807408", 586 | "9384": "tt4154664", 587 | "9386": "tt8085790", 588 | "9387": "tt4126476", 589 | "9419": "tt1253864", 590 | "9421": "tt0384642", 591 | "9451": "tt0371746", 592 | "9456": "tt4154756", 593 | "9460": "tt6483364", 594 | "9461": "tt7083526", 595 | "9462": "tt6348138", 596 | "9481": "tt4154796", 597 | "9482": "tt2126355", 598 | "9488": "tt0265298", 599 | "9502": "tt0234215", 600 | "9504": "tt0094898", 601 | "9509": "tt6722030", 602 | "9510": "tt1946502", 603 | "9515": "tt0810913", 604 | "9519": "tt5884052", 605 | "9526": "tt0437086", 606 | "9528": "tt4913966", 607 | "9529": "tt5461944", 608 | "9535": "tt1430607", 609 | "9552": "tt0450405", 610 | "9555": "tt0120630", 611 | "9575": "tt1588173", 612 | "9576": "tt0073195", 613 | "9583": "tt0115820", 614 | "9595": "tt1399103", 615 | "9606": "tt2294449", 616 | "9615": "tt5113040", 617 | "9617": "tt1298644", 618 | "9618": "tt3741700", 619 | "9619": "tt7752126", 620 | "9620": "tt8385474", 621 | "9638": "tt0107120", 622 | "9642": "tt0995039", 623 | "9644": "tt5635086", 624 | "9647": "tt2283336", 625 | "9654": "tt0972785", 626 | "9659": "tt6146586", 627 | "9676": "tt4701182", 628 | "9689": "tt2717822", 629 | "9719": "tt3654796", 630 | "9724": "tt0407304", 631 | "9732": "tt8079248", 632 | "9733": "tt3361792", 633 | "9735": "tt0146316", 634 | "9737": "tt6320628", 635 | "9738": "tt0198781", 636 | "9741": "tt0298814", 637 | "9747": "tt0129290", 638 | "9750": "tt0290095", 639 | "9751": "tt0413895", 640 | "9754": "tt7456310", 641 | "9756": "tt0047296", 642 | "9761": "tt0389790", 643 | "9773": "tt0117218", 644 | "9774": "tt8350360", 645 | "9785": "tt0249462", 646 | "9799": "tt0120749", 647 | "9846": "tt0091042", 648 | "9896": "tt2713180", 649 | "9906": "tt0319343", 650 | "9920": "tt3315342", 651 | "9952": "tt0099653" 652 | } -------------------------------------------------------------------------------- /datasets/mad_split.json: -------------------------------------------------------------------------------- 1 | { 2 | "EVAL": [ 3 | "1005_Signs", 4 | "1026_Legion", 5 | "1027_Les_Miserables", 6 | "1051_Harry_Potter_and_the_goblet_of_fire", 7 | "3009_BATTLE_LOS_ANGELES", 8 | "3015_CHARLIE_ST_CLOUD", 9 | "3031_HANSEL_GRETEL_WITCH_HUNTERS", 10 | "3032_HOW_DO_YOU_KNOW", 11 | "3034_IDES_OF_MARCH", 12 | "3074_THE_ROOMMATE" 13 | ], 14 | "TRAIN": [ 15 | "10142", 16 | "10149", 17 | "10202", 18 | "10322", 19 | "10527", 20 | "10536", 21 | "10784", 22 | "10813", 23 | "10836", 24 | "10861", 25 | "10894", 26 | "10965", 27 | "11003", 28 | "11010", 29 | "11099", 30 | "11129", 31 | "11139", 32 | "11140", 33 | "11143", 34 | "11147", 35 | "11148", 36 | "11154", 37 | "11318", 38 | "11321", 39 | "11345", 40 | "11396", 41 | "11430", 42 | "11438", 43 | "11530", 44 | "11620", 45 | "11727", 46 | "11796", 47 | "11962", 48 | "12010", 49 | "12079", 50 | "12090", 51 | "12125", 52 | "12131", 53 | "12132", 54 | "12144", 55 | "12147", 56 | "12148", 57 | "12186", 58 | "12211", 59 | "12220", 60 | "12222", 61 | "12263", 62 | "12273", 63 | "12294", 64 | "12324", 65 | "12358", 66 | "12504", 67 | "12563", 68 | "12585", 69 | "12618", 70 | "12653", 71 | "12658", 72 | "12743", 73 | "12852", 74 | "12869", 75 | "12900", 76 | "12906", 77 | "12911", 78 | "12923", 79 | "12958", 80 | "13018", 81 | "13027", 82 | "13031", 83 | "13045", 84 | "13140", 85 | "13146", 86 | "13159", 87 | "13165", 88 | "13187", 89 | "13191", 90 | "13201", 91 | "2723", 92 | "2730", 93 | "2731", 94 | "2735", 95 | "2738", 96 | "2745", 97 | "2750", 98 | "2758", 99 | "2768", 100 | "2778", 101 | "2787", 102 | "2800", 103 | "2801", 104 | "2814", 105 | "2818", 106 | "2854", 107 | "2869", 108 | "2870", 109 | "2873", 110 | "2911", 111 | "2913", 112 | "2928", 113 | "2934", 114 | "2944", 115 | "2948", 116 | "2970", 117 | "2986", 118 | "2992", 119 | "2996", 120 | "3001", 121 | "3014", 122 | "3020", 123 | "3021", 124 | "3023", 125 | "3033", 126 | "3040", 127 | "3049", 128 | "3050", 129 | "3059", 130 | "3060", 131 | "3066", 132 | "3070", 133 | "3103", 134 | "3106", 135 | "3113", 136 | "3114", 137 | "3117", 138 | "3129", 139 | "3138", 140 | "3146", 141 | "3153", 142 | "3160", 143 | "3170", 144 | "3171", 145 | "3209", 146 | "3239", 147 | "3253", 148 | "3276", 149 | "3277", 150 | "3295", 151 | "3314", 152 | "3339", 153 | "3340", 154 | "3354", 155 | "3376", 156 | "3393", 157 | "3401", 158 | "3408", 159 | "3414", 160 | "3417", 161 | "3447", 162 | "3464", 163 | "3480", 164 | "3482", 165 | "3500", 166 | "3509", 167 | "3510", 168 | "3513", 169 | "3521", 170 | "3548", 171 | "3575", 172 | "3590", 173 | "3599", 174 | "3611", 175 | "3625", 176 | "3720", 177 | "3743", 178 | "3759", 179 | "3773", 180 | "3820", 181 | "3834", 182 | "3837", 183 | "3858", 184 | "3905", 185 | "3911", 186 | "3922", 187 | "3977", 188 | "4001", 189 | "4007", 190 | "4010", 191 | "4017", 192 | "4031", 193 | "4043", 194 | "4053", 195 | "4061", 196 | "4071", 197 | "4080", 198 | "4082", 199 | "4143", 200 | "4156", 201 | "4200", 202 | "4204", 203 | "4210", 204 | "4253", 205 | "4266", 206 | "4299", 207 | "4303", 208 | "4305", 209 | "4368", 210 | "4377", 211 | "4378", 212 | "4390", 213 | "4423", 214 | "4434", 215 | "4451", 216 | "4455", 217 | "4460", 218 | "4480", 219 | "4489", 220 | "4528", 221 | "4535", 222 | "4551", 223 | "4576", 224 | "4578", 225 | "4587", 226 | "4596", 227 | "4597", 228 | "4608", 229 | "4611", 230 | "4618", 231 | "4634", 232 | "4635", 233 | "4638", 234 | "4644", 235 | "4664", 236 | "4670", 237 | "4671", 238 | "4684", 239 | "4702", 240 | "4709", 241 | "4719", 242 | "4728", 243 | "4740", 244 | "4741", 245 | "4753", 246 | "4772", 247 | "4778", 248 | "4797", 249 | "4798", 250 | "4813", 251 | "4815", 252 | "4839", 253 | "4880", 254 | "4884", 255 | "4888", 256 | "4901", 257 | "4902", 258 | "4914", 259 | "4925", 260 | "4929", 261 | "4933", 262 | "4936", 263 | "4950", 264 | "4962", 265 | "4970", 266 | "4977", 267 | "4982", 268 | "4992", 269 | "5014", 270 | "5041", 271 | "5055", 272 | "5063", 273 | "5074", 274 | "5093", 275 | "5101", 276 | "5118", 277 | "5139", 278 | "5144", 279 | "5217", 280 | "5236", 281 | "5237", 282 | "5257", 283 | "5259", 284 | "5265", 285 | "5270", 286 | "5283", 287 | "5293", 288 | "5308", 289 | "5335", 290 | "5366", 291 | "5367", 292 | "5369", 293 | "5417", 294 | "5420", 295 | "5432", 296 | "5449", 297 | "5461", 298 | "5469", 299 | "5473", 300 | "5477", 301 | "5494", 302 | "5506", 303 | "5510", 304 | "5511", 305 | "5522", 306 | "5563", 307 | "5565", 308 | "5568", 309 | "5574", 310 | "5575", 311 | "5577", 312 | "5583", 313 | "5594", 314 | "5605", 315 | "5607", 316 | "5634", 317 | "5641", 318 | "5649", 319 | "5677", 320 | "5678", 321 | "5682", 322 | "5685", 323 | "5700", 324 | "5735", 325 | "5737", 326 | "5743", 327 | "5749", 328 | "5752", 329 | "5758", 330 | "5762", 331 | "5792", 332 | "5807", 333 | "5814", 334 | "5818", 335 | "5819", 336 | "5828", 337 | "5852", 338 | "5865", 339 | "5872", 340 | "5873", 341 | "5898", 342 | "5900", 343 | "5913", 344 | "5923", 345 | "5950", 346 | "5958", 347 | "6012", 348 | "6013", 349 | "6022", 350 | "6048", 351 | "6055", 352 | "6057", 353 | "6076", 354 | "6086", 355 | "6090", 356 | "6137", 357 | "6153", 358 | "6154", 359 | "6156", 360 | "6177", 361 | "6186", 362 | "6194", 363 | "6224", 364 | "6232", 365 | "6319", 366 | "6334", 367 | "6394", 368 | "6402", 369 | "6491", 370 | "6521", 371 | "6607", 372 | "6613", 373 | "6617", 374 | "6629", 375 | "6636", 376 | "6655", 377 | "6656", 378 | "6672", 379 | "6685", 380 | "6701", 381 | "6706", 382 | "6741", 383 | "6769", 384 | "6770", 385 | "6775", 386 | "6810", 387 | "6811", 388 | "6816", 389 | "6819", 390 | "6832", 391 | "6833", 392 | "6837", 393 | "6859", 394 | "6869", 395 | "6870", 396 | "6878", 397 | "6890", 398 | "6952", 399 | "6959", 400 | "6992", 401 | "6994", 402 | "7001", 403 | "7005", 404 | "7007", 405 | "7026", 406 | "7036", 407 | "7050", 408 | "7055", 409 | "7131", 410 | "7195", 411 | "7196", 412 | "7243", 413 | "7682", 414 | "7882", 415 | "8152", 416 | "8276", 417 | "8295", 418 | "8346", 419 | "8496", 420 | "8578", 421 | "8587", 422 | "8589", 423 | "8593", 424 | "8598", 425 | "8601", 426 | "8608", 427 | "8616", 428 | "8618", 429 | "8637", 430 | "8734", 431 | "8766", 432 | "8767", 433 | "8811", 434 | "9110", 435 | "9277", 436 | "9380", 437 | "9384", 438 | "9386", 439 | "9387", 440 | "9419", 441 | "9421", 442 | "9451", 443 | "9456", 444 | "9460", 445 | "9461", 446 | "9462", 447 | "9481", 448 | "9482", 449 | "9488", 450 | "9502", 451 | "9504", 452 | "9509", 453 | "9510", 454 | "9515", 455 | "9519", 456 | "9526", 457 | "9528", 458 | "9529", 459 | "9535", 460 | "9552", 461 | "9555", 462 | "9575", 463 | "9576", 464 | "9583", 465 | "9595", 466 | "9606", 467 | "9615", 468 | "9617", 469 | "9618", 470 | "9619", 471 | "9620", 472 | "9638", 473 | "9642", 474 | "9644", 475 | "9647", 476 | "9654", 477 | "9659", 478 | "9676", 479 | "9689", 480 | "9719", 481 | "9724", 482 | "9732", 483 | "9733", 484 | "9735", 485 | "9737", 486 | "9738", 487 | "9741", 488 | "9747", 489 | "9750", 490 | "9751", 491 | "9754", 492 | "9756", 493 | "9761", 494 | "9773", 495 | "9774", 496 | "9785", 497 | "9799", 498 | "9846", 499 | "9896", 500 | "9906", 501 | "9920", 502 | "9952" 503 | ], 504 | "V1-VAL": [ 505 | "0002_As_Good_As_It_Gets", 506 | "0003_CASABLANCA", 507 | "0006_Clerks", 508 | "0013_Halloween", 509 | "0014_Ist_das_Leben_nicht_schoen", 510 | "0020_Raising_Arizona", 511 | "0021_Rear_Window", 512 | "0023_THE_BUTTERFLY_EFFECT", 513 | "0026_The_Big_Fish", 514 | "0027_The_Big_Lebowski", 515 | "0031_The_Lost_Weekend", 516 | "0032_The_Princess_Bride", 517 | "0033_Amadeus", 518 | "0053_Rendezvous_mit_Joe_Black", 519 | "1010_TITANIC", 520 | "1015_27_Dresses", 521 | "1017_Bad_Santa", 522 | "1018_Body_Of_Lies", 523 | "1019_Confessions_Of_A_Shopaholic", 524 | "1026_Legion", 525 | "1027_Les_Miserables", 526 | "1033_Sherlock_Holmes_A_Game_of_Shadows", 527 | "1035_The_Adjustment_Bureau", 528 | "1038_The_Great_Gatsby", 529 | "1042_Up_In_The_Air", 530 | "1050_Harry_Potter_and_the_deathly_hallows_Disk_One", 531 | "1052_Harry_Potter_and_the_order_of_phoenix", 532 | "1058_The_Damned_united", 533 | "3003_40_YEAR_OLD_VIRGIN", 534 | "3005_ABRAHAM_LINCOLN_VAMPIRE_HUNTER", 535 | "3018_CINDERELLA_MAN", 536 | "3024_EASY_A", 537 | "3030_GROWN_UPS", 538 | "3034_IDES_OF_MARCH", 539 | "3038_ITS_COMPLICATED", 540 | "3041_JUST_GO_WITH_IT", 541 | "3042_KARATE_KID", 542 | "3043_KATY_PERRY_PART_OF_ME", 543 | "3047_LIFE_OF_PI", 544 | "3048_LITTLE_FOCKERS", 545 | "3049_MORNING_GLORY", 546 | "3053_PARENTAL_GUIDANCE", 547 | "3066_THE_ADVENTURES_OF_TINTIN", 548 | "3069_THE_BOUNTY_HUNTER", 549 | "3071_THE_DESCENDANTS", 550 | "3074_THE_ROOMMATE", 551 | "3077_THE_VOW", 552 | "3088_WHATS_YOUR_NUMBER", 553 | "3090_YOUNG_ADULT", 554 | "3092_ZOOKEEPER" 555 | ], 556 | "V1-TEST": [ 557 | "0001_American_Beauty", 558 | "0004_Charade", 559 | "0005_Chinatown", 560 | "0007_DIE_NACHT_DES_JAEGERS", 561 | "0008_Fargo", 562 | "0009_Forrest_Gump", 563 | "0010_Frau_Ohne_Gewissen", 564 | "0011_Gandhi", 565 | "0012_Get_Shorty", 566 | "0016_O_Brother_Where_Art_Thou", 567 | "0017_Pianist", 568 | "0019_Pulp_Fiction", 569 | "0022_Reservoir_Dogs", 570 | "0028_The_Crying_Game", 571 | "0029_The_Graduate", 572 | "0030_The_Hustler", 573 | "0038_Psycho", 574 | "0041_The_Sixth_Sense", 575 | "0043_Thelma_and_Luise", 576 | "0046_Chasing_Amy", 577 | "0049_Hannah_and_her_sisters", 578 | "0050_Indiana_Jones_and_the_last_crusade", 579 | "0051_Men_in_black", 580 | "1001_Flight", 581 | "1002_Harry_Potter_and_the_Half-Blood_Prince", 582 | "1003_How_to_Lose_Friends_and_Alienate_People", 583 | "1004_Juno", 584 | "1005_Signs", 585 | "1006_Slumdog_Millionaire", 586 | "1007_Spider-Man1", 587 | "1008_Spider-Man2", 588 | "1009_Spider-Man3", 589 | "1011_The_Help", 590 | "1012_Unbreakable", 591 | "1014_2012", 592 | "1020_Crazy_Stupid_Love", 593 | "1028_No_Reservations", 594 | "1031_Quantum_of_Solace", 595 | "1034_Super_8", 596 | "1037_The_Curious_Case_Of_Benjamin_Button", 597 | "1039_The_Queen", 598 | "1040_The_Ugly_Truth", 599 | "1043_Vantage_Point", 600 | "1045_An_education", 601 | "1046_Australia", 602 | "1047_Defiance", 603 | "1048_Gran_Torino", 604 | "1051_Harry_Potter_and_the_goblet_of_fire", 605 | "1054_Harry_Potter_and_the_prisoner_of_azkaban", 606 | "1055_Marley_and_me", 607 | "1057_Seven_pounds", 608 | "1059_The_devil_wears_prada", 609 | "1060_Yes_man", 610 | "1061_Harry_Potter_and_the_deathly_hallows_Disk_Two", 611 | "1062_Day_the_Earth_stood_still", 612 | "3001_21_JUMP_STREET", 613 | "3002_30_MINUTES_OR_LESS", 614 | "3004_500_DAYS_OF_SUMMER", 615 | "3007_A_THOUSAND_WORDS", 616 | "3008_BAD_TEACHER", 617 | "3009_BATTLE_LOS_ANGELES", 618 | "3012_BRUNO", 619 | "3013_BURLESQUE", 620 | "3014_CAPTAIN_AMERICA", 621 | "3015_CHARLIE_ST_CLOUD", 622 | "3016_CHASING_MAVERICKS", 623 | "3017_CHRONICLE", 624 | "3020_DEAR_JOHN", 625 | "3021_DEATH_AT_A_FUNERAL", 626 | "3022_DINNER_FOR_SCHMUCKS", 627 | "3023_DISTRICT_9", 628 | "3025_FLIGHT", 629 | "3026_FRIENDS_WITH_BENEFITS", 630 | "3028_GHOST_RIDER_SPIRIT_OF_VENGEANCE", 631 | "3031_HANSEL_GRETEL_WITCH_HUNTERS", 632 | "3032_HOW_DO_YOU_KNOW", 633 | "3033_HUGO", 634 | "3035_INSIDE_MAN", 635 | "3036_IN_TIME", 636 | "3037_IRON_MAN2", 637 | "3039_JACK_AND_JILL", 638 | "3040_JULIE_AND_JULIA", 639 | "3045_LAND_OF_THE_LOST", 640 | "3046_LARRY_CROWNE", 641 | "3050_MR_POPPERS_PENGUINS", 642 | "3051_NANNY_MCPHEE_RETURNS", 643 | "3052_NO_STRINGS_ATTACHED", 644 | "3054_PERCY_JACKSON_LIGHTENING_THIEF", 645 | "3055_PROMETHEUS", 646 | "3056_PUBLIC_ENEMIES", 647 | "3058_RUBY_SPARKS", 648 | "3060_SANCTUM", 649 | "3061_SNOW_FLOWER", 650 | "3062_SORCERERS_APPRENTICE", 651 | "3063_SOUL_SURFER", 652 | "3067_THE_ART_OF_GETTING_BY", 653 | "3070_THE_CALL", 654 | "3072_THE_GIRL_WITH_THE_DRAGON_TATTOO", 655 | "3073_THE_GUILT_TRIP", 656 | "3075_THE_SITTER", 657 | "3076_THE_SOCIAL_NETWORK", 658 | "3078_THE_WATCH", 659 | "3079_THINK_LIKE_A_MAN", 660 | "3081_THOR", 661 | "3082_TITANIC1", 662 | "3083_TITANIC2", 663 | "3084_TOOTH_FAIRY", 664 | "3085_TRUE_GRIT", 665 | "3086_UGLY_TRUTH", 666 | "3087_WE_BOUGHT_A_ZOO", 667 | "3089_XMEN_FIRST_CLASS", 668 | "3091_ZOMBIELAND" 669 | ] 670 | } -------------------------------------------------------------------------------- /models/gpt_utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing#scrollTo=OArDkm_24w4L """ 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from tqdm import tqdm, trange 7 | # from transformers.generation_utils import top_k_top_p_filtering 8 | 9 | @torch.no_grad() 10 | def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None, attention_mask=None, 11 | entry_length=67, temperature=1., stop_token: str = '.', past_key_values=None, media=None, 12 | repetition_penalty=1.2, 13 | no_repeat_ngram_size=3, 14 | history_tokens=None,): 15 | 16 | model.eval() 17 | # stop_token_index = tokenizer.encode(stop_token)[0] 18 | stop_token_index = 50256 19 | tokens = None 20 | scores = None 21 | device = next(model.parameters()).device 22 | seq_lengths = torch.ones(beam_size, device=device) 23 | is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) 24 | 25 | if embed is not None: 26 | generated = embed 27 | else: 28 | if tokens is None: 29 | tokens = torch.tensor(tokenizer.encode(prompt)).long() 30 | tokens = tokens.unsqueeze(0).to(device) 31 | generated = model.gpt.transformer.wte(tokens) 32 | 33 | for i in range(entry_length): 34 | if media is not None: 35 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values, media=media) 36 | else: 37 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values) 38 | 39 | if past_key_values is not None: 40 | past_key_values = outputs.past_key_values 41 | 42 | logits = outputs.logits 43 | logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) 44 | logits = F.log_softmax(logits, -1) 45 | logits = enforce_repetition_penalty(logits, tokens, repetition_penalty=repetition_penalty) 46 | 47 | # # entire generated paragraph as ngram repetition penalty 48 | if (history_tokens is None) or (history_tokens.numel() == 0): 49 | h_tokens = tokens 50 | elif tokens is None: 51 | h_tokens = history_tokens.repeat(beam_size, 1) 52 | else: 53 | h_tokens = torch.cat((history_tokens.repeat(beam_size, 1), tokens), dim=1) 54 | cul_len = 0 if h_tokens is None else h_tokens.shape[1] 55 | num_hypo = 1 if h_tokens is None else h_tokens.shape[0] 56 | banned_batch_tokens = calc_banned_ngram_tokens( 57 | h_tokens, num_hypo, no_repeat_ngram_size, cul_len 58 | ) 59 | if len(banned_batch_tokens) > logits.shape[0]: 60 | banned_batch_tokens = [banned_batch_tokens[0]] 61 | for i, banned_tokens in enumerate(banned_batch_tokens): 62 | logits[i, banned_tokens] = -float("inf") 63 | 64 | # beam search 65 | if scores is None: 66 | scores, next_tokens = logits.topk(beam_size, -1) 67 | generated = generated.expand(beam_size, *generated.shape[1:]) 68 | next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) 69 | if tokens is None: 70 | tokens = next_tokens 71 | else: 72 | tokens = tokens.expand(beam_size, *tokens.shape[1:]) 73 | tokens = torch.cat((tokens, next_tokens), dim=1) 74 | else: 75 | logits[is_stopped] = -float(np.inf) 76 | logits[is_stopped, 0] = 0 77 | scores_sum = scores[:, None] + logits 78 | seq_lengths[~is_stopped] += 1 79 | scores_sum_average = scores_sum / seq_lengths[:, None] 80 | scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) 81 | next_tokens_source = torch.div(next_tokens, scores_sum.shape[1], rounding_mode='floor') # next_tokens // scores_sum.shape[1] 82 | seq_lengths = seq_lengths[next_tokens_source] 83 | next_tokens = next_tokens % scores_sum.shape[1] 84 | next_tokens = next_tokens.unsqueeze(1) 85 | tokens = tokens[next_tokens_source] 86 | tokens = torch.cat((tokens, next_tokens), dim=1) 87 | generated = generated[next_tokens_source] 88 | scores = scores_sum_average * seq_lengths 89 | is_stopped = is_stopped[next_tokens_source] 90 | next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) 91 | if past_key_values is not None: 92 | generated = next_token_embed 93 | if i == 0: 94 | past_key_values = tuple([(tp[0].repeat(beam_size,1,1,1), tp[1].repeat(beam_size,1,1,1)) for tp in past_key_values]) 95 | else: 96 | generated = torch.cat((generated, next_token_embed), dim=1) 97 | 98 | if attention_mask is not None: 99 | if i == 0: 100 | attention_mask = attention_mask.repeat(beam_size, 1) 101 | attention_mask = torch.cat((attention_mask, torch.ones_like(attention_mask[:,-2:-1])), dim=1) 102 | is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() 103 | if is_stopped.all(): 104 | break 105 | scores = scores / seq_lengths 106 | output_list = tokens.cpu().numpy() 107 | output_texts = [tokenizer.decode(output[:int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths)] 108 | order = scores.argsort(descending=True) 109 | output_texts = [output_texts[i] for i in order] 110 | return output_texts 111 | 112 | 113 | @torch.no_grad() 114 | def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.0): 115 | if (prev_output_tokens is None) or (prev_output_tokens.shape == 0) or (repetition_penalty == 1.0): 116 | return lprobs 117 | for i in range(lprobs.shape[0]): 118 | for previous_token in set(prev_output_tokens[i].tolist()): 119 | # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability 120 | if lprobs[i, previous_token] < 0: # log(prob) 121 | lprobs[i, previous_token] *= repetition_penalty 122 | else: # prob 123 | lprobs[i, previous_token] /= repetition_penalty 124 | return lprobs 125 | 126 | 127 | @torch.no_grad() 128 | def calc_banned_ngram_tokens(prev_input_ids: torch.Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: 129 | """Copied from fairseq for no_repeat_ngram in beam_search""" 130 | if cur_len + 1 < no_repeat_ngram_size: 131 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 132 | return [[] for _ in range(num_hypos)] 133 | generated_ngrams = [{} for _ in range(num_hypos)] 134 | for idx in range(num_hypos): 135 | gen_tokens = prev_input_ids[idx].tolist() 136 | generated_ngram = generated_ngrams[idx] 137 | for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): 138 | prev_ngram_tuple = tuple(ngram[:-1]) 139 | generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] 140 | 141 | def _get_generated_ngrams(hypo_idx): 142 | # Before decoding the next token, prevent decoding of ngrams that have already appeared 143 | start_idx = cur_len + 1 - no_repeat_ngram_size 144 | ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) 145 | return generated_ngrams[hypo_idx].get(ngram_idx, []) 146 | 147 | banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] 148 | return banned_tokens 149 | 150 | 151 | @torch.no_grad() 152 | def generate_greedy( 153 | model, 154 | tokenizer, 155 | tokens=None, 156 | prompt=None, 157 | embed=None, 158 | attention_mask=None, 159 | entry_count=1, 160 | entry_length=67, # maximum number of words 161 | top_p=0.8, 162 | temperature=1., 163 | stop_token: str = '.', 164 | verbose=False, 165 | past_key_values=None, 166 | media=None, 167 | repetition_penalty=1.2, 168 | no_repeat_ngram_size=3, 169 | history_tokens=None, 170 | ): 171 | model.eval() 172 | generated_num = 0 173 | generated_list = [] 174 | # stop_token_index = tokenizer.encode(stop_token)[0] 175 | stop_token_index = 50256 176 | filter_value = -float("Inf") 177 | device = next(model.parameters()).device 178 | 179 | for entry_idx in trange(entry_count, disable=not verbose): 180 | if embed is not None: 181 | generated = embed 182 | else: 183 | if tokens is None: 184 | tokens = torch.tensor(tokenizer.encode(prompt)).long() 185 | tokens = tokens.unsqueeze(0).to(device) 186 | 187 | generated = model.gpt.transformer.wte(tokens) 188 | 189 | for i in range(entry_length): 190 | if media is not None: 191 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values, media=media) 192 | else: 193 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values) 194 | if past_key_values is not None: 195 | past_key_values = outputs.past_key_values 196 | logits = outputs.logits 197 | logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) 198 | logits = enforce_repetition_penalty(F.log_softmax(logits, dim=-1), tokens, repetition_penalty=repetition_penalty) 199 | 200 | # entire generated paragraph as ngram repetition penalty 201 | if (history_tokens is None) or (history_tokens.numel() == 0): 202 | h_tokens = tokens 203 | elif tokens is None: 204 | h_tokens = history_tokens 205 | else: 206 | h_tokens = torch.cat((history_tokens, tokens), dim=1) 207 | cul_len = 0 if h_tokens is None else h_tokens.shape[1] 208 | banned_batch_tokens = calc_banned_ngram_tokens( 209 | h_tokens, 1, no_repeat_ngram_size, cul_len 210 | ) 211 | 212 | for i, banned_tokens in enumerate(banned_batch_tokens): 213 | logits[i, banned_tokens] = -float("inf") 214 | 215 | logits = F.softmax(logits, dim=-1) 216 | 217 | # TOP-P filtering 218 | # sorted_logits, sorted_indices = torch.sort(logits, descending=True) 219 | # cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 220 | # sorted_indices_to_remove = cumulative_probs > top_p 221 | # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ 222 | # ..., :-1 223 | # ].clone() 224 | # sorted_indices_to_remove[..., 0] = 0 225 | 226 | # indices_to_remove = sorted_indices[sorted_indices_to_remove] 227 | # logits[:, indices_to_remove] = filter_value 228 | next_token = torch.argmax(logits, -1).unsqueeze(0) 229 | next_token_embed = model.gpt.transformer.wte(next_token) 230 | if tokens is None: 231 | tokens = next_token 232 | else: 233 | tokens = torch.cat((tokens, next_token), dim=1) 234 | 235 | if past_key_values is not None: 236 | generated = next_token_embed 237 | else: 238 | generated = torch.cat((generated, next_token_embed), dim=1) 239 | 240 | if attention_mask is not None: 241 | attention_mask = torch.cat((attention_mask, torch.ones_like(attention_mask[:,-2:-1])), dim=1) 242 | if stop_token_index == next_token.item(): 243 | break 244 | 245 | try: 246 | output_list = list(tokens.squeeze().cpu().numpy()) 247 | output_text = tokenizer.decode(output_list, skip_special_tokens=True) 248 | except: 249 | output_text = '.' 250 | 251 | generated_list.append(output_text) 252 | 253 | return generated_list[0] 254 | 255 | 256 | @torch.no_grad() 257 | def generate_top_k_top_p( 258 | model, 259 | tokenizer, 260 | tokens=None, 261 | prompt=None, 262 | embed=None, 263 | attention_mask=None, 264 | entry_count=1, 265 | entry_length=67, # maximum number of words 266 | top_p=0.8, 267 | top_k=3, 268 | temperature=1., 269 | stop_token: str = '.', 270 | verbose=False, 271 | past_key_values=None, 272 | media=None, 273 | ): 274 | """modified from https://github.com/JasonBenn/duet/blob/master/generate.py""" 275 | 276 | model.eval() 277 | generated_num = 0 278 | generated_list = [] 279 | stop_token_index = tokenizer.encode(stop_token)[0] 280 | filter_value = -float("Inf") 281 | device = next(model.parameters()).device 282 | 283 | for entry_idx in trange(entry_count, disable=not verbose): 284 | if embed is not None: 285 | generated = embed 286 | else: 287 | if tokens is None: 288 | tokens = torch.tensor(tokenizer.encode(prompt)).long() 289 | tokens = tokens.unsqueeze(0).to(device) 290 | 291 | generated = model.gpt.transformer.wte(tokens) 292 | 293 | for i in range(entry_length): 294 | if media is not None: 295 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values, media=media) 296 | else: 297 | outputs = model.gpt(inputs_embeds=generated, attention_mask=attention_mask, past_key_values=past_key_values) 298 | if past_key_values is not None: 299 | past_key_values = outputs.past_key_values 300 | logits = outputs.logits 301 | logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) 302 | logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k) 303 | probs = F.softmax(logits, dim=-1) 304 | next_token = torch.multinomial(probs, num_samples=1) 305 | next_token_embed = model.gpt.transformer.wte(next_token) 306 | if tokens is None: 307 | tokens = next_token 308 | else: 309 | tokens = torch.cat((tokens, next_token), dim=1) 310 | 311 | if past_key_values is not None: 312 | generated = next_token_embed 313 | else: 314 | generated = torch.cat((generated, next_token_embed), dim=1) 315 | 316 | if attention_mask is not None: 317 | attention_mask = torch.cat((attention_mask, torch.ones_like(attention_mask[:,-2:-1])), dim=1) 318 | if stop_token_index == next_token.item(): 319 | break 320 | 321 | try: 322 | output_list = list(tokens.squeeze().cpu().numpy()) 323 | output_text = tokenizer.decode(output_list) 324 | except: 325 | output_text = '.' 326 | 327 | generated_list.append(output_text) 328 | 329 | return generated_list[0] 330 | 331 | 332 | @torch.no_grad() 333 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 334 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 335 | https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 336 | Args: 337 | logits: logits distribution shape (..., vocabulary size) 338 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 339 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 340 | """ 341 | top_k = min(top_k, logits.size(-1)) # Safety check 342 | if top_k > 0: 343 | # Remove all tokens with a probability less than the last token of the top-k 344 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 345 | logits[indices_to_remove] = filter_value 346 | 347 | if top_p > 0.0: 348 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 349 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 350 | 351 | # Remove tokens with cumulative probability above the threshold 352 | sorted_indices_to_remove = cumulative_probs >= top_p 353 | # Shift the indices to the right to keep also the first token above the threshold 354 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 355 | sorted_indices_to_remove[..., 0] = 0 356 | # indices_to_remove = sorted_indices[sorted_indices_to_remove][None,:] 357 | indices_to_remove = torch.zeros_like(logits, dtype=torch.long).scatter_( 358 | dim=-1, index=sorted_indices, src=sorted_indices_to_remove.long()).bool() 359 | logits[indices_to_remove] = filter_value 360 | return logits 361 | 362 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AutoAD Project 2 | 3 | * AutoAD III: The Prequel -- Back to the Pixels. [CVPR'24]. T. Han, M. Bain, A. Nagrani, G. Varol, W. Xie and A. Zisserman. [[PDF]](https://www.robots.ox.ac.uk/~vgg/publications/2024/Han24/han24.pdf) 4 | * AutoAD II: The Sequel – Who, When, and What in Movie Audio Description [ICCV'23]. T. Han, M. Bain, A. Nagrani, G. Varol, W. Xie and A. Zisserman. [[PDF]](https://www.robots.ox.ac.uk/~vgg/publications/2023/Han23a/han23a.pdf) 5 | * AutoAD I: Movie Description in Context [CVPR'23 Highlight]. T. Han*, M. Bain*, A. Nagrani, G. Varol, W. Xie and A. Zisserman. [[PDF]](https://www.robots.ox.ac.uk/~vgg/publications/2023/Han23/han23.pdf) 6 | 7 | [[project page]](https://www.robots.ox.ac.uk/~vgg/research/autoad/) 8 | 9 | ### News :mega: 10 | * 2024.04.22: AutoAD-III paper released. Model weights and examples AD outputs are [available here](autoad_iii/). More code and datasets coming soon. 11 | 12 | 13 | 14 | 15 | 16 | ### Details 17 | * AutoAD-III: [autoad_iii/](autoad_iii/) 18 | * AutoAD-II: [autoad_ii/](autoad_ii/) 19 | * AutoAD-I: [autoad_i/](autoad_i/) 20 | 21 | 22 | ### Reference 23 | ```bibtex 24 | @InProceedings{han2024autoad3, 25 | title={{AutoAD III: The Prequel} - Back to the Pixels}, 26 | author={Tengda Han and Max Bain and Arsha Nagrani and G\"ul Varol and Weidi Xie and Andrew Zisserman}, 27 | booktitle={CVPR}, 28 | year={2024}} 29 | 30 | @InProceedings{han2023autoad2, 31 | title={{AutoAD II: The Sequel} - Who, When, and What in Movie Audio Description}, 32 | author={Tengda Han and Max Bain and Arsha Nagrani and G\"ul Varol and Weidi Xie and Andrew Zisserman}, 33 | booktitle={ICCV}, 34 | year={2023}} 35 | 36 | @InProceedings{han2023autoad1, 37 | title={{AutoAD}: Movie Description in Context}, 38 | author={Tengda Han and Max Bain and Arsha Nagrani and G\"ul Varol and Weidi Xie and Andrew Zisserman}, 39 | booktitle={CVPR}, 40 | year={2023}} 41 | ``` 42 | 43 | --------------------------------------------------------------------------------