├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitkeep ├── dataloader ├── __init__.py ├── base_dataset.py ├── dramaqa.py ├── egoschema.py ├── how2qa.py ├── nextqa.py ├── perception.py ├── star.py ├── textvid.py ├── tvqa.py ├── vlep.py └── webvid.py ├── demos ├── 13B_msrvtt_results.json ├── 7B_msrvtt_results.json ├── Eval_Cap_MSRVTT.ipynb ├── Eval_Cap_VATEX.ipynb ├── mvbench.ipynb ├── upload_leaderboard_13B_zero_shot.json ├── upload_leaderboard_7B_ZS.json └── video_transforms.py ├── engine.py ├── llama ├── __init__.py ├── generation.py ├── model.py ├── model_llama3.py ├── tokenizer.py └── tokenizer_llama3.py ├── llama_vqa.py ├── pics └── topa_framework.jpg ├── pretrained └── .gitkeep ├── scripts ├── baseline │ ├── llama2_13b.sh │ ├── llama2_7b.sh │ └── llama3_8b.sh ├── eval │ ├── zeroshot_eval_egos.sh │ ├── zeroshot_eval_nextqa.sh │ ├── zeroshot_eval_star.sh │ └── zeroshot_eval_tvqa.sh ├── finetune │ ├── LLama2_13b_finetune.sh │ ├── LLama2_7b_finetune.sh │ └── LLama3_finetune.sh └── pretrain │ ├── llama2_13b.sh │ ├── llama2_7b.sh │ └── llama3_8b.sh ├── setup.sh ├── train.py ├── util ├── lr_sched.py └── misc.py └── vqa_checkpoint └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore 2 | data/* 3 | !data/.gitkeep 4 | pretrained/* 5 | !pretrained/.gitkeep 6 | vqa_checkpoint/* 7 | !vqa_checkpoint/.gitkeep 8 | /python 9 | results/* 10 | 11 | *ipynb_checkpoints 12 | *.pyc 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MLV Lab (Machine Learning and Vision Lab at Korea University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

TOPA: Extend Large Language Models for Video Understanding via Text-Only Pre-Alignment 3 |

NeurIPS 2024 Spollight 4 |

5 | 6 |
7 | 8 |
9 | 10 |
11 | 12 | ## Data Preparation: 13 | Prepare the data as follows. 14 | 15 | **TextVID**: download TextVID at [TextVid only](https://drive.google.com/file/d/12xocihCDYocHVtsdzymii3BnTJmlh430/view?usp=sharing). download TextVID and preprocessed features at [TextVid and features](https://drive.google.com/file/d/1hfMIlABeAl9D_qhG5EcLUVM2HcZArTUx/view?usp=sharing) 16 | 17 | **NeXTQA, STAR and TVQA**: 18 | The prepocessed feautures are available at [here](https://github.com/mlvlab/Flipped-VQA). 19 | 20 | **EgoScehma**: 21 | Download raw videos from [EgoSchema](https://github.com/egoschema/EgoSchema). We provide prepocessed feature [here](https://drive.google.com/file/d/1yCAw101BZvOtSntDToNX7TR8ErdnfSdk/view?usp=sharing) 22 | 23 | **MVBench**: 24 | Download raw videos from [Hugging Face](https://huggingface.co/datasets/OpenGVLab/MVBench). 25 | 26 | **MSRVTT**: 27 | Download raw videos from [MSRVTT](https://github.com/crux82/msr-vtt-it). 28 | 29 | ``` 30 | ./data 31 | |─ nextqa 32 | | |─ train.csv 33 | | |─ val.csv 34 | | └─ clipvitl14.pth 35 | |─ star 36 | | : 37 | |─ tvqa 38 | | : 39 | └─ egos 40 | : 41 | ``` 42 | ## Model Preparation: 43 | Prepare the model as follows. 44 | 45 | **LLMs**: Download the pretrained Llama models from [Llama2](https://github.com/meta-llama/llama) and [Llama3](https://github.com/meta-llama/llama3). 46 | 47 | **TOPA Checkpoints**: Download our [pretrained models](https://drive.google.com/file/d/1-Ce6LC-1TeKvUbg_BeCWzsps6XBf-dlG/view?usp=sharing) 48 | ``` 49 | ./pretrained 50 | └─ llama2 51 | | |─ 7B 52 | | | |─ consolidated.00.pth 53 | | | └─ params.json 54 | | |─ 13B 55 | | | : 56 | | | : 57 | | └─ tokenizer.model 58 | └─ llama3 59 | |─ 8B 60 | | |─ consolidated.00.pth 61 | | └─ params.json 62 | └─ tokenizer.model 63 | 64 | ./vqa_checkpoint 65 | └─ checkpoint_pretrain 66 | |─ llama2_7b 67 | |─ llama2_13b 68 | └─ llama3_8b 69 | ``` 70 | 71 | ## Training & Evaluation 72 | ### Text-only Pre-alignment 73 | ``` 74 | ./scripts/pretrain/llama2_7b.sh 75 | ``` 76 | ### Zero-shot inference 77 | ``` 78 | ./scripts/eval/zeroshot_eval_egos.sh 79 | ./scripts/eval/zeroshot_eval_nextqa.sh 80 | ./scripts/eval/zeroshot_eval_star.sh 81 | ./scripts/eval/zeroshot_eval_tvqa.sh 82 | ``` 83 | ### Evaluate on MVBench 84 | [mvbench.ipynb](demos/mvbench.ipynb) 85 | 86 | ### Evaluate on video captioning benchmarks 87 | [MSRVTT.ipynb](demos/Eval_Cap_MSRVTT.ipynb) 88 | 89 | [VATEX.ipynb](demos/Eval_Cap_VATEX.ipynb) 90 | 91 | ## Acknowledgements 92 | This repo is built upon [Flipped-VQA](https://github.com/mlvlab/Flipped-VQA) and benefits from [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [DeCap](https://github.com/dhg-wei/DeCap), [MVBench](https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/MVBENCH.md), [Llama2](https://github.com/meta-llama/llama) and [Llama3](https://github.com/meta-llama/llama3). 93 | 94 | 95 | ## Citations 96 | 97 | ``` 98 | @article{li2024topa, 99 | title={TOPA: Extend Large Language Models for Video Understanding via Text-Only Pre-Alignment}, 100 | author={Li, Wei and Fan, Hehe and Wong, Yongkang and Kankanhalli, Mohan and Yang, Yi}, 101 | journal={arXiv preprint arXiv:2405.13911}, 102 | year={2024} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/data/.gitkeep -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util import misc 3 | from .nextqa import NextQA 4 | from .dramaqa import DramaQA 5 | from .star import STAR 6 | from .vlep import VLEP 7 | from .tvqa import TVQA 8 | from .textvid import TextVid 9 | from .egoschema import EgoSchema 10 | from .perception import PerceptionTest 11 | from .how2qa import How2qa 12 | from .webvid import Webvid 13 | 14 | from torch.utils.data import DataLoader, ConcatDataset 15 | 16 | dataset_mapping = {'nextqa': NextQA, 'star': STAR, 'dramaqa': DramaQA, 'vlep': VLEP, 'tvqa': TVQA,'textvid':TextVid,'egos':EgoSchema,'perc':PerceptionTest,'how2qa':How2qa,'webvid':Webvid} 17 | num_options_mapping = {'nextqa': 5, 'star': 4, 'dramaqa': 5, 'vlep': 2, 'tvqa': 5,'textvid': 5,'egos':5,'perc':3,'how2qa':4,'webvid':5} 18 | 19 | def load_data(args, tokenizer, split='train'): 20 | if split=='train' and args.textvid: 21 | args.num_options = num_options_mapping['textvid'] 22 | dataset = dataset_mapping['textvid'](args=args, tokenizer=tokenizer, split=split) 23 | else: 24 | args.num_options = num_options_mapping[args.dataset] 25 | dataset = dataset_mapping[args.dataset](args=args, tokenizer=tokenizer, split=split) 26 | 27 | num_tasks = misc.get_world_size() 28 | global_rank = misc.get_rank() 29 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True) 30 | data_loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=batch_collate, 31 | pin_memory=args.pin_mem, drop_last=False) 32 | 33 | return data_loader 34 | 35 | 36 | 37 | def load_data_instruct(args, tokenizer, split='train'): 38 | 39 | instruct_datasets = ['nextqa','star','tvqa'] 40 | data_loader_instruct=[] 41 | for dataset_name in instruct_datasets: 42 | args.dataset=dataset_name 43 | args.num_options = num_options_mapping[args.dataset] 44 | dataset = dataset_mapping[args.dataset](args=args, tokenizer=tokenizer, split=split) 45 | 46 | data_loader_instruct.append(dataset) 47 | 48 | dataset = ConcatDataset(data_loader_instruct) 49 | 50 | num_tasks = misc.get_world_size() 51 | global_rank = misc.get_rank() 52 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True) 53 | data_loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=batch_collate, 54 | pin_memory=args.pin_mem, drop_last=False) 55 | 56 | return data_loader 57 | 58 | def batch_collate(batch): 59 | q_index = None 60 | bs = len(batch) 61 | vid = [batch[i]["vid"] for i in range(bs)] 62 | video = torch.stack([batch[i]["video"] for i in range(bs)]) 63 | video_len = torch.tensor([batch[i]["video_len"] for i in range(bs)], dtype=torch.long) 64 | text = [batch[i]["text"] for i in range(bs)] 65 | qid = [batch[i]["qid"] for i in range(bs)] 66 | qtype = torch.tensor([batch[i]['qtype'] for i in range(bs)]) 67 | 68 | vqa_id = torch.stack([batch[i]['text_id']['vqa'] for i in range(bs)]) 69 | vaq_id = torch.stack([batch[i]['text_id']['vaq'] for i in range(bs)]) 70 | qav_id = torch.stack([batch[i]['text_id']['qav'] for i in range(bs)]) 71 | text_id = {'vqa': vqa_id, 'vaq': vaq_id, 'qav': qav_id} 72 | 73 | vqa_label = torch.stack([batch[i]['label']['vqa'] for i in range(bs)]) 74 | vaq_label = torch.stack([batch[i]['label']['vaq'] for i in range(bs)]) 75 | qav_label = torch.stack([batch[i]['label']['qav'] for i in range(bs)]) 76 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label} 77 | 78 | vqa_video_start = [batch[i]["video_start"]['vqa'] for i in range(bs)] 79 | vaq_video_start = [batch[i]["video_start"]['vaq'] for i in range(bs)] 80 | qav_video_start = [batch[i]["video_start"]['qav'] for i in range(bs)] 81 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_video_start} 82 | # q_index = [batch[i]["q_index"] for i in range(bs)] 83 | 84 | vqa_video_index = torch.stack([batch[i]["video_index"]['vqa'] for i in range(bs)]) 85 | vaq_video_index = torch.stack([batch[i]["video_index"]['vaq'] for i in range(bs)]) 86 | qav_video_index = torch.stack([batch[i]["video_index"]['qav'] for i in range(bs)]) 87 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index} 88 | 89 | vqa_label_mask = torch.stack([batch[i]["label_mask"]['vqa'] for i in range(bs)]) 90 | vaq_label_mask = torch.stack([batch[i]["label_mask"]['vaq'] for i in range(bs)]) 91 | qav_label_mask = torch.stack([batch[i]["label_mask"]['qav'] for i in range(bs)]) 92 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask} 93 | 94 | answer = torch.tensor([batch[i]["answer"] for i in range(bs)]) 95 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 96 | "video_index": video_index, "label_mask": label_mask, "qid": qid, "answer": answer, "qtype": qtype,"q_index":q_index} 97 | -------------------------------------------------------------------------------- /dataloader/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import copy 4 | 5 | class BaseDataset(Dataset): 6 | def __init__(self, args, tokenizer, split): 7 | self.args = args 8 | self.max_feats = args.max_feats 9 | self.tokenizer = tokenizer 10 | self.max_seq_len = args.max_seq_len 11 | self.split = split 12 | 13 | self.features_dim = 768 14 | 15 | def _get_padding_id(self, text_id): 16 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1 17 | 18 | for i, tid in enumerate(text_id): 19 | # print(len(tid)) 20 | padding = self.max_seq_len - len(tid) 21 | # print(padding) 22 | if padding >= 0: 23 | padding_text_id[i, :len(tid)] = tid 24 | else: 25 | padding_text_id[i] = tid[:self.max_seq_len] 26 | # raise Exception 27 | # print('max sequence length overflow') 28 | return padding_text_id 29 | 30 | def _get_text_token(self, text, answer): 31 | vqa_id, vqa_prefix_index, vqa_video_start = self.tokenizer.encode_vqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 32 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_vaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 33 | qav_id, qav_prefix_index = self.tokenizer.encode_qav(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 34 | 35 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id] 36 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id] 37 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id] 38 | 39 | vqa_padding_text_id = self._get_padding_id(vqa_id) 40 | vaq_padding_text_id = self._get_padding_id(vaq_id) 41 | qav_padding_text_id = self._get_padding_id(qav_id) 42 | 43 | # label 44 | vqa_label = copy.deepcopy(vqa_padding_text_id) 45 | vqa_label[:, :vqa_prefix_index] = -1 46 | vqa_label_mask = vqa_label.ge(0) 47 | vqa_label[~vqa_label_mask] = 0 48 | vqa_label_mask = vqa_label_mask.float() 49 | 50 | vaq_label = copy.deepcopy(vaq_padding_text_id) 51 | vaq_label[:, :vaq_prefix_index] = -1 52 | vaq_label_mask = vaq_label.ge(0) 53 | vaq_label[~vaq_label_mask] = 0 54 | vaq_label_mask = vaq_label_mask.float() 55 | 56 | qav_label = torch.ones_like(qav_padding_text_id) * -1 57 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats) 58 | qav_label_mask = torch.zeros_like(qav_padding_text_id) 59 | qav_label_mask[:, qav_prefix_index] = 1 60 | qav_label_mask = qav_label_mask.float() 61 | 62 | # text mask 63 | vqa_text_mask = vqa_padding_text_id.ge(0) 64 | vqa_padding_text_id[~vqa_text_mask] = 0 65 | vaq_text_mask = vaq_padding_text_id.ge(0) 66 | vaq_padding_text_id[~vaq_text_mask] = 0 67 | qav_text_mask = qav_padding_text_id.ge(0) 68 | qav_padding_text_id[~qav_text_mask] = 0 69 | 70 | # video index 71 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats) 72 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats) 73 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats) 74 | 75 | 76 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id} 77 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label} 78 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index} 79 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index} 80 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask} 81 | return text_id, label, video_start, video_index, label_mask 82 | 83 | 84 | def _get_caption_token(self, text, answer): 85 | 86 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_videocap(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 87 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id] 88 | 89 | vaq_padding_text_id = self._get_padding_id(vaq_id) 90 | 91 | # label 92 | 93 | 94 | vaq_label = copy.deepcopy(vaq_padding_text_id) 95 | vaq_label[:, :vaq_prefix_index] = -1 96 | vaq_label_mask = vaq_label.ge(0) 97 | vaq_label[~vaq_label_mask] = 0 98 | vaq_label_mask = vaq_label_mask.float() 99 | 100 | # text mask 101 | 102 | vaq_text_mask = vaq_padding_text_id.ge(0) 103 | vaq_padding_text_id[~vaq_text_mask] = 0 104 | 105 | # video index 106 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats) 107 | 108 | 109 | text_id = {'vaq': vaq_padding_text_id} 110 | label = {'vaq': vaq_label} 111 | video_start = {'vaq': vaq_video_start} 112 | video_index = {'vaq': vaq_video_index} 113 | label_mask = {'vaq': vaq_label_mask} 114 | return text_id, label, video_start, video_index, label_mask 115 | 116 | def _get_openvqa_token(self, text, answer): 117 | 118 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_openvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 119 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id] 120 | 121 | vaq_padding_text_id = self._get_padding_id(vaq_id) 122 | 123 | vaq_label = copy.deepcopy(vaq_padding_text_id) 124 | if type(vaq_prefix_index)==list: 125 | for i,index_ in enumerate(vaq_prefix_index): 126 | vaq_label[i, :index_] = -1 127 | else: 128 | vaq_label[:, :vaq_prefix_index] = -1 129 | vaq_label_mask = vaq_label.ge(0) 130 | vaq_label[~vaq_label_mask] = 0 131 | vaq_label_mask = vaq_label_mask.float() 132 | 133 | # text mask 134 | 135 | vaq_text_mask = vaq_padding_text_id.ge(0) 136 | vaq_padding_text_id[~vaq_text_mask] = 0 137 | 138 | # video index 139 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats) 140 | 141 | 142 | text_id = {'vaq': vaq_padding_text_id} 143 | label = {'vaq': vaq_label} 144 | video_start = {'vaq': vaq_video_start} 145 | video_index = {'vaq': vaq_video_index} 146 | label_mask = {'vaq': vaq_label_mask} 147 | return text_id, label, video_start, video_index, label_mask -------------------------------------------------------------------------------- /dataloader/dramaqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import json 4 | 5 | class DramaQA(BaseDataset): 6 | def __init__(self, args=None, tokenizer=None, split='train'): 7 | super().__init__(args, tokenizer, split) 8 | self.data = json.load(open(f'./data/dramaqa/AnotherMissOhQA_{split}_set.json', "r")) 9 | self.features = torch.load(f'./data/dramaqa/clipvitl14.pth') 10 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 11 | self.num_options = 5 12 | print(f"Num {split} data: {len(self.data)}") 13 | 14 | def _get_text(self, idx): 15 | question = self.data[idx]["que"].capitalize().strip() 16 | if question[-1] != "?": 17 | question = str(question) + "?" 18 | 19 | options = self.data[idx]['answers'] 20 | 21 | q_text = f"Question: {question}\n" 22 | o_text = "Choices: \n" 23 | for i in range(self.num_options): 24 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 25 | a_text = "Answer: The answer is " 26 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options} 27 | return text 28 | 29 | def _get_video(self, video_id , idx): 30 | 31 | scene = True 32 | # Scene 33 | if video_id[-4:] == '0000': 34 | shots = self.data[idx]['shot_contained'] 35 | start, end = shots[0], shots[1] 36 | 37 | for i in range(start, end+1): 38 | v_name = video_id[:-4] + f'{i:04}' 39 | 40 | if v_name not in self.features.keys(): 41 | print(v_name, " Not in features") 42 | nxt_vid = torch.zeros(1, self.features_dim) 43 | else: nxt_vid = self.features[v_name].float() 44 | 45 | if i == start: video = nxt_vid 46 | else: video = torch.concat((video, nxt_vid), dim = 0) 47 | # Shot 48 | else: 49 | scene = False 50 | if video_id not in self.features.keys(): 51 | print(video_id, "Not in freatures") 52 | video = torch.zeros(1, self.features_dim) 53 | else: 54 | video = self.features[video_id].float() 55 | 56 | if len(video) > self.max_feats: 57 | sampled = [] 58 | for j in range(self.max_feats): 59 | sampled.append(video[(j * len(video)) // self.max_feats]) 60 | video = torch.stack(sampled) 61 | video_len = self.max_feats 62 | elif len(video) < self.max_feats: 63 | video_len = len(video) 64 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0) 65 | else: 66 | video_len = self.max_feats 67 | 68 | return video, video_len, scene 69 | 70 | def __getitem__(self, idx): 71 | vid = self.data[idx]['vid'] 72 | qtype = -1 73 | answer = self.data[idx]['correct_idx'] 74 | text = self._get_text(idx) 75 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 76 | video, video_len, scene = self._get_video(f'{vid}', idx) 77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 78 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 79 | 80 | def __len__(self): 81 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/egoschema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import pickle 5 | import numpy as np 6 | import numpy 7 | import random 8 | 9 | class EgoSchema(BaseDataset): 10 | def __init__(self, args=None, tokenizer=None, split='train'): 11 | super().__init__(args, tokenizer, split) 12 | self.data = pd.read_csv(f'./data/egos/{split}.csv') 13 | 14 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 15 | self.num_options = 5 16 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 17 | 18 | self.features_path = './data/egos/features/' 19 | 20 | print(f"Num {split} data: {len(self.data)}") 21 | 22 | def _get_text(self, idx): 23 | question = self.data["question"].values[idx].capitalize().strip() 24 | if question[-1] != "?" and question[-1] != ".": 25 | question = str(question) + "?" 26 | 27 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 28 | 29 | q_text = f"Question: {question}\n" 30 | o_text = "Choices: \n" 31 | for i in range(self.num_options): 32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 33 | 34 | a_text = "Answer: The correct choice is " 35 | open_options = [f"Answer: {option}" for option in options] 36 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options} 37 | return text 38 | 39 | 40 | def _get_video(self, video): 41 | video=torch.from_numpy(video).float() 42 | video = video/video.norm(dim=-1,keepdim=True) 43 | 44 | if len(video) > self.max_feats: 45 | sampled = [] 46 | for j in range(self.max_feats): 47 | sampled.append(video[(j * len(video)) // self.max_feats]) 48 | video = torch.stack(sampled) 49 | video_len = self.max_feats 50 | elif len(video) < self.max_feats: 51 | video_len = len(video) 52 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 53 | else: 54 | video_len = self.max_feats 55 | if self.args.single_frame: 56 | video = video[::2] 57 | video=video.repeat_interleave(2,dim=0) 58 | return video, video_len 59 | 60 | def __getitem__(self, idx): 61 | while True: 62 | try: 63 | vid = self.data['uid'].values[idx] 64 | 65 | qtype = 1 66 | answer = self.data['answer'].values[idx] 67 | 68 | text = self._get_text(idx) 69 | 70 | 71 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 72 | q_index = len(self.tokenizer.sp_model.encode(text['q_text']))+1+video_start['vqa']+self.max_feats 73 | if self.args.openvqa_eval: 74 | text['oa_text'] = f"_" 75 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer) 76 | text_id.update(text_id_c) 77 | label.update(label_c) 78 | video_start.update(video_start_c) 79 | video_index.update(video_index_c) 80 | label_mask.update(label_mask_c) 81 | 82 | v_path = f'{self.features_path}{vid}.npy' 83 | with open(v_path,'rb') as f: 84 | video = numpy.load(f) 85 | video, video_len = self._get_video(video) 86 | 87 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype,"q_index": q_index} 88 | except: 89 | print(idx) 90 | idx = np.random.randint(0, len(self)-1) 91 | 92 | def __len__(self): 93 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/how2qa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import pickle 5 | import numpy as np 6 | import numpy 7 | class How2qa(BaseDataset): 8 | def __init__(self, args=None, tokenizer=None, split='train'): 9 | super().__init__(args, tokenizer, split) 10 | self.data = pd.read_csv(f'./data/how2qa/{split}.csv') 11 | 12 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 13 | self.num_options = 4 14 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 15 | 16 | self.features_path = '../../EgoSchema/benchmarking/FrozenBilm/features_how2qa/' 17 | 18 | print(f"Num {split} data: {len(self.data)}") 19 | 20 | def _get_text(self, idx): 21 | question = self.data["question"].values[idx].capitalize().strip() 22 | if question[-1] != "?": 23 | question = str(question) + "?" 24 | 25 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 26 | 27 | q_text = f"Question: {question}\n" 28 | o_text = "Choices: \n" 29 | for i in range(self.num_options): 30 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 31 | 32 | a_text = "Answer: The correct choice is " 33 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options} 34 | return text 35 | 36 | def _get_video(self, video): 37 | video=torch.from_numpy(video).float() 38 | video = video/video.norm(dim=-1,keepdim=True) 39 | # video = torch.zeros(1, self.features_dim) 40 | if len(video) > self.max_feats: 41 | sampled = [] 42 | for j in range(self.max_feats): 43 | sampled.append(video[(j * len(video)) // self.max_feats]) 44 | video = torch.stack(sampled) 45 | video_len = self.max_feats 46 | elif len(video) < self.max_feats: 47 | video_len = len(video) 48 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 49 | else: 50 | video_len = self.max_feats 51 | 52 | return video, video_len 53 | 54 | def __getitem__(self, idx): 55 | while True: 56 | try: 57 | # if True: 58 | vid = self.data['uid'].values[idx] 59 | video_id = self.data['video_id'].values[idx] 60 | qtype = 1 61 | answer = self.data['answer'].values[idx] 62 | text = self._get_text(idx) 63 | # print(text) 64 | # print(answer) 65 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 66 | 67 | v_path = f'{self.features_path}{video_id}_{vid}.npy' 68 | # v_path = f'{self.features_path}507441ee-3eb4-4dc6-bac2-26bec2b66380.npy' 69 | with open(v_path,'rb') as f: 70 | video = numpy.load(f) 71 | video, video_len = self._get_video(video) 72 | 73 | # print(label_mask) 74 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 75 | "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype} 76 | except: 77 | print(idx) 78 | idx = np.random.randint(0, len(self)-1) 79 | 80 | def __len__(self): 81 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/nextqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import pickle 5 | import json 6 | import numpy 7 | 8 | class NextQA(BaseDataset): 9 | def __init__(self, args=None, tokenizer=None, split='train'): 10 | super().__init__(args, tokenizer, split) 11 | self.split =split 12 | self.data = pd.read_csv(f'./data/nextqa/{split}.csv') 13 | 14 | 15 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 16 | self.num_options = 5 17 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 18 | 19 | 20 | self.features = torch.load(f'./data/{args.dataset}/clipvitl14.pth') 21 | print(f"Num {split} data: {len(self.data)}") 22 | if self.split=='train': 23 | self.train_ratio = int(len(self.data)*self.args.data_ratio) 24 | else: 25 | self.train_ratio = int(len(self.data)*1) 26 | 27 | def _get_text(self, idx): 28 | question = self.data["question"].values[idx].capitalize().strip() 29 | if question[-1] != "?": 30 | question = str(question) + "?" 31 | 32 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 33 | 34 | q_text = f"Question: {question}\n" 35 | o_text = "Choices: \n" 36 | for i in range(self.num_options): 37 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 38 | 39 | a_text = "Answer: The correct choice is " 40 | open_options = [f"\nAnswer: {option}" for option in options] 41 | 42 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options} 43 | return text 44 | 45 | 46 | def _get_video(self, video): 47 | video = video/video.norm(dim=-1,keepdim=True) 48 | # video = torch.zeros(1, self.features_dim) 49 | # video = video.repeat(,1) 50 | 51 | if len(video) > self.max_feats: 52 | sampled = [] 53 | for j in range(self.max_feats): 54 | sampled.append(video[(j * len(video)) // self.max_feats]) 55 | video = torch.stack(sampled) 56 | video_len = self.max_feats 57 | elif len(video) < self.max_feats: 58 | video_len = len(video) 59 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 60 | else: 61 | video_len = self.max_feats 62 | if self.args.single_frame: 63 | video = video[::2] 64 | video=video.repeat_interleave(2,dim=0) 65 | return video, video_len 66 | 67 | def __getitem__(self, idx): 68 | idx = idx%self.train_ratio 69 | vid = self.data['video'].values[idx] 70 | qid = self.data['qid'].values[idx] 71 | # print(vid) 72 | qtype = self.qtype_mapping[self.data['type'].values[idx]] 73 | answer = self.data['answer'].values[idx] 74 | text = self._get_text(idx) 75 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 76 | 77 | 78 | video = self.features[f'{vid}'].float() 79 | video, video_len = self._get_video(video) 80 | 81 | if self.args.openvqa_eval: 82 | text['oa_text'] = f"_" 83 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer) 84 | text_id.update(text_id_c) 85 | label.update(label_c) 86 | video_start.update(video_start_c) 87 | video_index.update(video_index_c) 88 | label_mask.update(label_mask_c) 89 | 90 | # print(label_mask) 91 | return {"vid": str(vid), "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 92 | "video_index": video_index, "label_mask": label_mask, "qid": qid, "answer": answer, "qtype": qtype} 93 | 94 | def __len__(self): 95 | if self.args.debug: 96 | return len(self.data[:2000]) 97 | return len(self.data[:]) 98 | -------------------------------------------------------------------------------- /dataloader/perception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import pickle 5 | import numpy as np 6 | import numpy 7 | 8 | class PerceptionTest(BaseDataset): 9 | def __init__(self, args=None, tokenizer=None, split='train'): 10 | super().__init__(args, tokenizer, split) 11 | self.data = pd.read_csv(f'./data/perception/{split}.csv') 12 | 13 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)'} 14 | self.num_options = 3 15 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 16 | self.features_path = '../../EgoSchema/benchmarking/FrozenBilm/features_perception_test_2fps/' 17 | 18 | 19 | print(f"Num {split} data: {len(self.data)}") 20 | 21 | def _get_text(self, idx): 22 | question = self.data["question"].values[idx].capitalize().strip() 23 | if question[-1] != "?": 24 | question = str(question) + "?" 25 | 26 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 27 | 28 | q_text = f"Question: {question}\n" 29 | o_text = "Choices: \n" 30 | for i in range(self.num_options): 31 | 32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 33 | 34 | a_text = "Answer: The correct choice is " 35 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options} 36 | return text 37 | 38 | def _get_video(self, video): 39 | video=torch.from_numpy(video).float() 40 | 41 | 42 | video = video/video.norm(dim=-1,keepdim=True) 43 | 44 | if len(video) > self.max_feats: 45 | sampled = [] 46 | for j in range(self.max_feats): 47 | sampled.append(video[(j * len(video)) // self.max_feats]) 48 | video = torch.stack(sampled) 49 | video_len = self.max_feats 50 | elif len(video) < self.max_feats: 51 | video_len = len(video) 52 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 53 | else: 54 | video_len = self.max_feats 55 | 56 | return video, video_len 57 | 58 | def __getitem__(self, idx): 59 | while True: 60 | try: 61 | vid = self.data['uid'].values[idx] 62 | # print(vid) 63 | qtype = 1 64 | answer = self.data['answer'].values[idx] 65 | text = self._get_text(idx) 66 | # print(text) 67 | # print(answer) 68 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 69 | 70 | v_path = f'{self.features_path}{vid}.npy' 71 | 72 | with open(v_path,'rb') as f: 73 | video = numpy.load(f) 74 | video, video_len = self._get_video(video) 75 | 76 | # print(label_mask) 77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 78 | "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype} 79 | except: 80 | print(idx) 81 | idx = np.random.randint(0, len(self)-1) 82 | 83 | def __len__(self): 84 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/star.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import json 4 | 5 | class STAR(BaseDataset): 6 | def __init__(self, args=None, tokenizer=None, split='train'): 7 | super().__init__(args, tokenizer, split) 8 | self.data = json.load(open(f'./data/star/STAR_{split}.json', 'r')) 9 | self.features = torch.load(f'./data/star/clipvitl14.pth') 10 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)'} 11 | self.qtype_mapping = {'Interaction': 1, 'Sequence': 2, 'Prediction': 3, 'Feasibility': 4} 12 | self.num_options = 4 13 | self.split = split 14 | print(f"Num {split} data: {len(self.data)}") 15 | if split=='train': 16 | self.train_ratio = int(len(self.data)*self.args.data_ratio) 17 | else: 18 | self.train_ratio = int(len(self.data)) 19 | 20 | def _get_text(self, idx): 21 | question = self.data[idx]["question"].capitalize().strip() 22 | if question[-1] != "?": 23 | question = str(question) + "?" 24 | 25 | options = {x['choice_id']: x['choice'] for x in self.data[idx]['choices']} 26 | options = [options[i] for i in range(self.num_options)] 27 | answer = options.index(self.data[idx]['answer']) 28 | 29 | q_text = f"Question: {question}\n" 30 | o_text = "Choices: \n" 31 | for i in range(self.num_options): 32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 33 | a_text = "Answer: The answer is " 34 | open_options = [f"Answer: {option}" for option in options] 35 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options} 36 | return text, answer 37 | 38 | def _get_video(self, video_id, start, end): 39 | if video_id not in self.features: 40 | print(video_id) 41 | video = torch.zeros(1, self.features_dim) 42 | else: 43 | video = self.features[video_id][start: end +1, :].float() # ts 44 | video = video/video.norm(dim=-1,keepdim=True) 45 | if len(video) > self.max_feats: 46 | sampled = [] 47 | for j in range(self.max_feats): 48 | sampled.append(video[(j * len(video)) // self.max_feats]) 49 | video = torch.stack(sampled) 50 | video_len = self.max_feats 51 | elif len(video) < self.max_feats: 52 | video_len = len(video) 53 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0) 54 | else: 55 | video_len = self.max_feats 56 | 57 | return video, video_len 58 | 59 | def __getitem__(self, idx): 60 | idx = idx%self.train_ratio 61 | vid = self.data[idx]['video_id'] 62 | qtype = self.qtype_mapping[self.data[idx]['question_id'].split('_')[0]] 63 | text, answer = self._get_text(idx) 64 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 65 | start, end = round(self.data[idx]['start']), round(self.data[idx]['end']) 66 | video, video_len = self._get_video(f'{vid}', start, end) 67 | 68 | if self.args.openvqa_eval: 69 | text['oa_text'] = f"_" 70 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer) 71 | text_id.update(text_id_c) 72 | label.update(label_c) 73 | video_start.update(video_start_c) 74 | video_index.update(video_index_c) 75 | label_mask.update(label_mask_c) 76 | 77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 78 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 79 | 80 | 81 | def __len__(self): 82 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/textvid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import json 5 | import random 6 | import pickle 7 | import numpy as np 8 | import math 9 | import pickle 10 | def noise_injection(x, variance=0.001, modality_offset=None, uniform_noise=False, dont_norm=False): 11 | if variance == 0.0: 12 | return x 13 | std = math.sqrt(variance) 14 | if not dont_norm: 15 | x = torch.nn.functional.normalize(x, dim=1) 16 | 17 | x = x + (torch.randn(x.shape, device=x.device) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim 18 | noise = (torch.randn(x.shape, device=x.device) * std) 19 | noise/=noise.norm(dim=-1,keepdim=True) 20 | x = x+noise*std 21 | if modality_offset is not None: 22 | x = x + modality_offset 23 | return torch.nn.functional.normalize(x, dim=-1) 24 | 25 | class TextVid(BaseDataset): 26 | def __init__(self, args=None, tokenizer=None, split='train'): 27 | super().__init__(args, tokenizer, split) 28 | 29 | with open('./data/textvid/textvid.json','r') as f: 30 | self.data = json.load(f) 31 | self.feature_path = './data/textvid/features' 32 | # with open('./data/textvid/feature_small.pkl','rb') as f: 33 | # self.feature = pickle.load(f) 34 | self.letter2number = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4} 35 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 36 | self.num_options = 5 37 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 38 | print(f"Num {split} data: {len(self.data)}") 39 | 40 | def _get_text(self, question,options): 41 | question = question.capitalize().strip() 42 | if question[-1] != "?": 43 | question = str(question) + "?" 44 | 45 | # options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 46 | 47 | q_text = f"Question: {question}\n" 48 | o_text = "Choices: \n" 49 | for i in range(len(options)): 50 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 51 | 52 | a_text = "Answer: The correct choice is " 53 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options} 54 | return text 55 | 56 | def _get_video(self, video): 57 | video=video.float() 58 | if len(video) > self.max_feats: 59 | sampled = [] 60 | for j in range(self.max_feats): 61 | sampled.append(video[(j * len(video)) // self.max_feats]) 62 | video = torch.stack(sampled) 63 | video_len = self.max_feats 64 | elif len(video) < self.max_feats: 65 | video_len = len(video) 66 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 67 | 68 | else: 69 | video_len = self.max_feats 70 | 71 | return video, video_len 72 | 73 | def _get_frame_feature(self, features,frames): 74 | start = 0 75 | video = [] 76 | features = features/features.norm(dim=-1,keepdim=True) 77 | 78 | for frame in frames: 79 | feature = features[start:start+len(list(frame.keys()))] 80 | feature = feature.mean(dim=0,keepdim=True) 81 | feature = feature/feature.norm(dim=-1,keepdim=True) 82 | video.append(feature) 83 | start+=len(list(frame.keys())) 84 | video = torch.cat(video,dim=0) 85 | video = noise_injection(video, variance=self.args.variance) 86 | return video 87 | 88 | 89 | def __getitem__(self, idx): 90 | while True: 91 | try: 92 | 93 | video_meta = self.data[idx] 94 | 95 | vid = video_meta['idx'] 96 | with open(f'{self.feature_path}/{vid}.pkl','rb') as f: 97 | features = pickle.load(f) 98 | # features = self.feature[vid] 99 | video = self._get_frame_feature(features,video_meta['frames']) 100 | 101 | qtype = 1 102 | qa = random.choice(video_meta['QAs']) 103 | question = qa['question'] 104 | 105 | # if question.strip().lower().split(' ')[0] =='what': 106 | # if random.random()>0.5: 107 | # qa = random.choice(video_meta['QAs']) 108 | # question = qa['question'] 109 | # if question.strip().lower().split(' ')[0] =='what': 110 | # raise 111 | 112 | answer = qa['answer'] 113 | answer = self.letter2number[answer] 114 | options = qa['options'] 115 | answer_text = options[answer] 116 | if self.args.answer_balance: 117 | if not (("both" in answer_text.lower()) or ("all" in answer_text.lower()) or ('none' in answer_text.lower()) or ('and' in answer_text.lower())): 118 | random.shuffle(options) 119 | answer = options.index(answer_text) 120 | 121 | text = self._get_text(question,options) 122 | 123 | 124 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 125 | video, video_len = self._get_video(video) 126 | 127 | if self.args.video_caption: 128 | caption = video_meta['global_video_caption'] 129 | text['c_text'] = f"Description: {caption}\n" 130 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_caption_token(text, answer) 131 | text_id.update(text_id_c) 132 | label.update(label_c) 133 | video_start.update(video_start_c) 134 | video_index.update(video_index_c) 135 | label_mask.update(label_mask_c) 136 | 137 | if self.args.openvqa and (random.random()>0.5): 138 | if 'answer_open_ended' in qa.keys(): 139 | if ('both' not in answer_text.lower()) and ('all' not in answer_text.lower()): 140 | answer_open_ended = qa['answer_open_ended'].strip() 141 | text['oa_text'] = f"Answer: {answer_open_ended}\n" 142 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer) 143 | text_id.update(text_id_c) 144 | label.update(label_c) 145 | video_start.update(video_start_c) 146 | video_index.update(video_index_c) 147 | label_mask.update(label_mask_c) 148 | 149 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 150 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 151 | 152 | except: 153 | idx = np.random.randint(0, len(self)-1) 154 | # print(f'Error reading {idx}') 155 | 156 | def __len__(self): 157 | if self.args.debug: 158 | return len(self.data[:10000]) 159 | else: 160 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/tvqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import json 4 | import copy 5 | import pysrt 6 | import numpy as np 7 | class TVQA(BaseDataset): 8 | def __init__(self, args=None, tokenizer=None, split='train'): 9 | super().__init__(args, tokenizer, split) 10 | self.split=split 11 | json_path = f'./data/tvqa/tvqa_{split}.jsonl' 12 | feature_path = f'./data/tvqa/clipvitl14.pth' 13 | 14 | with open(json_path, "r") as f: 15 | data_list = list(f) 16 | self.data = [json.loads(x) for x in data_list] 17 | self.features = torch.load(feature_path) 18 | self.subtitle_path = f'./data/tvqa/tvqa_subtitles/' # provided as castle_s01e01_seg02_clip_00.srt 19 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3 : '(D)', 4: '(E)'} 20 | self.num_options = 5 21 | self.sub = args.sub 22 | print(f"Num {split} data: {len(self.data)}") 23 | if split=='train': 24 | self.train_ratio = int(len(self.data)*self.args.data_ratio) 25 | else: 26 | self.train_ratio = int(len(self.data)) 27 | 28 | def _get_text(self, idx, choices, vid, start, end): 29 | question = self.data[idx]["q"].capitalize().strip() 30 | if question[-1] != "?": 31 | question = str(question) + "?" 32 | 33 | if self.sub: 34 | dialogue = '' 35 | 36 | for t in pysrt.open(self.subtitle_path+f'{vid}'+'.srt'): 37 | txt = t.text.replace('\n', ' ') 38 | st = t.start.minutes * 60 + t.start.seconds 39 | et = t.end.minutes * 60 + t.end.seconds 40 | if (st >= start and et <= end) or (st <= start and et <= end and start <= et): 41 | dialogue += ' ' + txt 42 | 43 | if dialogue != '': d_text = f"Dialogue: {dialogue}\n" 44 | else: d_text = '' 45 | 46 | else: 47 | d_text = "" 48 | 49 | q_text = f"Question: {question}\n" 50 | o_text = f"Choices: \n" 51 | 52 | assert len(choices) == self.num_options, "Double check number of choices" 53 | for i, option in enumerate(choices): 54 | o_text += f"{self.answer_mapping[i]} {option}\n" 55 | 56 | a_text = f"Answer: The correct choice is " 57 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'd_text': d_text} 58 | return text 59 | 60 | def _get_video(self, video_id, start, end): 61 | if video_id not in self.features: 62 | print(video_id) 63 | video = torch.zeros(1, self.features_dim) 64 | else: 65 | video = self.features[video_id][start * 3: (end + 1) * 3, :].float() # 3fps 66 | video = video/video.norm(dim=-1,keepdim=True) 67 | if len(video) > self.max_feats: 68 | sampled = [] 69 | for j in range(self.max_feats): 70 | sampled.append(video[(j * len(video)) // self.max_feats]) 71 | video = torch.stack(sampled) 72 | video_len = self.max_feats 73 | elif len(video) < self.max_feats: 74 | video_len = len(video) 75 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0) 76 | else: 77 | video_len = self.max_feats 78 | 79 | return video, video_len 80 | 81 | def _get_padding_id(self, text_id, prefix_index, prefix_i, prefix_main, type): 82 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1 83 | 84 | prefix = prefix_index 85 | for i, tid in enumerate(text_id): 86 | padding = self.max_seq_len - len(tid) 87 | # print(padding) 88 | if padding >= 0: 89 | padding_text_id[i, :len(tid)] = tid 90 | prefix = prefix_index 91 | else: 92 | if self.sub and prefix_i != prefix_main: 93 | pad = self.max_seq_len - ((prefix_i) + (len(tid) - prefix_main)) 94 | padding_text_id[i, :prefix_i] = tid[:prefix_i] 95 | padding_text_id[i, prefix_i: prefix_i + pad] = tid[prefix_i: prefix_i + pad] 96 | padding_text_id[i, prefix_i + pad :] = tid[prefix_main:] 97 | 98 | if type == "vqa": 99 | prefix = len(padding_text_id[i]) - 4 100 | elif type == "vaq": 101 | if self.split == "train": 102 | try: 103 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1].item() + 2 104 | except: 105 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2 106 | else: 107 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2 108 | else: 109 | prefix = len(padding_text_id[i]) - self.max_feats - 1 110 | else: 111 | padding_text_id[i] = tid[:self.max_seq_len] 112 | prefix = prefix_index 113 | # print('max sequence length overflow') 114 | 115 | return padding_text_id, prefix 116 | 117 | def _get_text_token(self, text, answer): 118 | vqa_id, vqa_prefix_index, vqa_video_start, vqa_prefix_i, vqa_prefix_q = self.tokenizer.encode_dvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 119 | vaq_id, vaq_prefix_index, vaq_video_start, vaq_prefix_i, vaq_prefix_q = self.tokenizer.encode_dvaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 120 | qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q = self.tokenizer.encode_dqav(text=text, max_feats=self.max_feats, max_seq_len=self.max_seq_len, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 121 | 122 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id] 123 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id] 124 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id] 125 | 126 | vqa_padding_text_id, vqa_prefix_index = self._get_padding_id(vqa_id, vqa_prefix_index, vqa_prefix_i, vqa_prefix_q, "vqa") 127 | vaq_padding_text_id, vaq_prefix_index = self._get_padding_id(vaq_id, vaq_prefix_index, vaq_prefix_i, vaq_prefix_q, "vaq") 128 | qav_padding_text_id, qav_prefix_index = self._get_padding_id(qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q, "qav") 129 | 130 | # label 131 | vqa_label = copy.deepcopy(vqa_padding_text_id) 132 | vqa_label[:, :vqa_prefix_index] = -1 133 | vqa_label_mask = vqa_label.ge(0) 134 | vqa_label[~vqa_label_mask] = 0 135 | vqa_label_mask = vqa_label_mask.float() 136 | 137 | vaq_label = copy.deepcopy(vaq_padding_text_id) 138 | vaq_label[:, :vaq_prefix_index] = -1 139 | vaq_label_mask = vaq_label.ge(0) 140 | vaq_label[~vaq_label_mask] = 0 141 | vaq_label_mask = vaq_label_mask.float() 142 | 143 | qav_label = torch.ones_like(qav_padding_text_id) * -1 144 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats) 145 | qav_label_mask = torch.zeros_like(qav_padding_text_id) 146 | qav_label_mask[:, qav_prefix_index] = 1 147 | qav_label_mask = qav_label_mask.float() 148 | 149 | # text mask 150 | vqa_text_mask = vqa_padding_text_id.ge(0) 151 | vqa_padding_text_id[~vqa_text_mask] = 0 152 | vaq_text_mask = vaq_padding_text_id.ge(0) 153 | vaq_padding_text_id[~vaq_text_mask] = 0 154 | qav_text_mask = qav_padding_text_id.ge(0) 155 | qav_padding_text_id[~qav_text_mask] = 0 156 | 157 | # video index 158 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats) 159 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats) 160 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats) 161 | 162 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id} 163 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label} 164 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index} 165 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index} 166 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask} 167 | return text_id, label, video_start, video_index, label_mask 168 | 169 | def __getitem__(self, idx): 170 | while True: 171 | try: 172 | idx = idx%self.train_ratio 173 | vid = self.data[idx]['vid_name'] 174 | qtype = -1 175 | choices = [ self.data[idx][f'a{i}'] for i in range(self.num_options)] 176 | answer = self.data[idx]['answer_idx'] 177 | 178 | start, end = map(float, self.data[idx]['ts'].split('-')) 179 | try: 180 | start, end = round(start), round(end) 181 | except: 182 | start, end = -1000, 1000 183 | 184 | video, video_len = self._get_video(f'{vid}', start, end) 185 | text = self._get_text(idx, choices, f'{vid}', start, end) 186 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 187 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 188 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 189 | except: 190 | idx = np.random.randint(0, len(self)-1) 191 | 192 | 193 | def __len__(self): 194 | 195 | return len(self.data[:]) 196 | -------------------------------------------------------------------------------- /dataloader/vlep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import json 4 | import copy 5 | import numpy as np 6 | class VLEP(BaseDataset): 7 | def __init__(self, args=None, tokenizer=None, split='train'): 8 | super().__init__(args, tokenizer, split) 9 | if split == 'val': 10 | json_path = f'./data/vlep/vlep_dev_release.jsonl' 11 | else: 12 | json_path = f'./data/vlep/vlep_{split}_release.jsonl' 13 | feature_path = f'./data/vlep/clipvitl14.pth' 14 | sub_path = f'./data/vlep/vlep_subtitles.jsonl' 15 | 16 | with open(json_path, "r") as f: 17 | data_list = list(f) 18 | with open(sub_path, "r") as s: 19 | sub_list = list(s) 20 | self.data = [json.loads(x) for x in data_list] 21 | self.subtitle = [json.loads(x) for x in sub_list] 22 | self.features = torch.load(feature_path) 23 | self.answer_mapping = {0: '(A)', 1: '(B)'} 24 | self.num_options = 2 25 | self.sub = args.sub 26 | print(f"Num {split} data: {len(self.data)}") 27 | 28 | def _get_text(self, choices, vid, start, end): 29 | question = "Which event is more likely to happen right after?".capitalize().strip() 30 | 31 | if self.sub: 32 | text = [x['sub'] for x in self.subtitle if x['vid_name'] == vid][0] 33 | dialogue = '' 34 | for txt in text: 35 | s, e, t = round(int(txt['start'])), int(txt['end']), txt['text'].replace('-', '') 36 | if (s >= start and e <= end) or (s <= start and e <= end and start <= e): 37 | dialogue+= t 38 | d_text = f"Dialogue: {dialogue}\n" # subtitles 39 | else: 40 | d_text = "" 41 | 42 | q_text = f"Question: {question}\n" 43 | o_text = f"Choices: \n" 44 | 45 | assert len(choices) == self.num_options, "Double check number of choices" 46 | for i, option in enumerate(choices): 47 | o_text += f"{self.answer_mapping[i]} {option}\n" 48 | 49 | a_text = f"Answer: The answer is " 50 | 51 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'd_text': d_text} 52 | return text 53 | 54 | def _get_video(self, video_id, start, end): 55 | if video_id not in self.features: 56 | print(video_id) 57 | video = torch.zeros(1, self.features_dim) 58 | else: 59 | video = self.features[video_id][start: end +1, :].float() 60 | video = video/video.norm(dim=-1,keepdim=True) 61 | 62 | if len(video) > self.max_feats: 63 | sampled = [] 64 | for j in range(self.max_feats): 65 | sampled.append(video[(j * len(video)) // self.max_feats]) 66 | video = torch.stack(sampled) 67 | video_len = self.max_feats 68 | elif len(video) < self.max_feats: 69 | video_len = len(video) 70 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0) 71 | else: 72 | video_len = self.max_feats 73 | return video, video_len 74 | 75 | def _get_padding_id(self, text_id, prefix_index, prefix_i, prefix_main, type): 76 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1 77 | 78 | prefix = prefix_index 79 | for i, tid in enumerate(text_id): 80 | padding = self.max_seq_len - len(tid) 81 | if padding >= 0: 82 | padding_text_id[i, :len(tid)] = tid 83 | prefix = prefix_index 84 | else: 85 | if self.sub and prefix_i != prefix_main: 86 | pad = self.max_seq_len - ((prefix_i) + (len(tid) - prefix_main)) 87 | padding_text_id[i, :prefix_i] = tid[:prefix_i] 88 | padding_text_id[i, prefix_i: prefix_i + pad] = tid[prefix_i: prefix_i + pad] 89 | padding_text_id[i, prefix_i + pad :] = tid[prefix_main:] 90 | 91 | if type == "vqa": 92 | prefix = len(padding_text_id[i]) - 4 93 | elif type == "vaq": 94 | if self.split == "train": 95 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1].item() + 2 96 | else: 97 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2 98 | else: 99 | prefix = len(padding_text_id[i]) - self.max_feats - 1 100 | else: 101 | padding_text_id[i] = tid[:self.max_seq_len] 102 | prefix = prefix_index 103 | return padding_text_id, prefix 104 | 105 | 106 | def _get_text_token(self, text, answer): 107 | vqa_id, vqa_prefix_index, vqa_video_start, vqa_prefix_i, vqa_prefix_q = self.tokenizer.encode_dvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 108 | vaq_id, vaq_prefix_index, vaq_video_start, vaq_prefix_i, vaq_prefix_q = self.tokenizer.encode_dvaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 109 | qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q = self.tokenizer.encode_dqav(text=text, max_feats=self.max_feats, max_seq_len=self.max_seq_len, split=self.split, answer_mapping=self.answer_mapping, answer=answer) 110 | 111 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id] 112 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id] 113 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id] 114 | 115 | vqa_padding_text_id, vqa_prefix_index = self._get_padding_id(vqa_id, vqa_prefix_index, vqa_prefix_i, vqa_prefix_q, "vqa") 116 | vaq_padding_text_id, vaq_prefix_index = self._get_padding_id(vaq_id, vaq_prefix_index, vaq_prefix_i, vaq_prefix_q, "vaq") 117 | qav_padding_text_id, qav_prefix_index = self._get_padding_id(qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q, "qav") 118 | 119 | # label 120 | vqa_label = copy.deepcopy(vqa_padding_text_id) 121 | vqa_label[:, :vqa_prefix_index] = -1 122 | vqa_label_mask = vqa_label.ge(0) 123 | vqa_label[~vqa_label_mask] = 0 124 | vqa_label_mask = vqa_label_mask.float() 125 | 126 | vaq_label = copy.deepcopy(vaq_padding_text_id) 127 | vaq_label[:, :vaq_prefix_index] = -1 128 | vaq_label_mask = vaq_label.ge(0) 129 | vaq_label[~vaq_label_mask] = 0 130 | vaq_label_mask = vaq_label_mask.float() 131 | 132 | qav_label = torch.ones_like(qav_padding_text_id) * -1 133 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats) 134 | qav_label_mask = torch.zeros_like(qav_padding_text_id) 135 | qav_label_mask[:, qav_prefix_index] = 1 136 | qav_label_mask = qav_label_mask.float() 137 | 138 | # text mask 139 | vqa_text_mask = vqa_padding_text_id.ge(0) 140 | vqa_padding_text_id[~vqa_text_mask] = 0 141 | vaq_text_mask = vaq_padding_text_id.ge(0) 142 | vaq_padding_text_id[~vaq_text_mask] = 0 143 | qav_text_mask = qav_padding_text_id.ge(0) 144 | qav_padding_text_id[~qav_text_mask] = 0 145 | 146 | # video index 147 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats) 148 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats) 149 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats) 150 | 151 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id} 152 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label} 153 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index} 154 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index} 155 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask} 156 | return text_id, label, video_start, video_index, label_mask 157 | 158 | 159 | def __getitem__(self, idx): 160 | while True: 161 | try: 162 | vid = self.data[idx]['vid_name'] 163 | qtype = -1 164 | choices = self.data[idx]['events'] 165 | answer = self.data[idx]['answer'] 166 | ts = self.data[idx]['ts'] 167 | start, end = round(ts[0]), round(ts[1]) 168 | video, video_len = self._get_video(f'{vid}', start, end) 169 | text = self._get_text(choices, f'{vid}', start, end) 170 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 171 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 172 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 173 | except: 174 | print(idx) 175 | idx = np.random.randint(0, len(self)-1) 176 | 177 | def __len__(self): 178 | return len(self.data) -------------------------------------------------------------------------------- /dataloader/webvid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_dataset import BaseDataset 3 | import pandas as pd 4 | import json 5 | import random 6 | import pickle 7 | import numpy as np 8 | import math 9 | import pickle 10 | def noise_injection(x, variance=0.001, modality_offset=None, uniform_noise=False, dont_norm=False): 11 | if variance == 0.0: 12 | return x 13 | std = math.sqrt(variance) 14 | if not dont_norm: 15 | x = torch.nn.functional.normalize(x, dim=1) 16 | # if uniform_noise: 17 | # x = x + get_uniform_ball_noise(x.shape, radius=std) 18 | # else: 19 | x = x + (torch.randn(x.shape, device=x.device) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim 20 | noise = (torch.randn(x.shape, device=x.device) * std) 21 | noise/=noise.norm(dim=-1,keepdim=True) 22 | x = x+noise*std 23 | if modality_offset is not None: 24 | x = x + modality_offset 25 | return torch.nn.functional.normalize(x, dim=-1) 26 | 27 | class Webvid(BaseDataset): 28 | def __init__(self, args=None, tokenizer=None, split='train'): 29 | super().__init__(args, tokenizer, split) 30 | 31 | with open('./data/webvid/train.json','r') as f: 32 | self.data = json.load(f) 33 | with open('./data/webvid/features.pkl','rb') as f: 34 | self.feature = pickle.load(f) 35 | self.letter2number = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4} 36 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'} 37 | self.num_options = 5 38 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 39 | print(f"Num {split} data: {len(self.data)}") 40 | 41 | def _get_text(self, question,options): 42 | question = question.capitalize().strip() 43 | if question[-1] != "?": 44 | question = str(question) + "?" 45 | 46 | # options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)] 47 | 48 | q_text = f"Question: {question}\n" 49 | o_text = "Choices: \n" 50 | for i in range(self.num_options): 51 | o_text += f"{self.answer_mapping[i]} {options[i]}\n" 52 | 53 | a_text = "Answer: The correct choice is " 54 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options} 55 | return text 56 | 57 | def _get_video(self, video): 58 | video=video.float() 59 | if len(video) > self.max_feats: 60 | sampled = [] 61 | for j in range(self.max_feats): 62 | sampled.append(video[(j * len(video)) // self.max_feats]) 63 | video = torch.stack(sampled) 64 | video_len = self.max_feats 65 | elif len(video) < self.max_feats: 66 | video_len = len(video) 67 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0) 68 | 69 | # elif len(video) < self.max_feats: 70 | # video_len = len(video) 71 | # filled_tensor = torch.zeros(self.max_feats, self.features_dim) 72 | 73 | # indices = random.sample(range(self.max_feats),video_len) 74 | # indices.sort() 75 | # for i,indice in enumerate(indices): 76 | # filled_tensor[indice] = video[i] 77 | # video_len = self.max_feats 78 | # video = filled_tensor 79 | else: 80 | video_len = self.max_feats 81 | # print(video_len) 82 | return video, video_len 83 | 84 | def _get_frame_feature(self, features,frames): 85 | start = 0 86 | video = [] 87 | features = features/features.norm(dim=-1,keepdim=True) 88 | 89 | for frame in frames: 90 | feature = features[start:start+len(list(frame.keys()))] 91 | feature = feature.mean(dim=0,keepdim=True) 92 | features = features/features.norm(dim=-1,keepdim=True) 93 | video.append(feature) 94 | start+=len(list(frame.keys())) 95 | video = torch.cat(video,dim=0) 96 | video = noise_injection(video, variance=self.args.variance) 97 | return video 98 | 99 | 100 | def __getitem__(self, idx): 101 | # while True: 102 | # try: 103 | 104 | video_meta = self.data[idx] 105 | 106 | vid = video_meta['feature_idx'] 107 | video = self.feature[vid] 108 | # video = self._get_frame_feature(features,video_meta['frames']) 109 | qtype = 1 110 | 111 | video, video_len = self._get_video(video) 112 | 113 | 114 | caption = video_meta['caption'] 115 | question = "What does this video show?" 116 | options=['1','1','1','1','1'] 117 | answer=0 118 | text = self._get_text(question,options) 119 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer) 120 | 121 | text['c_text'] = f"Description: {caption}." 122 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_caption_token(text, answer) 123 | text_id.update(text_id_c) 124 | label.update(label_c) 125 | video_start.update(video_start_c) 126 | video_index.update(video_index_c) 127 | label_mask.update(label_mask_c) 128 | 129 | 130 | 131 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, 132 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype} 133 | 134 | # except: 135 | # idx = np.random.randint(0, len(self)-1) 136 | # print(f'Error reading {idx}') 137 | 138 | def __len__(self): 139 | if self.args.debug: 140 | return len(self.data[:10000]) 141 | else: 142 | return len(self.data) -------------------------------------------------------------------------------- /demos/13B_msrvtt_results.json: -------------------------------------------------------------------------------- 1 | {"Bleu_1": 0.6634161036124164, "Bleu_2": 0.4809019909504734, "Bleu_3": 0.3141303218433441, "Bleu_4": 0.2110267573252259, "METEOR": 0.2179849241732958, "ROUGE_L": 0.4981662383625836, "CIDEr": 0.33443181238316033, "SPICE": 0.05282364436867422} -------------------------------------------------------------------------------- /demos/7B_msrvtt_results.json: -------------------------------------------------------------------------------- 1 | {"Bleu_1": 0.654746090605047, "Bleu_2": 0.4820976690148554, "Bleu_3": 0.3151958750789693, "Bleu_4": 0.2078218649347474, "METEOR": 0.21875795472104748, "ROUGE_L": 0.5030358497611536, "CIDEr": 0.32867167801902625, "SPICE": 0.050667287735789934} -------------------------------------------------------------------------------- /demos/Eval_Cap_MSRVTT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ebdbbd25-038d-4987-ac84-d5adbdcd40ad", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import argparse\n", 11 | "import json\n", 12 | "import torch\n", 13 | "import json\n", 14 | "import os" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "460fe558-fca6-4906-93f3-facd215d9ab4", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def save_arguments(args, filepath):\n", 25 | " with open(filepath, 'w') as file:\n", 26 | " json.dump(vars(args), file)\n", 27 | "\n", 28 | "def load_arguments(filepath):\n", 29 | " with open(filepath, 'r') as file:\n", 30 | " args_dict = json.load(file)\n", 31 | " return args_dict\n", 32 | "\n", 33 | "# Optionally, repopulate argparse.ArgumentParser with these arguments\n", 34 | "def repopulate_arguments(args_dict):\n", 35 | " parser = argparse.ArgumentParser(description=\"Example script\")\n", 36 | " for key, value in args_dict.items():\n", 37 | " parser.add_argument(f'--{key}', type=type(value),default=value)\n", 38 | " return parser.parse_args([])" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "c9c50017-a3cf-4fa3-8a7f-2e24c501a1a0", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "path = '../vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips'\n", 49 | "\n", 50 | "loaded_args = load_arguments(path+'/args.json')\n", 51 | "\n", 52 | "args = repopulate_arguments(loaded_args)\n", 53 | "args.llama_model_path = '.' +args.llama_model_path\n", 54 | "args.resume='../vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth'" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "074499b9-b268-4a4e-b05f-9e0f6d5a189f", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# path = '../vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips'\n", 65 | "\n", 66 | "# loaded_args = load_arguments(path+'/argsa.json')\n", 67 | "\n", 68 | "# args = repopulate_arguments(loaded_args)\n", 69 | "# args.llama_model_path = '.' +args.llama_model_path\n", 70 | "# args.resume='../vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth'" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "id": "98596d9e-0760-4eef-995f-7e2d0e47d33f", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stderr", 81 | "output_type": "stream", 82 | "text": [ 83 | "/home/users/nus/idmwyk/scratch/anaconda3/envs/llama/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 84 | " from .autonotebook import tqdm as notebook_tqdm\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "import sys\n", 90 | "sys.path.append('../')\n", 91 | "from llama import Tokenizer\n", 92 | "from llama_vqa import LLaMA_VQA\n", 93 | "from dataloader import load_data" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "id": "9272202e-d5ed-4645-8411-98463087b6ca", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "checkpoint = torch.load(args.resume, map_location='cpu')\n", 104 | "model.load_state_dict(checkpoint['model'], strict=False)\n", 105 | "tokenizer = Tokenizer(model_path=f'{args.llama_model_path}./tokenizer.model')" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "c2b12971-d47f-4aef-ba77-4d1d163160d8", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "1e125471-d745-4f57-b51a-b311df9de478", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "id": "80ba3180-0ce1-4f7c-817f-79aa02c0ccc5", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def decoding(model, tokenizer, prompt1,prompt2,video=None):\n", 132 | " adapter = model.adapter_query.weight.reshape(-1, model.adapter_len, model.params.dim).unsqueeze(1)\n", 133 | " freqs= model.freqs_cis.cuda()\n", 134 | " \n", 135 | " tokens = [tokenizer.bos_id] + tokenizer.sp_model.encode(prompt1)\n", 136 | " query = torch.tensor(tokens, dtype=torch.int64).cuda()\n", 137 | " input_embedding = model.tok_embeddings(query)\n", 138 | "\n", 139 | " tokens_2 = tokenizer.sp_model.encode(prompt2)\n", 140 | " query_2 = torch.tensor(tokens_2, dtype=torch.int64).cuda()\n", 141 | " input_embedding_2 = model.tok_embeddings(query_2)\n", 142 | " tokens.extend(tokens_2)\n", 143 | " video = video.cuda().float()\n", 144 | " video/=video.norm(dim=-1,keepdim=True)\n", 145 | " if True:\n", 146 | " sim = video@model.memory.T\n", 147 | "\n", 148 | " sim = (sim*100).softmax(dim=-1)\n", 149 | "\n", 150 | " video = sim@model.memory\n", 151 | " video = video/video.norm(dim=-1,keepdim=True)\n", 152 | " \n", 153 | " video_feature = model.visual_proj(video)\n", 154 | " video_feature = (video_feature + model.temporal_emb.weight[:, :]).type(model.llamatype)\n", 155 | " vqa_video_start=input_embedding.shape[0]\n", 156 | " # print(video_feature.shape)\n", 157 | " input_embedding = torch.cat([input_embedding,video_feature,input_embedding_2])\n", 158 | " start_pos=0\n", 159 | " for j in range(10):\n", 160 | " vqa_h = input_embedding.unsqueeze(0)\n", 161 | " seqlen = vqa_h.shape[-2]\n", 162 | " freqs_cis = freqs[:seqlen]\n", 163 | " mask = None\n", 164 | " mask = torch.full((1, 1, seqlen, seqlen), float(\"-inf\"), device=vqa_h.device)\n", 165 | " mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h)\n", 166 | "\n", 167 | " for i, layer in enumerate(model.layers[-1 * model.adapter_layer:]):\n", 168 | " vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(model.llamatype), vqa_video_start)\n", 169 | " vqa_h = model.norm(vqa_h)\n", 170 | " vqa_output = model.output(vqa_h)\n", 171 | " vqa_output = vqa_output.reshape(-1, model.vocab_size)\n", 172 | " vqa_output[-1,920]=-100\n", 173 | " vqa_output[-1,1128]=-100\n", 174 | " next_token = vqa_output[-1,:].argmax()\n", 175 | " tokens.append(next_token.item())\n", 176 | " token_emb = model.tok_embeddings(next_token.unsqueeze(0))\n", 177 | " input_embedding = torch.cat([input_embedding,token_emb],dim=0)\n", 178 | " return tokens" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "id": "7d043520-913a-40f1-a392-1b432196b78e", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "import os\n", 189 | "dataset_path = '../data/videos/msrvtt/'\n", 190 | "test_file = os.path.join(dataset_path,'test_videodatainfo.json')\n", 191 | "\n", 192 | "with open(test_file,'r') as f:\n", 193 | " test_info = json.load(f)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 9, 199 | "id": "a6a2f83d-843d-4983-9a4e-cb1eeb6e9227", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "annotations_gt = {}\n", 204 | "annotations_gt['images'] = test_info['videos']\n", 205 | "\n", 206 | "for sentence in test_info['sentences']:\n", 207 | " sentence['image_id'] = int(sentence['video_id'].replace('video',''))\n", 208 | " sentence['id'] = sentence['video_id']\n", 209 | "annotations_gt['annotations'] =test_info['sentences']" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "id": "8b766fa8-7843-4627-b8fd-5ebaebef8f03", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "import cv2\n", 220 | "import numpy as np\n", 221 | "from PIL import Image\n", 222 | "def sample_images_from_video(video_path, num_samples=10):\n", 223 | " # Open the video file\n", 224 | " cap = cv2.VideoCapture(video_path)\n", 225 | "\n", 226 | " # Get the total number of frames in the video\n", 227 | " total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", 228 | " frame_rate = cap.get(cv2.CAP_PROP_FPS)\n", 229 | "\n", 230 | " # Calculate total duration in seconds\n", 231 | " total_duration = total_frames / frame_rate\n", 232 | " # print(total_duration)\n", 233 | " # Check if the video opened successfully\n", 234 | " if not cap.isOpened():\n", 235 | " print(\"Error opening video file.\")\n", 236 | " return []\n", 237 | "\n", 238 | " # Calculate the interval for sampling\n", 239 | " interval = total_frames // num_samples\n", 240 | "\n", 241 | " # Initialize a list to store the sampled images\n", 242 | " sampled_images = []\n", 243 | "\n", 244 | " for i in range(num_samples):\n", 245 | " # Set the frame position\n", 246 | " frame_id = i * interval\n", 247 | " cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)\n", 248 | "\n", 249 | " # Read the frame\n", 250 | " ret, frame = cap.read()\n", 251 | "\n", 252 | " # If frame reading was successful, save the frame\n", 253 | " if ret:\n", 254 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 255 | " pil_image = Image.fromarray(frame)\n", 256 | " sampled_images.append(pil_image)\n", 257 | " \n", 258 | " else:\n", 259 | " print(f\"Error reading frame at position {frame_id}\")\n", 260 | "\n", 261 | " # Release the video capture object\n", 262 | " cap.release()\n", 263 | "\n", 264 | " return sampled_images, total_frames\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 11, 270 | "id": "d8f3cd5d-6b6a-4b75-8e01-3502967c71a5", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "import clip" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 12, 280 | "id": "7d2bf58b-c2f0-4bda-9c55-d25b53d1863d", 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "clip_model, preprocess = clip.load(\"ViT-L/14\")\n", 285 | "clip_model.eval()\n", 286 | "clip_model = clip_model.cuda()" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "c8012c00-b980-4cea-8eb2-5d5f4624a14b", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 13, 300 | "id": "05527846-984b-43a5-8136-15f5267ce783", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "from pycocotools.coco import COCO\n", 305 | "from pycocoevalcap.eval import COCOEvalCap\n", 306 | "\n", 307 | "from json import encoder\n", 308 | "encoder.FLOAT_REPR = lambda o: format(o, '.3f')\n", 309 | "import sys" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 14, 315 | "id": "c299f86f-e4ee-404d-8ad4-32b625e33c9e", 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "loading annotations into memory...\n", 323 | "Done (t=0.04s)\n", 324 | "creating index...\n", 325 | "index created!\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "annotation_file = os.path.join(dataset_path,'annotation_file')\n", 331 | "with open(annotation_file,'w') as f:\n", 332 | " json.dump(annotations_gt,f)\n", 333 | "coco = COCO(annotation_file)\n", 334 | "video_ids = coco.getImgIds()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 15, 340 | "id": "6898eb1a-7233-46f3-8060-fea7fd076a67", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "# prompt = \"Instruction: Predict the answer based on the video and question.\\nVideo:\"\n", 345 | "# prompt2 = \"\\nQuestion: Summarize the video.\\nAnswer: It is a video showing\" #26.7\n", 346 | "\n", 347 | "prompt = \"Instruction: Predict the answer based on the video and question.\\nVideo:\"\n", 348 | "prompt2 = \"\\nQuestion: Can you describe this video?\\nAnswer: It is a video of\"" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "id": "af77f161-4b9b-4079-b585-595b27900be3", 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stderr", 359 | "output_type": "stream", 360 | "text": [ 361 | " 5%|▌ | 151/2990 [02:28<36:09, 1.31it/s] " 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "from tqdm import tqdm\n", 367 | "results = []\n", 368 | " \n", 369 | "for video_id in tqdm(video_ids[:]):\n", 370 | " with torch.no_grad():\n", 371 | " try:\n", 372 | " video_path = os.path.join(dataset_path,'TestVideo','video'+str(video_id)+'.mp4')\n", 373 | " sampled_images,_ = sample_images_from_video(video_path)\n", 374 | "\n", 375 | " image_features = [preprocess(image) for image in sampled_images]\n", 376 | " image_features = torch.stack(image_features,dim=0).cuda()\n", 377 | " image_features = clip_model.encode_image(image_features)\n", 378 | " image_features/=image_features.norm(dim=-1,keepdim=True)\n", 379 | " tokens = decoding(model,tokenizer,prompt,prompt2,image_features)\n", 380 | " generate_text = tokenizer.decode(tokens[:])\n", 381 | " generate_text = generate_text.split('It is a video of')[1].strip().split('.')[0]\n", 382 | " results.append({'image_id':video_id,'caption': generate_text})\n", 383 | " except:\n", 384 | " results.append({'image_id':video_id,'caption': 'A video'})" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "602b0e4f-462e-41c1-9fcc-4a59630f0c4f", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "coco_result = coco.loadRes(results) \n", 395 | "\n", 396 | "coco_eval = COCOEvalCap(coco, coco_result)\n", 397 | "coco_eval.evaluate()\n", 398 | "# print output evaluation s|cores\n", 399 | "scores = {}\n", 400 | "for metric, score in coco_eval.eval.items():\n", 401 | " print(f\"{metric}: {score:.3f}\")\n", 402 | " scores[metric] = score\n", 403 | "with open('13B_msrvtt_results.json','w') as f:\n", 404 | " json.dump(scores,f)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "10d8bf4e-8ba2-4565-b080-feec9bf406c4", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [] 414 | } 415 | ], 416 | "metadata": { 417 | "kernelspec": { 418 | "display_name": "llama", 419 | "language": "python", 420 | "name": "llama" 421 | }, 422 | "language_info": { 423 | "codemirror_mode": { 424 | "name": "ipython", 425 | "version": 3 426 | }, 427 | "file_extension": ".py", 428 | "mimetype": "text/x-python", 429 | "name": "python", 430 | "nbconvert_exporter": "python", 431 | "pygments_lexer": "ipython3", 432 | "version": "3.11.5" 433 | } 434 | }, 435 | "nbformat": 4, 436 | "nbformat_minor": 5 437 | } 438 | -------------------------------------------------------------------------------- /demos/upload_leaderboard_13B_zero_shot.json: -------------------------------------------------------------------------------- 1 | {"Action Sequence": 38.0, "Action Prediction": 40.0, "Action Antonym": 42.5, "Fine-grained Action": 35.0, "Unexpected Action": 69.0, "Object Existence": 52.5, "Object Interaction": 58.5, "Object Shuffle": 29.5, "Moving Direction": 22.5, "Action Localization": 43.5, "Scene Transition": 80.5, "Action Count": 38.0, "Moving Count": 25.5, "Moving Attribute": 43.0, "State Change": 43.0, "Fine-grained Pose": 29.5, "Character Order": 37.5, "Egocentric Navigation": 38.5, "Episodic Reasoning": 50.0, "Counterfactual Inference": 32.5, "Avg": 42.449999999999996} -------------------------------------------------------------------------------- /demos/upload_leaderboard_7B_ZS.json: -------------------------------------------------------------------------------- 1 | {"Action Sequence": 42.0, "Action Prediction": 38.5, "Action Antonym": 35.0, "Fine-grained Action": 34.5, "Unexpected Action": 66.0, "Object Existence": 52.5, "Object Interaction": 47.5, "Object Shuffle": 28.000000000000004, "Moving Direction": 22.0, "Action Localization": 37.5, "Scene Transition": 81.0, "Action Count": 38.0, "Moving Count": 24.0, "Moving Attribute": 42.5, "State Change": 41.5, "Fine-grained Pose": 28.499999999999996, "Character Order": 34.0, "Egocentric Navigation": 23.5, "Episodic Reasoning": 49.0, "Counterfactual Inference": 30.5, "Avg": 39.800000000000004} -------------------------------------------------------------------------------- /demos/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class MultiGroupRandomCrop(object): 38 | def __init__(self, size, groups=1): 39 | if isinstance(size, numbers.Number): 40 | self.size = (int(size), int(size)) 41 | else: 42 | self.size = size 43 | self.groups = groups 44 | 45 | def __call__(self, img_group): 46 | 47 | w, h = img_group[0].size 48 | th, tw = self.size 49 | 50 | out_images = list() 51 | 52 | for i in range(self.groups): 53 | x1 = random.randint(0, w - tw) 54 | y1 = random.randint(0, h - th) 55 | 56 | for img in img_group: 57 | assert(img.size[0] == w and img.size[1] == h) 58 | if w == tw and h == th: 59 | out_images.append(img) 60 | else: 61 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 62 | 63 | return out_images 64 | 65 | 66 | class GroupCenterCrop(object): 67 | def __init__(self, size): 68 | self.worker = torchvision.transforms.CenterCrop(size) 69 | 70 | def __call__(self, img_group): 71 | return [self.worker(img) for img in img_group] 72 | 73 | 74 | class GroupRandomHorizontalFlip(object): 75 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 76 | """ 77 | 78 | def __init__(self, is_flow=False): 79 | self.is_flow = is_flow 80 | 81 | def __call__(self, img_group, is_flow=False): 82 | v = random.random() 83 | if v < 0.5: 84 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 85 | if self.is_flow: 86 | for i in range(0, len(ret), 2): 87 | # invert flow pixel values when flipping 88 | ret[i] = ImageOps.invert(ret[i]) 89 | return ret 90 | else: 91 | return img_group 92 | 93 | 94 | class GroupNormalize(object): 95 | def __init__(self, mean, std): 96 | self.mean = mean 97 | self.std = std 98 | 99 | def __call__(self, tensor): 100 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 101 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 102 | 103 | # TODO: make efficient 104 | for t, m, s in zip(tensor, rep_mean, rep_std): 105 | t.sub_(m).div_(s) 106 | 107 | return tensor 108 | 109 | 110 | class GroupScale(object): 111 | """ Rescales the input PIL.Image to the given 'size'. 112 | 'size' will be the size of the smaller edge. 113 | For example, if height > width, then image will be 114 | rescaled to (size * height / width, size) 115 | size: size of the smaller edge 116 | interpolation: Default: PIL.Image.BILINEAR 117 | """ 118 | 119 | def __init__(self, size, interpolation=Image.BILINEAR): 120 | self.worker = torchvision.transforms.Resize(size, interpolation) 121 | 122 | def __call__(self, img_group): 123 | return [self.worker(img) for img in img_group] 124 | 125 | 126 | class GroupOverSample(object): 127 | def __init__(self, crop_size, scale_size=None, flip=True): 128 | self.crop_size = crop_size if not isinstance( 129 | crop_size, int) else (crop_size, crop_size) 130 | 131 | if scale_size is not None: 132 | self.scale_worker = GroupScale(scale_size) 133 | else: 134 | self.scale_worker = None 135 | self.flip = flip 136 | 137 | def __call__(self, img_group): 138 | 139 | if self.scale_worker is not None: 140 | img_group = self.scale_worker(img_group) 141 | 142 | image_w, image_h = img_group[0].size 143 | crop_w, crop_h = self.crop_size 144 | 145 | offsets = GroupMultiScaleCrop.fill_fix_offset( 146 | False, image_w, image_h, crop_w, crop_h) 147 | oversample_group = list() 148 | for o_w, o_h in offsets: 149 | normal_group = list() 150 | flip_group = list() 151 | for i, img in enumerate(img_group): 152 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 153 | normal_group.append(crop) 154 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 155 | 156 | if img.mode == 'L' and i % 2 == 0: 157 | flip_group.append(ImageOps.invert(flip_crop)) 158 | else: 159 | flip_group.append(flip_crop) 160 | 161 | oversample_group.extend(normal_group) 162 | if self.flip: 163 | oversample_group.extend(flip_group) 164 | return oversample_group 165 | 166 | 167 | class GroupFullResSample(object): 168 | def __init__(self, crop_size, scale_size=None, flip=True): 169 | self.crop_size = crop_size if not isinstance( 170 | crop_size, int) else (crop_size, crop_size) 171 | 172 | if scale_size is not None: 173 | self.scale_worker = GroupScale(scale_size) 174 | else: 175 | self.scale_worker = None 176 | self.flip = flip 177 | 178 | def __call__(self, img_group): 179 | 180 | if self.scale_worker is not None: 181 | img_group = self.scale_worker(img_group) 182 | 183 | image_w, image_h = img_group[0].size 184 | crop_w, crop_h = self.crop_size 185 | 186 | w_step = (image_w - crop_w) // 4 187 | h_step = (image_h - crop_h) // 4 188 | 189 | offsets = list() 190 | offsets.append((0 * w_step, 2 * h_step)) # left 191 | offsets.append((4 * w_step, 2 * h_step)) # right 192 | offsets.append((2 * w_step, 2 * h_step)) # center 193 | 194 | oversample_group = list() 195 | for o_w, o_h in offsets: 196 | normal_group = list() 197 | flip_group = list() 198 | for i, img in enumerate(img_group): 199 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 200 | normal_group.append(crop) 201 | if self.flip: 202 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 203 | 204 | if img.mode == 'L' and i % 2 == 0: 205 | flip_group.append(ImageOps.invert(flip_crop)) 206 | else: 207 | flip_group.append(flip_crop) 208 | 209 | oversample_group.extend(normal_group) 210 | oversample_group.extend(flip_group) 211 | return oversample_group 212 | 213 | 214 | class GroupMultiScaleCrop(object): 215 | 216 | def __init__(self, input_size, scales=None, max_distort=1, 217 | fix_crop=True, more_fix_crop=True): 218 | self.scales = scales if scales is not None else [1, .875, .75, .66] 219 | self.max_distort = max_distort 220 | self.fix_crop = fix_crop 221 | self.more_fix_crop = more_fix_crop 222 | self.input_size = input_size if not isinstance(input_size, int) else [ 223 | input_size, input_size] 224 | self.interpolation = Image.BILINEAR 225 | 226 | def __call__(self, img_group): 227 | 228 | im_size = img_group[0].size 229 | 230 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 231 | crop_img_group = [ 232 | img.crop( 233 | (offset_w, 234 | offset_h, 235 | offset_w + 236 | crop_w, 237 | offset_h + 238 | crop_h)) for img in img_group] 239 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 240 | for img in crop_img_group] 241 | return ret_img_group 242 | 243 | def _sample_crop_size(self, im_size): 244 | image_w, image_h = im_size[0], im_size[1] 245 | 246 | # find a crop size 247 | base_size = min(image_w, image_h) 248 | crop_sizes = [int(base_size * x) for x in self.scales] 249 | crop_h = [ 250 | self.input_size[1] if abs( 251 | x - self.input_size[1]) < 3 else x for x in crop_sizes] 252 | crop_w = [ 253 | self.input_size[0] if abs( 254 | x - self.input_size[0]) < 3 else x for x in crop_sizes] 255 | 256 | pairs = [] 257 | for i, h in enumerate(crop_h): 258 | for j, w in enumerate(crop_w): 259 | if abs(i - j) <= self.max_distort: 260 | pairs.append((w, h)) 261 | 262 | crop_pair = random.choice(pairs) 263 | if not self.fix_crop: 264 | w_offset = random.randint(0, image_w - crop_pair[0]) 265 | h_offset = random.randint(0, image_h - crop_pair[1]) 266 | else: 267 | w_offset, h_offset = self._sample_fix_offset( 268 | image_w, image_h, crop_pair[0], crop_pair[1]) 269 | 270 | return crop_pair[0], crop_pair[1], w_offset, h_offset 271 | 272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 273 | offsets = self.fill_fix_offset( 274 | self.more_fix_crop, image_w, image_h, crop_w, crop_h) 275 | return random.choice(offsets) 276 | 277 | @staticmethod 278 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 279 | w_step = (image_w - crop_w) // 4 280 | h_step = (image_h - crop_h) // 4 281 | 282 | ret = list() 283 | ret.append((0, 0)) # upper left 284 | ret.append((4 * w_step, 0)) # upper right 285 | ret.append((0, 4 * h_step)) # lower left 286 | ret.append((4 * w_step, 4 * h_step)) # lower right 287 | ret.append((2 * w_step, 2 * h_step)) # center 288 | 289 | if more_fix_crop: 290 | ret.append((0, 2 * h_step)) # center left 291 | ret.append((4 * w_step, 2 * h_step)) # center right 292 | ret.append((2 * w_step, 4 * h_step)) # lower center 293 | ret.append((2 * w_step, 0 * h_step)) # upper center 294 | 295 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 296 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 297 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 298 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 299 | 300 | return ret 301 | 302 | 303 | class GroupRandomSizedCrop(object): 304 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 305 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 306 | This is popularly used to train the Inception networks 307 | size: size of the smaller edge 308 | interpolation: Default: PIL.Image.BILINEAR 309 | """ 310 | 311 | def __init__(self, size, interpolation=Image.BILINEAR): 312 | self.size = size 313 | self.interpolation = interpolation 314 | 315 | def __call__(self, img_group): 316 | for attempt in range(10): 317 | area = img_group[0].size[0] * img_group[0].size[1] 318 | target_area = random.uniform(0.08, 1.0) * area 319 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 320 | 321 | w = int(round(math.sqrt(target_area * aspect_ratio))) 322 | h = int(round(math.sqrt(target_area / aspect_ratio))) 323 | 324 | if random.random() < 0.5: 325 | w, h = h, w 326 | 327 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 328 | x1 = random.randint(0, img_group[0].size[0] - w) 329 | y1 = random.randint(0, img_group[0].size[1] - h) 330 | found = True 331 | break 332 | else: 333 | found = False 334 | x1 = 0 335 | y1 = 0 336 | 337 | if found: 338 | out_group = list() 339 | for img in img_group: 340 | img = img.crop((x1, y1, x1 + w, y1 + h)) 341 | assert(img.size == (w, h)) 342 | out_group.append( 343 | img.resize( 344 | (self.size, self.size), self.interpolation)) 345 | return out_group 346 | else: 347 | # Fallback 348 | scale = GroupScale(self.size, interpolation=self.interpolation) 349 | crop = GroupRandomCrop(self.size) 350 | return crop(scale(img_group)) 351 | 352 | 353 | class ConvertDataFormat(object): 354 | def __init__(self, model_type): 355 | self.model_type = model_type 356 | 357 | def __call__(self, images): 358 | if self.model_type == '2D': 359 | return images 360 | tc, h, w = images.size() 361 | t = tc // 3 362 | images = images.view(t, 3, h, w) 363 | images = images.permute(1, 0, 2, 3) 364 | return images 365 | 366 | 367 | class Stack(object): 368 | 369 | def __init__(self, roll=False): 370 | self.roll = roll 371 | 372 | def __call__(self, img_group): 373 | if img_group[0].mode == 'L': 374 | return np.concatenate([np.expand_dims(x, 2) 375 | for x in img_group], axis=2) 376 | elif img_group[0].mode == 'RGB': 377 | if self.roll: 378 | return np.concatenate([np.array(x)[:, :, ::-1] 379 | for x in img_group], axis=2) 380 | else: 381 | #print(np.concatenate(img_group, axis=2).shape) 382 | # print(img_group[0].shape) 383 | return np.concatenate(img_group, axis=2) 384 | 385 | 386 | class ToTorchFormatTensor(object): 387 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 388 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 389 | 390 | def __init__(self, div=True): 391 | self.div = div 392 | 393 | def __call__(self, pic): 394 | if isinstance(pic, np.ndarray): 395 | # handle numpy array 396 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 397 | else: 398 | # handle PIL Image 399 | img = torch.ByteTensor( 400 | torch.ByteStorage.from_buffer( 401 | pic.tobytes())) 402 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 403 | # put it from HWC to CHW format 404 | # yikes, this transpose takes 80% of the loading time/CPU 405 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 406 | return img.float().div(255) if self.div else img.float() 407 | 408 | 409 | class IdentityTransform(object): 410 | 411 | def __call__(self, data): 412 | return data 413 | 414 | 415 | if __name__ == "__main__": 416 | trans = torchvision.transforms.Compose([ 417 | GroupScale(256), 418 | GroupRandomCrop(224), 419 | Stack(), 420 | ToTorchFormatTensor(), 421 | GroupNormalize( 422 | mean=[.485, .456, .406], 423 | std=[.229, .224, .225] 424 | )] 425 | ) 426 | 427 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 428 | 429 | color_group = [im] * 3 430 | rst = trans(color_group) 431 | 432 | gray_group = [im.convert('L')] * 9 433 | gray_rst = trans(gray_group) 434 | 435 | trans2 = torchvision.transforms.Compose([ 436 | GroupRandomSizedCrop(256), 437 | Stack(), 438 | ToTorchFormatTensor(), 439 | GroupNormalize( 440 | mean=[.485, .456, .406], 441 | std=[.229, .224, .225]) 442 | ]) 443 | print(trans2(color_group)) 444 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import sys 4 | from typing import Iterable 5 | import util.misc as misc 6 | import util.lr_sched as lr_sched 7 | import json 8 | import random 9 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, loss_scaler, args=None): 10 | model.train(True) 11 | metric_logger = misc.MetricLogger(delimiter=" ") 12 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 13 | header = 'Epoch: [{}]'.format(epoch) 14 | print_freq = int(len(data_loader) / 50) 15 | accum_iter = args.accum_iter 16 | 17 | optimizer.zero_grad() 18 | 19 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 20 | 21 | if data_iter_step % accum_iter == 0: 22 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 23 | mode='vqa' 24 | if args.video_caption: 25 | if random.random()>0.5: 26 | mode='caption' 27 | if args.webvid: 28 | mode = 'caption' 29 | vqa_loss, vaq_loss, qav_loss = model(data,mode=mode) 30 | # print(vqa_loss, vaq_loss, qav_loss) 31 | loss = vqa_loss + vaq_loss*args.weight_captioning + qav_loss 32 | loss_value = loss.item() 33 | vqa_loss_value = vqa_loss.item() 34 | vaq_loss_value = vaq_loss.item() 35 | qav_loss_value = qav_loss.item() 36 | 37 | 38 | loss = loss / accum_iter 39 | 40 | loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) 41 | if (data_iter_step + 1) % accum_iter == 0: 42 | optimizer.zero_grad() 43 | 44 | torch.cuda.synchronize() 45 | 46 | metric_logger.update(loss=loss_value) 47 | metric_logger.update(vqa_loss=vqa_loss_value) 48 | metric_logger.update(vaq_loss=vaq_loss_value) 49 | metric_logger.update(qav_loss=qav_loss_value) 50 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 51 | 52 | # gather the stats from all processes 53 | metric_logger.synchronize_between_processes() 54 | print("Averaged stats:", metric_logger) 55 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 56 | 57 | 58 | def val_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, args=None): 59 | model.eval() 60 | 61 | metric_logger = misc.MetricLogger(delimiter=" ") 62 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 63 | header = 'Epoch: [{}]'.format(epoch) 64 | print_freq = int(len(data_loader) / 10) 65 | results = {} 66 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 67 | answer = data['answer'].cuda() 68 | bsz = answer.shape[0] 69 | if args.openvqa_eval: 70 | mode = 'caption' 71 | else: 72 | mode = 'vqa' 73 | with torch.no_grad(): 74 | logits = model(data, inference=True,mode=mode) 75 | count = (logits != 0).sum(-1) 76 | prediction = (logits.sum(-1) / count).argmin(-1) 77 | 78 | eval = (answer == prediction) 79 | acc = eval.sum().item() / bsz 80 | 81 | misc.log_qtype(data, eval, metric_logger, args) 82 | 83 | lr = optimizer.param_groups[0]["lr"] 84 | metric_logger.update(lr=lr) 85 | metric_logger.update(n=bsz, acc=acc) 86 | 87 | metric_logger.synchronize_between_processes() 88 | print("Averaged stats:", metric_logger) 89 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 90 | 91 | 92 | def test_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, args=None): 93 | model.eval() 94 | metric_logger = misc.MetricLogger(delimiter=" ") 95 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 96 | header = 'Epoch: [{}]'.format(epoch) 97 | print_freq = int(len(data_loader) /10) 98 | results = {} 99 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 100 | answer = data['answer'].cuda() 101 | bsz = answer.shape[0] 102 | if args.openvqa_eval: 103 | mode = 'caption' 104 | else: 105 | mode = 'vqa' 106 | with torch.no_grad(): 107 | logits = model(data, inference=True,mode=mode) 108 | 109 | count = (logits != 0).sum(-1) 110 | prediction = (logits.sum(-1) / count).argmin(-1) 111 | 112 | results[data['qid'][0]]=prediction.item() 113 | 114 | eval = (answer == prediction) 115 | acc = eval.sum().item() / bsz 116 | 117 | misc.log_qtype(data, eval, metric_logger, args) 118 | 119 | lr = optimizer.param_groups[0]["lr"] 120 | metric_logger.update(lr=lr) 121 | metric_logger.update(n=bsz, acc=acc) 122 | 123 | # gather the stats from all processes 124 | with open(f'{args.output_dir}/egos.json','w') as f: 125 | json.dump(results,f) 126 | metric_logger.synchronize_between_processes() 127 | print("Averaged stats:", metric_logger) 128 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 129 | -------------------------------------------------------------------------------- /llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from .generation import LLaMA 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Tokenizer 7 | from .tokenizer_llama3 import Tokenizer_llama3 8 | from .model_llama3 import Transformer_llama3,ModelArgs_llama3 9 | -------------------------------------------------------------------------------- /llama/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import List 5 | 6 | import torch 7 | 8 | from llama.tokenizer import Tokenizer 9 | from llama.model import Transformer 10 | 11 | 12 | class LLaMA: 13 | def __init__(self, model: Transformer, tokenizer: Tokenizer): 14 | self.model = model 15 | self.tokenizer = tokenizer 16 | 17 | def generate(self, prompts: List[str], max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95,) -> List[str]: 18 | bsz = len(prompts) 19 | params = self.model.params 20 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 21 | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] 22 | 23 | min_prompt_size = min([len(t) for t in prompt_tokens]) 24 | max_prompt_size = max([len(t) for t in prompt_tokens]) 25 | 26 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) 27 | 28 | tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() 29 | for k, t in enumerate(prompt_tokens): 30 | tokens[k, : len(t)] = torch.tensor(t).long() 31 | input_text_mask = tokens != self.tokenizer.pad_id 32 | start_pos = min_prompt_size 33 | prev_pos = 0 34 | for cur_pos in range(start_pos, total_len): 35 | logits = self.model.inference(None, tokens[:, prev_pos:cur_pos], prev_pos) 36 | if temperature > 0: 37 | probs = torch.softmax(logits / temperature, dim=-1) 38 | next_token = sample_top_p(probs, top_p) 39 | else: 40 | next_token = torch.argmax(logits, dim=-1) 41 | next_token = next_token.reshape(-1) 42 | # only replace token if prompt has already been generated 43 | next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) 44 | tokens[:, cur_pos] = next_token 45 | prev_pos = cur_pos 46 | 47 | decoded = [] 48 | for i, t in enumerate(tokens.tolist()): 49 | # cut to max gen len 50 | t = t[: len(prompt_tokens[i]) + max_gen_len] 51 | # cut to eos tok if any 52 | try: 53 | t = t[: t.index(self.tokenizer.eos_id)] 54 | except ValueError: 55 | pass 56 | decoded.append(self.tokenizer.decode(t)) 57 | return decoded 58 | 59 | 60 | def sample_top_p(probs, p): 61 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 62 | probs_sum = torch.cumsum(probs_sort, dim=-1) 63 | mask = probs_sum - probs_sort > p 64 | probs_sort[mask] = 0.0 65 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 66 | next_token = torch.multinomial(probs_sort, num_samples=1) 67 | next_token = torch.gather(probs_idx, -1, next_token) 68 | return next_token 69 | -------------------------------------------------------------------------------- /llama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional, Tuple 5 | from dataclasses import dataclass 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from torch.nn import Embedding, Linear 13 | import torch 14 | import pickle 15 | import timm.models.hub as timm_hub 16 | 17 | 18 | @dataclass 19 | class ModelArgs: 20 | dim: int = 512 21 | n_layers: int = 8 22 | n_heads: int = 8 23 | vocab_size: int = -1 # defined later by tokenizer 24 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 25 | norm_eps: float = 1e-5 26 | n_kv_heads: int=1 27 | max_batch_size: int = 32 28 | max_seq_len: int = 2048 29 | adapter_len: int=10 30 | adapter_layer: int=30 31 | 32 | 33 | class RMSNorm(torch.nn.Module): 34 | def __init__(self, dim: int, eps: float = 1e-6): 35 | super().__init__() 36 | self.eps = eps 37 | self.weight = nn.Parameter(torch.ones(dim)) 38 | 39 | def _norm(self, x): 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | output = self._norm(x.float()).type_as(x) 44 | return output * self.weight 45 | 46 | 47 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | t = torch.arange(end, device=freqs.device) # type: ignore 50 | freqs = torch.outer(t, freqs).float() # type: ignore 51 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 52 | return freqs_cis 53 | 54 | 55 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 56 | ndim = x.ndim 57 | assert 0 <= 1 < ndim 58 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 59 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 60 | return freqs_cis.view(*shape) 61 | 62 | 63 | def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 64 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 65 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 66 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 67 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 68 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 69 | return xq_out.type_as(xq), xk_out.type_as(xk) 70 | 71 | 72 | class Attention(nn.Module): 73 | def __init__(self, args: ModelArgs): 74 | super().__init__() 75 | self.n_local_heads = args.n_heads 76 | self.head_dim = args.dim // args.n_heads 77 | self.max_feats = args.max_feats 78 | 79 | self.wq = Linear(args.dim, args.n_heads * self.head_dim, bias=False) 80 | self.wk = Linear(args.dim, args.n_heads * self.head_dim, bias=False) 81 | self.wv = Linear(args.dim, args.n_heads * self.head_dim, bias=False) 82 | self.wo = Linear(args.n_heads * self.head_dim, args.dim, bias=False) 83 | 84 | self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() 85 | self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() 86 | self.gate1 = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1)) 87 | self.gate2 = torch.nn.Parameter(torch.ones(1, self.n_local_heads, 1, 1) * -args.bias) 88 | self.llamatype = args.llamatype 89 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None): 90 | bsz, seqlen, _ = x.shape 91 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 92 | 93 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 94 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) 95 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) 96 | 97 | 98 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 99 | if adapter is not None: 100 | adapter_len = adapter.shape[1] 101 | adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1) 102 | adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1) 103 | xk = torch.cat([adapter_k, xk], dim=1) 104 | xv = torch.cat([adapter_v, xv], dim=1) 105 | extra_mask = torch.zeros(1, 1, seqlen, adapter_len).to(mask) 106 | mask = torch.cat([extra_mask, mask], dim=-1) 107 | keys = xk 108 | values = xv 109 | 110 | xq = xq.transpose(1, 2) 111 | keys = keys.transpose(1, 2) 112 | values = values.transpose(1, 2) 113 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 114 | if mask is not None: 115 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 116 | if adapter is not None: 117 | adapter_scores = F.softmax(scores[..., :adapter_len].float(), dim=-1).type_as(xq) * self.gate1.tanh().type(self.llamatype) 118 | if video_start is not None: 119 | vt_scores = scores[..., adapter_len:].clone() 120 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] = \ 121 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] + self.gate2.type(self.llamatype) 122 | vt_scores = F.softmax(vt_scores.float(), dim=-1).type_as(xq) 123 | else: 124 | vt_scores = F.softmax(scores[..., adapter_len:], dim=-1) 125 | scores = torch.cat([adapter_scores, vt_scores], dim=-1) 126 | else: 127 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 128 | output = torch.matmul(scores, values) 129 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 130 | return self.wo(output) 131 | 132 | 133 | class FeedForward(nn.Module): 134 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int): 135 | super().__init__() 136 | hidden_dim = int(2 * hidden_dim / 3) 137 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 138 | 139 | self.w1 = Linear(dim, hidden_dim, bias=False) 140 | self.w2 = Linear(hidden_dim, dim, bias=False) 141 | self.w3 = Linear(dim, hidden_dim, bias=False) 142 | 143 | def forward(self, x): 144 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 145 | 146 | 147 | class TransformerBlock(nn.Module): 148 | def __init__(self, layer_id: int, args: ModelArgs): 149 | super().__init__() 150 | self.n_heads = args.n_heads 151 | self.dim = args.dim 152 | self.head_dim = args.dim // args.n_heads 153 | self.attention = Attention(args) 154 | self.feed_forward = FeedForward(dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of) 155 | self.layer_id = layer_id 156 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 157 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 158 | 159 | 160 | 161 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None): 162 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, video_start) 163 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 164 | return out 165 | 166 | 167 | class Transformer(nn.Module): 168 | def __init__(self, params: ModelArgs, args): 169 | super().__init__() 170 | params.max_feats = args.max_feats 171 | params.bias = args.bias 172 | self.args = args 173 | 174 | if self.args.llama2: 175 | self.llamatype = torch.bfloat16 176 | params.llamatype = torch.bfloat16 177 | else: 178 | self.llamatype = torch.half 179 | params.llamatype = torch.half 180 | 181 | self.params = params 182 | self.vocab_size = params.vocab_size 183 | self.n_layers = params.n_layers 184 | self.max_feats = args.max_feats 185 | 186 | 187 | self.tok_embeddings = Embedding(params.vocab_size, params.dim) 188 | 189 | self.adapter_query = Embedding(params.adapter_len * params.adapter_layer, params.dim) 190 | 191 | clip_feature_dim=768 192 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False) 193 | self.temporal_emb = Embedding(self.max_feats, params.dim) 194 | self.adapter_len = params.adapter_len 195 | self.adapter_layer = params.adapter_layer 196 | 197 | self.vqa_criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 198 | self.vaq_criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 199 | 200 | self.inference_criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none') 201 | 202 | self.layers = torch.nn.ModuleList() 203 | for layer_id in range(params.n_layers): 204 | self.layers.append(TransformerBlock(layer_id, params)) 205 | 206 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 207 | self.output = Linear(params.dim, params.vocab_size, bias=False) 208 | 209 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) 210 | 211 | self.video_label = torch.arange(1, self.max_feats) 212 | self.tau = args.tau 213 | if self.args.memory: 214 | with open('./data/textvid/memory.pkl','rb') as f: 215 | self.memory = pickle.load(f).float()[:1000000].cuda() 216 | 217 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False) 218 | 219 | 220 | def re_init_freqs(self,max_seq_len): 221 | self.params.max_seq_len=max_seq_len 222 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) 223 | 224 | def forward(self, data, inference=False,mode='vqa'): 225 | 226 | video = data['video'].cuda() 227 | # video = video/video.norm(dim=-1,keepdim=True) 228 | if self.args.memory and inference==True: 229 | sim = video@self.memory.T 230 | 231 | sim = (sim*100).softmax(dim=-1) 232 | video = sim@self.memory 233 | video = video/video.norm(dim=-1,keepdim=True) 234 | if self.args.onlyqa: 235 | video=video*0 236 | 237 | 238 | vqa_id, vaq_id, qav_id = data['text_id']['vqa'].cuda(), data['text_id']['vaq'].cuda(), data['text_id']['qav'].cuda() 239 | vqa_label, vaq_label, qav_label = data['label']['vqa'].cuda(), data['label']['vaq'].cuda(), data['label']['qav'].cuda() 240 | vqa_video_start, vaq_video_start, qav_video_index = data['video_start']['vqa'][0], data['video_start']['vaq'][0], data['video_index']['qav'].cuda() 241 | 242 | bsz, n_options, seqlen = vqa_id.shape 243 | vqa_id, vaq_id = vqa_id.reshape(-1, seqlen), vaq_id.reshape(-1, seqlen) 244 | vqa_label, vaq_label = vqa_label.reshape(-1, seqlen), vaq_label.reshape(-1, seqlen) 245 | vqa_label, vaq_label = vqa_label[:, 1:].flatten(), vaq_label[:, 1:].flatten() 246 | 247 | qav_id = qav_id.reshape(-1, seqlen) 248 | qav_label = qav_label.reshape(-1, seqlen) 249 | qav_video_mask = qav_label.ge(0) 250 | qav_label = qav_label[:, 1:].flatten() 251 | 252 | 253 | with torch.no_grad(): 254 | vqa_h = self.tok_embeddings(vqa_id) 255 | 256 | if self.args.vaq and not inference: 257 | vaq_h = self.tok_embeddings(vaq_id) 258 | if self.args.openvqa_eval: 259 | vaq_h = self.tok_embeddings(vaq_id) 260 | if self.args.qav and not inference: 261 | qav_h = self.tok_embeddings(qav_id) 262 | 263 | freqs_cis = self.freqs_cis.to(vqa_h.device) 264 | freqs_cis = freqs_cis[:seqlen] 265 | mask = None 266 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=vqa_h.device) 267 | mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h) 268 | start_pos = 0 269 | vqa_loss, vaq_loss, qav_loss = torch.tensor([0]).cuda(),torch.tensor([0]).cuda(), torch.tensor([0]).cuda() 270 | 271 | adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, self.params.dim).unsqueeze(1) 272 | 273 | _video_feature = self.visual_proj(video) 274 | 275 | if inference: 276 | _video_feature = _video_feature.unsqueeze(1).repeat(1, n_options, 1, 1).view(-1, _video_feature.shape[-2], _video_feature.shape[-1]) 277 | 278 | video_feature = (_video_feature + self.temporal_emb.weight[None, :, :]).type(self.llamatype) 279 | 280 | 281 | if mode == 'vqa': 282 | 283 | vqa_h = vqa_h.clone() 284 | vqa_h[:, vqa_video_start:vqa_video_start+self.max_feats] = video_feature 285 | 286 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]): 287 | vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vqa_video_start) 288 | 289 | vqa_h = self.norm(vqa_h) 290 | vqa_output = self.output(vqa_h) 291 | vqa_output = vqa_output[:, :-1, :].reshape(-1, self.vocab_size) 292 | vqa_loss = self.vqa_criterion(vqa_output, vqa_label) 293 | 294 | if mode =='caption': 295 | vaq_h = vaq_h.clone() 296 | vaq_h[:, vaq_video_start:vaq_video_start+self.max_feats] = video_feature 297 | 298 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]): 299 | vaq_h = layer(vaq_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vaq_video_start) 300 | vaq_h = self.norm(vaq_h) 301 | vaq_output = self.output(vaq_h) 302 | vaq_output = vaq_output[:, :-1, :].reshape(-1, self.vocab_size) 303 | vaq_loss = self.vaq_criterion(vaq_output, vaq_label) 304 | 305 | if inference: 306 | if self.args.openvqa_eval: 307 | logits = self.inference_criterion(vaq_output, vaq_label) 308 | logits = logits.reshape(bsz, n_options, -1) 309 | else: 310 | logits = self.inference_criterion(vqa_output, vqa_label) 311 | logits = logits.reshape(bsz, n_options, -1) 312 | return logits 313 | else: 314 | return vqa_loss, vaq_loss, qav_loss 315 | -------------------------------------------------------------------------------- /llama/model_llama3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional, Tuple 5 | from dataclasses import dataclass 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from torch.nn import Embedding, Linear 13 | import torch 14 | import pickle 15 | import timm.models.hub as timm_hub 16 | from fairscale.nn.model_parallel.layers import ( 17 | ColumnParallelLinear, 18 | RowParallelLinear, 19 | VocabParallelEmbedding, 20 | ) 21 | 22 | 23 | @dataclass 24 | class ModelArgs_llama3: 25 | dim: int = 4096 26 | n_layers: int = 32 27 | n_heads: int = 32 28 | n_kv_heads: Optional[int] = None 29 | vocab_size: int = -1 30 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 31 | ffn_dim_multiplier: Optional[float] = None 32 | norm_eps: float = 1e-5 33 | rope_theta: float = 500000 34 | 35 | max_batch_size: int = 32 36 | max_seq_len: int = 2048 37 | 38 | adapter_len: int=10 39 | adapter_layer: int=30 40 | 41 | 42 | class RMSNorm(torch.nn.Module): 43 | def __init__(self, dim: int, eps: float = 1e-6): 44 | super().__init__() 45 | self.eps = eps 46 | self.weight = nn.Parameter(torch.ones(dim)) 47 | 48 | def _norm(self, x): 49 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 50 | 51 | def forward(self, x): 52 | output = self._norm(x.float()).type_as(x) 53 | return output * self.weight 54 | 55 | 56 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 57 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 58 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) 59 | freqs = torch.outer(t, freqs) 60 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 61 | return freqs_cis 62 | 63 | 64 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 65 | ndim = x.ndim 66 | assert 0 <= 1 < ndim 67 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 68 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 69 | return freqs_cis.view(*shape) 70 | 71 | 72 | def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 73 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 74 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 75 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 76 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 77 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 78 | return xq_out.type_as(xq), xk_out.type_as(xk) 79 | 80 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 81 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 82 | bs, slen, n_kv_heads, head_dim = x.shape 83 | if n_rep == 1: 84 | return x 85 | return ( 86 | x[:, :, :, None, :] 87 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 88 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 89 | ) 90 | 91 | 92 | class Attention(nn.Module): 93 | def __init__(self, args: ModelArgs_llama3): 94 | super().__init__() 95 | self.n_local_heads = args.n_heads 96 | self.head_dim = args.dim // args.n_heads 97 | self.max_feats = args.max_feats 98 | self.n_kv_heads = args.n_kv_heads 99 | model_parallel_size=1 100 | self.n_local_heads = args.n_heads // model_parallel_size 101 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 102 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 103 | 104 | self.wq = Linear(args.dim, args.n_heads * self.head_dim, bias=False) 105 | self.wk = Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) 106 | self.wv = Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) 107 | self.wo = Linear(args.n_heads * self.head_dim, args.dim, bias=False) 108 | 109 | self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() 110 | self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() 111 | self.gate1 = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1)) 112 | self.gate2 = torch.nn.Parameter(torch.ones(1, self.n_local_heads, 1, 1) * -args.bias) 113 | self.llamatype = args.llamatype 114 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None): 115 | bsz, seqlen, _ = x.shape 116 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 117 | 118 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 119 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 120 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 121 | 122 | 123 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 124 | if adapter is not None: 125 | adapter_len = adapter.shape[1] 126 | adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_kv_heads, self.head_dim).repeat(bsz, 1, 1, 1) 127 | adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_kv_heads, self.head_dim).repeat(bsz, 1, 1, 1) 128 | xk = torch.cat([adapter_k, xk], dim=1) 129 | xv = torch.cat([adapter_v, xv], dim=1) 130 | extra_mask = torch.zeros(1, 1, seqlen, adapter_len).to(mask) 131 | mask = torch.cat([extra_mask, mask], dim=-1) 132 | keys = xk 133 | values = xv 134 | keys = repeat_kv( 135 | keys, self.n_rep 136 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim) 137 | values = repeat_kv( 138 | values, self.n_rep 139 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim) 140 | 141 | xq = xq.transpose(1, 2) 142 | keys = keys.transpose(1, 2) 143 | values = values.transpose(1, 2) 144 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 145 | if mask is not None: 146 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 147 | if adapter is not None: 148 | adapter_scores = F.softmax(scores[..., :adapter_len].float(), dim=-1).type_as(xq) * self.gate1.tanh().type(self.llamatype) 149 | if video_start is not None: 150 | vt_scores = scores[..., adapter_len:].clone() 151 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] = \ 152 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] + self.gate2.type(self.llamatype) 153 | vt_scores = F.softmax(vt_scores.float(), dim=-1).type_as(xq) 154 | else: 155 | vt_scores = F.softmax(scores[..., adapter_len:], dim=-1) 156 | scores = torch.cat([adapter_scores, vt_scores], dim=-1) 157 | else: 158 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 159 | output = torch.matmul(scores, values) 160 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 161 | return self.wo(output) 162 | 163 | 164 | class FeedForward(nn.Module): 165 | def __init__( 166 | self, 167 | dim: int, 168 | hidden_dim: int, 169 | multiple_of: int, 170 | ffn_dim_multiplier: Optional[float], 171 | ): 172 | super().__init__() 173 | hidden_dim = int(2 * hidden_dim / 3) 174 | # custom dim factor multiplier 175 | if ffn_dim_multiplier is not None: 176 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 177 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 178 | 179 | self.w1 = Linear(dim, hidden_dim, bias=False) 180 | self.w2 = Linear(hidden_dim, dim, bias=False) 181 | self.w3 = Linear(dim, hidden_dim, bias=False) 182 | 183 | def forward(self, x): 184 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 185 | 186 | 187 | class TransformerBlock(nn.Module): 188 | def __init__(self, layer_id: int, args: ModelArgs_llama3): 189 | super().__init__() 190 | self.n_heads = args.n_heads 191 | self.dim = args.dim 192 | self.head_dim = args.dim // args.n_heads 193 | self.attention = Attention(args) 194 | self.feed_forward = FeedForward( 195 | dim=args.dim, 196 | hidden_dim=4 * args.dim, 197 | multiple_of=args.multiple_of, 198 | ffn_dim_multiplier=args.ffn_dim_multiplier, 199 | ) 200 | self.layer_id = layer_id 201 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 202 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 203 | 204 | 205 | 206 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None): 207 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, video_start) 208 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 209 | return out 210 | 211 | 212 | class Transformer_llama3(nn.Module): 213 | def __init__(self, params: ModelArgs_llama3, args): 214 | super().__init__() 215 | params.max_feats = args.max_feats 216 | params.bias = args.bias 217 | self.args = args 218 | 219 | if self.args.llama2: 220 | self.llamatype = torch.bfloat16 221 | params.llamatype = torch.bfloat16 222 | else: 223 | self.llamatype = torch.half 224 | params.llamatype = torch.half 225 | 226 | self.params = params 227 | self.vocab_size = params.vocab_size 228 | self.n_layers = params.n_layers 229 | self.max_feats = args.max_feats 230 | 231 | 232 | self.tok_embeddings = Embedding(params.vocab_size, params.dim) 233 | self.adapter_query = Embedding(params.adapter_len * params.adapter_layer, params.dim) 234 | 235 | clip_feature_dim=768 236 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False) 237 | self.temporal_emb = Embedding(self.max_feats, params.dim) 238 | self.adapter_len = params.adapter_len 239 | self.adapter_layer = params.adapter_layer 240 | 241 | self.vqa_criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 242 | self.vaq_criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 243 | self.qav_criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) 244 | self.inference_criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none') 245 | 246 | self.layers = torch.nn.ModuleList() 247 | for layer_id in range(params.n_layers): 248 | self.layers.append(TransformerBlock(layer_id, params)) 249 | 250 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 251 | self.output = Linear( 252 | params.dim, params.vocab_size, bias=False 253 | ) 254 | 255 | self.freqs_cis = precompute_freqs_cis( 256 | params.dim // params.n_heads, 257 | params.max_seq_len * 2, 258 | params.rope_theta, 259 | ) 260 | 261 | self.video_label = torch.arange(1, self.max_feats) 262 | self.tau = args.tau 263 | if self.args.memory: 264 | with open('./data/textvid/memory.pkl','rb') as f: 265 | self.memory = pickle.load(f).float()[:1000000].cuda() 266 | 267 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False) 268 | 269 | 270 | 271 | 272 | def re_init_freqs(self,max_seq_len): 273 | self.params.max_seq_len=max_seq_len 274 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) 275 | 276 | def forward(self, data, inference=False,mode='vqa'): 277 | 278 | video = data['video'].cuda() 279 | # video = video/video.norm(dim=-1,keepdim=True) 280 | if self.args.memory and inference==True: 281 | sim = video@self.memory.T 282 | 283 | sim = (sim*100).softmax(dim=-1) 284 | video = sim@self.memory 285 | video = video/video.norm(dim=-1,keepdim=True) 286 | if self.args.onlyqa: 287 | video=video*0 288 | 289 | vqa_id, vaq_id, qav_id = data['text_id']['vqa'].cuda(), data['text_id']['vaq'].cuda(), data['text_id']['qav'].cuda() 290 | vqa_label, vaq_label, qav_label = data['label']['vqa'].cuda(), data['label']['vaq'].cuda(), data['label']['qav'].cuda() 291 | vqa_video_start, vaq_video_start, qav_video_index = data['video_start']['vqa'][0], data['video_start']['vaq'][0], data['video_index']['qav'].cuda() 292 | 293 | bsz, n_options, seqlen = vqa_id.shape 294 | vqa_id, vaq_id = vqa_id.reshape(-1, seqlen), vaq_id.reshape(-1, seqlen) 295 | vqa_label, vaq_label = vqa_label.reshape(-1, seqlen), vaq_label.reshape(-1, seqlen) 296 | vqa_label, vaq_label = vqa_label[:, 1:].flatten(), vaq_label[:, 1:].flatten() 297 | 298 | qav_id = qav_id.reshape(-1, seqlen) 299 | qav_label = qav_label.reshape(-1, seqlen) 300 | qav_video_mask = qav_label.ge(0) 301 | qav_label = qav_label[:, 1:].flatten() 302 | 303 | 304 | with torch.no_grad(): 305 | vqa_h = self.tok_embeddings(vqa_id) 306 | 307 | if self.args.vaq and not inference: 308 | vaq_h = self.tok_embeddings(vaq_id) 309 | if self.args.openvqa_eval: 310 | vaq_h = self.tok_embeddings(vaq_id) 311 | if self.args.qav and not inference: 312 | qav_h = self.tok_embeddings(qav_id) 313 | 314 | freqs_cis = self.freqs_cis.to(vqa_h.device) 315 | freqs_cis = freqs_cis[:seqlen] 316 | mask = None 317 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=vqa_h.device) 318 | mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h) 319 | start_pos = 0 320 | vqa_loss, vaq_loss, qav_loss = torch.tensor([0]).cuda(),torch.tensor([0]).cuda(), torch.tensor([0]).cuda() 321 | 322 | adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, self.params.dim).unsqueeze(1) 323 | 324 | _video_feature = self.visual_proj(video) 325 | 326 | if inference: 327 | _video_feature = _video_feature.unsqueeze(1).repeat(1, n_options, 1, 1).view(-1, _video_feature.shape[-2], _video_feature.shape[-1]) 328 | video_feature = (_video_feature + self.temporal_emb.weight[None, :, :]).type(self.llamatype) 329 | 330 | 331 | if mode == 'vqa': 332 | 333 | vqa_h = vqa_h.clone() 334 | vqa_h[:, vqa_video_start:vqa_video_start+self.max_feats] = video_feature 335 | 336 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]): 337 | vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vqa_video_start) 338 | 339 | vqa_h = self.norm(vqa_h) 340 | vqa_output = self.output(vqa_h) 341 | vqa_output = vqa_output[:, :-1, :].reshape(-1, self.vocab_size) 342 | vqa_loss = self.vqa_criterion(vqa_output, vqa_label) 343 | 344 | if mode =='caption': 345 | vaq_h = vaq_h.clone() 346 | vaq_h[:, vaq_video_start:vaq_video_start+self.max_feats] = video_feature 347 | 348 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]): 349 | vaq_h = layer(vaq_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vaq_video_start) 350 | vaq_h = self.norm(vaq_h) 351 | vaq_output = self.output(vaq_h) 352 | vaq_output = vaq_output[:, :-1, :].reshape(-1, self.vocab_size) 353 | vaq_loss = self.vaq_criterion(vaq_output, vaq_label) 354 | 355 | if inference: 356 | if self.args.openvqa_eval: 357 | logits = self.inference_criterion(vaq_output, vaq_label) 358 | logits = logits.reshape(bsz, n_options, -1) 359 | else: 360 | logits = self.inference_criterion(vqa_output, vqa_label) 361 | logits = logits.reshape(bsz, n_options, -1) 362 | return logits 363 | else: 364 | return vqa_loss, vaq_loss, qav_loss 365 | -------------------------------------------------------------------------------- /llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from sentencepiece import SentencePieceProcessor 5 | from logging import getLogger 6 | from typing import List 7 | import os 8 | import torch 9 | 10 | logger = getLogger() 11 | 12 | 13 | class Tokenizer: 14 | def __init__(self, model_path: str): 15 | # reload tokenizer 16 | assert os.path.isfile(model_path), model_path 17 | self.sp_model = SentencePieceProcessor(model_file=model_path) 18 | logger.info(f"Reloaded SentencePiece model from {model_path}") 19 | 20 | # BOS / EOS token IDs 21 | self.n_words: int = self.sp_model.vocab_size() 22 | self.bos_id: int = self.sp_model.bos_id() 23 | self.eos_id: int = self.sp_model.eos_id() 24 | self.pad_id: int = self.sp_model.pad_id() 25 | 26 | self.v_token_id = 15167 27 | self.q_token_id = 16492 28 | self.a_token_id = 22550 29 | self.c_token_id = 9868 30 | self.nl_id = 13 31 | logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") 32 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 33 | 34 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 35 | assert type(s) is str 36 | t = self.sp_model.encode(s) 37 | if bos: 38 | t = [self.bos_id] + t 39 | if eos: 40 | t = t + [self.eos_id] 41 | return t 42 | 43 | def encode_vqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 44 | i_text = "Instruction: Choose the correct answer based on the video and question.\n" 45 | q_text = text['q_text'] 46 | o_text = text['o_text'] 47 | a_text = text['a_text'] 48 | 49 | s1 = i_text + 'Video:' 50 | t1 = [self.bos_id] + self.sp_model.encode(s1) 51 | video_start = len(t1) 52 | 53 | s2 = q_text + o_text + a_text 54 | 55 | if split == 'train': 56 | s2 = s2 + answer_mapping[answer] 57 | t2 = self.sp_model.encode(s2) + [self.eos_id] 58 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 59 | prefix_index = t[0].index(self.a_token_id) + 4 60 | else: 61 | t = [] 62 | for k, v in answer_mapping.items(): 63 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id] 64 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 65 | prefix_index = t[answer].index(self.a_token_id) + 4 66 | return t, prefix_index, video_start 67 | 68 | def encode_vaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 69 | i_text = "Instruction: Predict the question based on the video and answer.\n" 70 | q_text = text['q_text'].strip() 71 | o_text = text['o_text'] 72 | a_text = text['a_text'] 73 | 74 | s1 = i_text + 'Video:' 75 | t1 = [self.bos_id] + self.sp_model.encode(s1) 76 | video_start = len(t1) 77 | 78 | s2 = o_text + a_text 79 | 80 | if split == 'train': 81 | s2 = s2 + answer_mapping[answer] + "\n" + q_text 82 | t2 = self.sp_model.encode(s2) + [self.eos_id] 83 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 84 | prefix_index = t[0].index(self.q_token_id) + 1 85 | else: 86 | t = [] 87 | for k, v in answer_mapping.items(): 88 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 89 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 90 | prefix_index = t[answer].index(self.q_token_id) + 1 91 | return t, prefix_index, video_start 92 | 93 | 94 | def encode_qav(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 95 | i_text = "Instruction: Predict the video based on the question and answer.\n" 96 | q_text = text['q_text'] 97 | o_text = text['o_text'] 98 | a_text = text['a_text'] 99 | 100 | s1 = i_text + q_text + o_text + a_text 101 | 102 | if split == 'train': 103 | s1 = s1 + answer_mapping[answer] + "\n" + "Video:" 104 | t1 = [self.bos_id] + self.sp_model.encode(s1) 105 | t = [t1 + [-2 for _ in range(max_feats)] + [self.eos_id]] 106 | prefix_index = t[0].index(self.v_token_id) + 1 107 | else: 108 | t = [] 109 | for k, v in answer_mapping.items(): 110 | t1 = [self.bos_id] + self.sp_model.encode(s1 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id] 111 | t.append(t1) 112 | prefix_index = t[answer].index(self.v_token_id) + 1 113 | return t, prefix_index 114 | 115 | def decode(self, t: List[int]) -> str: 116 | return self.sp_model.decode(t) 117 | 118 | def encode_dvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 119 | i_text = "Instruction: Predict the answer based on the dialogue, video and question.\n" 120 | q_text = text['q_text'] 121 | o_text = text['o_text'] 122 | a_text = text['a_text'] 123 | d_text = text['d_text'] 124 | 125 | s1 = i_text + 'Video:' 126 | t1 = [self.bos_id] + self.sp_model.encode(s1) 127 | video_start = len(t1) 128 | 129 | prefix_i = video_start + max_feats + 1 130 | d1 = self.sp_model.encode(d_text) 131 | prefix_main = prefix_i + len(d1) 132 | 133 | s2 = q_text + o_text + a_text 134 | 135 | if split == 'train': 136 | s2 = s2 + answer_mapping[answer] 137 | t2 = self.sp_model.encode(s2) + [self.eos_id] 138 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2] 139 | else: 140 | t = [] 141 | for k, v in answer_mapping.items(): 142 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id] 143 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2) 144 | 145 | prefix_index = len(t[0]) - 4 146 | 147 | return t, prefix_index, video_start, prefix_i, prefix_main 148 | 149 | def encode_dvaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 150 | i_text = "Instruction: Predict the question based on the dialogue, video and answer.\n" 151 | q_text = text['q_text'].strip() 152 | o_text = text['o_text'] 153 | a_text = text['a_text'] 154 | d_text = text['d_text'] 155 | 156 | s1 = i_text + 'Video:' 157 | t1 = [self.bos_id] + self.sp_model.encode(s1) 158 | video_start = len(t1) 159 | 160 | prefix_i = video_start + max_feats + 1 161 | d1 = self.sp_model.encode(d_text) 162 | prefix_main = prefix_i + len(d1) 163 | 164 | s2 = o_text + a_text 165 | 166 | if split == 'train': 167 | s2 = s2 + answer_mapping[answer] + "\n" + q_text 168 | t2 = self.sp_model.encode(s2) + [self.eos_id] 169 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2] 170 | else: 171 | t = [] 172 | for k, v in answer_mapping.items(): 173 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 174 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2) 175 | 176 | prefix_index = t[0].index(self.q_token_id) + 1 177 | 178 | return t, prefix_index, video_start, prefix_i, prefix_main 179 | 180 | def encode_dqav(self, text=None, max_feats=10, max_seq_len=128, split='train', answer_mapping=None, answer=None) -> List[int]: 181 | i_text = "Instruction: Predict the video based on the dialogue, question and answer.\n" 182 | d_text = text['d_text'] 183 | q_text = text['q_text'] 184 | o_text = text['o_text'] 185 | a_text = text['a_text'] 186 | s1, s2, s3 = i_text, d_text, q_text + o_text + a_text 187 | 188 | t1 = [self.bos_id] + self.sp_model.encode(s1) 189 | t2 = self.sp_model.encode(s2) 190 | prefix_i, prefix_q = len(t1), len(t1) + len(t2) 191 | 192 | if split == 'train': 193 | t3 = self.sp_model.encode(s3 + answer_mapping[answer] + "\n" + "Video:") 194 | t = [t1 + t2 + t3 + [-2 for _ in range(max_feats)] + [self.eos_id]] 195 | else: 196 | t = [] 197 | for k, v in answer_mapping.items(): 198 | t3 = self.sp_model.encode(s3 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id] 199 | t.append(t1 + t2 + t3) 200 | 201 | prefix_index = len(t[0]) - max_feats - 1 202 | 203 | return t, prefix_index, prefix_i, prefix_q 204 | 205 | 206 | 207 | def encode_videocap(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 208 | i_text = "Instruction: Generate a summary for the video.\n" 209 | # q_text = text['q_text'].strip() 210 | # o_text = text['o_text'] 211 | # a_text = text['a_text'] 212 | 213 | s1 = i_text + 'Video:' 214 | t1 = [self.bos_id] + self.sp_model.encode(s1) 215 | video_start = len(t1) 216 | 217 | s2 = text['c_text'] 218 | 219 | if split == 'train': 220 | # s2 = s2 + answer_mapping[answer] + "\n" + q_text 221 | s2 = "\n"+s2 222 | t2 = self.sp_model.encode(s2) + [self.eos_id] 223 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 224 | prefix_index = t[0].index(self.c_token_id) + 1 225 | else: 226 | t = [] 227 | for k, v in answer_mapping.items(): 228 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 229 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 230 | prefix_index = t[answer].index(self.c_token_id) + 1 231 | return t, prefix_index, video_start 232 | 233 | 234 | 235 | def encode_openvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 236 | i_text = "Instruction: Predict the answer based on the video and question.\n" 237 | q_text = text['q_text'] 238 | oa_text = text['oa_text'] 239 | 240 | s1 = i_text + 'Video:' 241 | t1 = [self.bos_id] + self.sp_model.encode(s1) 242 | video_start = len(t1) 243 | 244 | if split == 'train': 245 | s2 = q_text + oa_text 246 | t2 = self.sp_model.encode(s2) + [self.eos_id] 247 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 248 | prefix_index = t[0].index(self.a_token_id)+1 249 | else: 250 | t = [] 251 | for open_option in text['open_options']: 252 | s2 = q_text + open_option 253 | t2 = self.sp_model.encode(s2) + [self.eos_id] 254 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 255 | prefix_index = t[0].index(self.a_token_id)+1 256 | # else: 257 | # t = [] 258 | # prefix_index = [] 259 | # for open_option in text['open_options']: 260 | # s2 = q_text + open_option+'. For this video, this answer is correct.' 261 | # t2 = self.sp_model.encode(s2) + [self.eos_id] 262 | # t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 263 | # prefix_index.append(len(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)-50) 264 | 265 | return t, prefix_index, video_start -------------------------------------------------------------------------------- /llama/tokenizer_llama3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | import os 5 | from logging import getLogger 6 | from pathlib import Path 7 | from typing import ( 8 | AbstractSet, 9 | cast, 10 | Collection, 11 | Dict, 12 | Iterator, 13 | List, 14 | Literal, 15 | Sequence, 16 | TypedDict, 17 | Union, 18 | ) 19 | 20 | import tiktoken 21 | from tiktoken.load import load_tiktoken_bpe 22 | 23 | 24 | logger = getLogger(__name__) 25 | 26 | 27 | Role = Literal["system", "user", "assistant"] 28 | 29 | 30 | class Message(TypedDict): 31 | role: Role 32 | content: str 33 | 34 | 35 | Dialog = Sequence[Message] 36 | 37 | 38 | class Tokenizer_llama3: 39 | """ 40 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 41 | """ 42 | 43 | special_tokens: Dict[str, int] 44 | 45 | num_reserved_special_tokens = 256 46 | 47 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 48 | 49 | def __init__(self, model_path: str): 50 | """ 51 | Initializes the Tokenizer with a Tiktoken model. 52 | 53 | Args: 54 | model_path (str): The path to the Tiktoken model file. 55 | """ 56 | assert os.path.isfile(model_path), model_path 57 | 58 | mergeable_ranks = load_tiktoken_bpe(model_path) 59 | num_base_tokens = len(mergeable_ranks) 60 | special_tokens = [ 61 | "<|begin_of_text|>", 62 | "<|end_of_text|>", 63 | "<|reserved_special_token_0|>", 64 | "<|reserved_special_token_1|>", 65 | "<|reserved_special_token_2|>", 66 | "<|reserved_special_token_3|>", 67 | "<|start_header_id|>", 68 | "<|end_header_id|>", 69 | "<|reserved_special_token_4|>", 70 | "<|eot_id|>", # end of turn 71 | ] + [ 72 | f"<|reserved_special_token_{i}|>" 73 | for i in range(5, self.num_reserved_special_tokens - 5) 74 | ] 75 | self.special_tokens = { 76 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 77 | } 78 | self.sp_model = tiktoken.Encoding( 79 | name=Path(model_path).name, 80 | pat_str=self.pat_str, 81 | mergeable_ranks=mergeable_ranks, 82 | special_tokens=self.special_tokens, 83 | ) 84 | logger.info(f"Reloaded tiktoken model from {model_path}") 85 | 86 | self.n_words: int = self.sp_model.n_vocab 87 | # BOS / EOS token IDs 88 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 89 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 90 | self.pad_id: int = -1 91 | self.stop_tokens = { 92 | self.special_tokens["<|end_of_text|>"], 93 | self.special_tokens["<|eot_id|>"], 94 | } 95 | logger.info( 96 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 97 | ) 98 | 99 | self.v_token_id = 10955 100 | self.q_token_id = 14924 101 | self.a_token_id = 16533 102 | self.c_token_id = 5116 103 | self.nl_id = 627 104 | 105 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 106 | assert type(s) is str 107 | t = self.sp_model.encode(s) 108 | if bos: 109 | t = [self.bos_id] + t 110 | if eos: 111 | t = t + [self.eos_id] 112 | return t 113 | 114 | def encode_vqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 115 | i_text = "Instruction: Choose the correct answer based on the video and question.\n" 116 | q_text = text['q_text'] 117 | o_text = text['o_text'] 118 | a_text = text['a_text'] 119 | 120 | s1 = i_text + 'Video:' 121 | t1 = [self.bos_id] + self.sp_model.encode(s1) 122 | video_start = len(t1) 123 | 124 | s2 = q_text + o_text + a_text 125 | 126 | if split == 'train': 127 | s2 = s2 + answer_mapping[answer] 128 | t2 = self.sp_model.encode(s2) + [self.eos_id] 129 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 130 | prefix_index = t[0].index(self.a_token_id) + 5 131 | else: 132 | t = [] 133 | for k, v in answer_mapping.items(): 134 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id] 135 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 136 | prefix_index = t[answer].index(self.a_token_id) + 5 137 | return t, prefix_index, video_start 138 | 139 | def encode_vaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 140 | i_text = "Instruction: Predict the question based on the video and answer.\n" 141 | q_text = text['q_text'].strip() 142 | o_text = text['o_text'] 143 | a_text = text['a_text'] 144 | 145 | s1 = i_text + 'Video:' 146 | t1 = [self.bos_id] + self.sp_model.encode(s1) 147 | video_start = len(t1) 148 | 149 | s2 = o_text + a_text 150 | 151 | if split == 'train': 152 | s2 = s2 + answer_mapping[answer] + "\n" + q_text 153 | t2 = self.sp_model.encode(s2) + [self.eos_id] 154 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 155 | prefix_index = t[0].index(self.q_token_id) + 1 156 | else: 157 | t = [] 158 | for k, v in answer_mapping.items(): 159 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 160 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 161 | prefix_index = t[answer].index(self.q_token_id) + 1 162 | return t, prefix_index, video_start 163 | 164 | 165 | def encode_qav(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 166 | i_text = "Instruction: Predict the video based on the question and answer.\n" 167 | q_text = text['q_text'] 168 | o_text = text['o_text'] 169 | a_text = text['a_text'] 170 | 171 | s1 = i_text + q_text + o_text + a_text 172 | 173 | if split == 'train': 174 | s1 = s1 + answer_mapping[answer] + "\n" + "Video:" 175 | t1 = [self.bos_id] + self.sp_model.encode(s1) 176 | t = [t1 + [-2 for _ in range(max_feats)] + [self.eos_id]] 177 | prefix_index = t[0].index(self.v_token_id) + 1 178 | else: 179 | t = [] 180 | for k, v in answer_mapping.items(): 181 | t1 = [self.bos_id] + self.sp_model.encode(s1 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id] 182 | t.append(t1) 183 | prefix_index = t[answer].index(self.v_token_id) + 1 184 | return t, prefix_index 185 | 186 | def decode(self, t: List[int]) -> str: 187 | return self.sp_model.decode(t) 188 | 189 | def encode_dvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 190 | i_text = "Instruction: Predict the answer based on the dialogue, video and question.\n" 191 | q_text = text['q_text'] 192 | o_text = text['o_text'] 193 | a_text = text['a_text'] 194 | d_text = text['d_text'] 195 | 196 | s1 = i_text + 'Video:' 197 | t1 = [self.bos_id] + self.sp_model.encode(s1) 198 | video_start = len(t1) 199 | 200 | prefix_i = video_start + max_feats + 1 201 | d1 = self.sp_model.encode(d_text) 202 | prefix_main = prefix_i + len(d1) 203 | 204 | s2 = q_text + o_text + a_text 205 | 206 | if split == 'train': 207 | s2 = s2 + answer_mapping[answer] 208 | t2 = self.sp_model.encode(s2) + [self.eos_id] 209 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2] 210 | else: 211 | t = [] 212 | for k, v in answer_mapping.items(): 213 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id] 214 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2) 215 | 216 | prefix_index = len(t[0]) - 4 217 | 218 | return t, prefix_index, video_start, prefix_i, prefix_main 219 | 220 | def encode_dvaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 221 | i_text = "Instruction: Predict the question based on the dialogue, video and answer.\n" 222 | q_text = text['q_text'].strip() 223 | o_text = text['o_text'] 224 | a_text = text['a_text'] 225 | d_text = text['d_text'] 226 | 227 | s1 = i_text + 'Video:' 228 | t1 = [self.bos_id] + self.sp_model.encode(s1) 229 | video_start = len(t1) 230 | 231 | prefix_i = video_start + max_feats + 1 232 | d1 = self.sp_model.encode(d_text) 233 | prefix_main = prefix_i + len(d1) 234 | 235 | s2 = o_text + a_text 236 | 237 | if split == 'train': 238 | s2 = s2 + answer_mapping[answer] + "\n" + q_text 239 | t2 = self.sp_model.encode(s2) + [self.eos_id] 240 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2] 241 | else: 242 | t = [] 243 | for k, v in answer_mapping.items(): 244 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 245 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2) 246 | 247 | prefix_index = t[0].index(self.q_token_id) + 1 248 | 249 | return t, prefix_index, video_start, prefix_i, prefix_main 250 | 251 | def encode_dqav(self, text=None, max_feats=10, max_seq_len=128, split='train', answer_mapping=None, answer=None) -> List[int]: 252 | i_text = "Instruction: Predict the video based on the dialogue, question and answer.\n" 253 | d_text = text['d_text'] 254 | q_text = text['q_text'] 255 | o_text = text['o_text'] 256 | a_text = text['a_text'] 257 | s1, s2, s3 = i_text, d_text, q_text + o_text + a_text 258 | 259 | t1 = [self.bos_id] + self.sp_model.encode(s1) 260 | t2 = self.sp_model.encode(s2) 261 | prefix_i, prefix_q = len(t1), len(t1) + len(t2) 262 | 263 | if split == 'train': 264 | t3 = self.sp_model.encode(s3 + answer_mapping[answer] + "\n" + "Video:") 265 | t = [t1 + t2 + t3 + [-2 for _ in range(max_feats)] + [self.eos_id]] 266 | else: 267 | t = [] 268 | for k, v in answer_mapping.items(): 269 | t3 = self.sp_model.encode(s3 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id] 270 | t.append(t1 + t2 + t3) 271 | 272 | prefix_index = len(t[0]) - max_feats - 1 273 | 274 | return t, prefix_index, prefix_i, prefix_q 275 | 276 | 277 | 278 | def encode_videocap(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 279 | i_text = "Instruction: Generate a dense description for the video.\n" 280 | # q_text = text['q_text'].strip() 281 | # o_text = text['o_text'] 282 | # a_text = text['a_text'] 283 | 284 | s1 = i_text + 'Video:' 285 | t1 = [self.bos_id] + self.sp_model.encode(s1) 286 | video_start = len(t1) 287 | 288 | s2 = text['c_text'] 289 | 290 | if split == 'train': 291 | # s2 = s2 + answer_mapping[answer] + "\n" + q_text 292 | s2 = "\n"+s2 293 | t2 = self.sp_model.encode(s2) + [self.eos_id] 294 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 295 | prefix_index = t[0].index(self.c_token_id) + 1 296 | else: 297 | t = [] 298 | for k, v in answer_mapping.items(): 299 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id] 300 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 301 | prefix_index = t[answer].index(self.c_token_id) + 1 302 | return t, prefix_index, video_start 303 | 304 | 305 | 306 | def encode_openvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]: 307 | i_text = "Instruction: Predict the answer based on the video and question.\n" 308 | q_text = text['q_text'] 309 | oa_text = text['oa_text'] 310 | 311 | s1 = i_text + 'Video:' 312 | t1 = [self.bos_id] + self.sp_model.encode(s1) 313 | video_start = len(t1) 314 | 315 | if split == 'train': 316 | s2 = q_text + oa_text 317 | t2 = self.sp_model.encode(s2) + [self.eos_id] 318 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2] 319 | prefix_index = t[0].index(self.a_token_id)+1 320 | else: 321 | t = [] 322 | for open_option in text['open_options']: 323 | s2 = q_text + open_option 324 | t2 = self.sp_model.encode(s2) + [self.eos_id] 325 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2) 326 | prefix_index = t[0].index(self.a_token_id)+1 327 | return t, prefix_index, video_start -------------------------------------------------------------------------------- /llama_vqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from llama import ModelArgs, Tokenizer, Transformer 4 | from llama import Transformer_llama3,Tokenizer_llama3,ModelArgs_llama3 5 | from pathlib import Path 6 | 7 | def LLaMA_VQA(args, **kwargs): 8 | with open(f'{args.llama_model_path}{args.model}/params.json', "r") as f: 9 | params = json.loads(f.read()) 10 | 11 | if args.llama3: 12 | tokenizer = Tokenizer_llama3(model_path=f'{args.llama_model_path}/tokenizer.model') 13 | else: 14 | tokenizer = Tokenizer(model_path=f'{args.llama_model_path}/tokenizer.model') 15 | print(f"Using model: {args.model}") 16 | 17 | 18 | checkpoints = (Path(args.llama_model_path) / args.model).glob("*.pth") 19 | checkpoints = sorted(checkpoints) 20 | 21 | loaded = [] 22 | for x in checkpoints: 23 | print("loading from", x) 24 | loaded.append(torch.load(x, map_location="cpu")) 25 | 26 | if len(loaded) == 1: 27 | full_state_dict = loaded[0] 28 | else: 29 | full_state_dict = {} 30 | split_dims = {} 31 | 32 | def add_weight_with_split_dim(name, dim): 33 | if dim < 0: # bcast without split 34 | full_state_dict[name] = loaded[0][name].clone() 35 | else: 36 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim) 37 | for x in loaded: 38 | del x[name] 39 | split_dims[name] = dim 40 | 41 | add_weight_with_split_dim("tok_embeddings.weight", 1) 42 | add_weight_with_split_dim("norm.weight", -1) 43 | add_weight_with_split_dim("output.weight", 0) 44 | for i in range(params["n_layers"]): 45 | print("gathering layer %d of %d" % (i, params["n_layers"])) 46 | layer_prefix = f"layers.{i}." 47 | bcast_names = ["attention_norm.weight", "ffn_norm.weight"] 48 | column_parallel_names = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight", "feed_forward.w1.weight", "feed_forward.w3.weight"] 49 | row_parallel_names = ["attention.wo.weight", "feed_forward.w2.weight"] 50 | for key in bcast_names: 51 | add_weight_with_split_dim(layer_prefix + key, -1) 52 | for key in column_parallel_names: 53 | add_weight_with_split_dim(layer_prefix + key, 0) 54 | for key in row_parallel_names: 55 | add_weight_with_split_dim(layer_prefix + key, 1) 56 | 57 | if args.llama3: 58 | model_args: ModelArgs = ModelArgs_llama3(max_seq_len=args.max_seq_len, max_batch_size=32, adapter_len=args.adapter_len, adapter_layer=args.adapter_layer, **params) 59 | else: 60 | model_args: ModelArgs = ModelArgs(max_seq_len=args.max_seq_len, max_batch_size=32, adapter_len=args.adapter_len, adapter_layer=args.adapter_layer, **params) 61 | 62 | 63 | model_args.vocab_size = tokenizer.n_words 64 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 65 | if args.llama3: 66 | model_llama_vqa = Transformer_llama3(model_args, args) 67 | else: 68 | model_llama_vqa = Transformer(model_args, args) 69 | torch.set_default_tensor_type(torch.FloatTensor) 70 | missing_keys, unexpected_keys = model_llama_vqa.load_state_dict(full_state_dict, strict=False) 71 | 72 | for name, param in model_llama_vqa.named_parameters(): 73 | 74 | if ('gate' in name) or ('adapter' in name) or ('temporal_emb' in name) or ('visual_proj' in name) or ('query_tokens' in name): 75 | param.requires_grad = True 76 | param.data = param.data.float() 77 | 78 | else: 79 | # print(name) 80 | # print(param.data.dtype) 81 | if args.llama2: 82 | param.data = param.data.bfloat16() 83 | param.requires_grad = False 84 | 85 | 86 | return model_llama_vqa -------------------------------------------------------------------------------- /pics/topa_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/pics/topa_framework.jpg -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/pretrained/.gitkeep -------------------------------------------------------------------------------- /scripts/baseline/llama2_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 2 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 4 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_nextqa_1e2 --adapter_len 50 \ 4 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \ 5 | 6 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 7 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 8 | --blr 2e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_star_2e2 --adapter_len 50 \ 9 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40\ 10 | 11 | 12 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 13 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 14 | --blr 2e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_vtqa_2e2 --adapter_len 50 \ 15 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40\ -------------------------------------------------------------------------------- /scripts/baseline/llama2_7b.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 2 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_nextqa_2e2_acc4 --adapter_len 50 \ 4 | --llama2 --llama_model_path ./pretrained/llama2/ \ 5 | 6 | 7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 8 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 9 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_star_1e2_acc4 --adapter_len 50 \ 10 | --llama2 --llama_model_path ./pretrained/llama2/ \ 11 | 12 | 13 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 14 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 15 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_star_1e2_acc4 --adapter_len 50 \ 16 | --llama2 --llama_model_path ./pretrained/llama2/ \ 17 | 18 | 19 | -------------------------------------------------------------------------------- /scripts/baseline/llama3_8b.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 2 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_tvqa_1e2_acc4 --adapter_len 50 \ 4 | --llama3 --llama_model_path ./pretrained/llama3/ \ 5 | 6 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 7 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 8 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_nextqa_2e2_acc4 --adapter_len 50 \ 9 | --llama3 --llama_model_path ./pretrained/llama3/ \ 10 | 11 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 12 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 13 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_star_2e2_acc4 --adapter_len 50 \ 14 | --llama3 --llama_model_path ./pretrained/llama3/ \ 15 | 16 | -------------------------------------------------------------------------------- /scripts/eval/zeroshot_eval_egos.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 7B \ 2 | --max_seq_len 128 --batch_size 1 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \ 3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 4 | --output_dir results/llama2_7B/egos \ 5 | --llama2 --llama_model_path ./pretrained/llama2/ \ 6 | --memory \ 7 | --test 8 | 9 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 13B --adapter_layer 40 \ 10 | # --max_seq_len 128 --batch_size 1 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \ 11 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \ 12 | # --output_dir results/llama2_13B/egos \ 13 | # --llama2 --llama_model_path ./pretrained/llama2/ \ 14 | # --memory \ 15 | # --test 16 | 17 | 18 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 8B \ 19 | # --max_seq_len 150 --batch_size 1 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \ 20 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 21 | # --output_dir results/llama3_8B/egos \ 22 | # --llama3 --llama_model_path ./pretrained/llama3/ \ 23 | # --memory \ 24 | # --test 25 | 26 | -------------------------------------------------------------------------------- /scripts/eval/zeroshot_eval_nextqa.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 2 | --max_seq_len 128 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \ 3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 4 | --output_dir results/llama2_7B/nextqa \ 5 | --llama2 --llama_model_path ./pretrained/llama2/ \ 6 | --memory \ 7 | 8 | 9 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \ 10 | # --max_seq_len 128 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \ 11 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \ 12 | # --output_dir results/llama2_13B/nextqa \ 13 | # --llama2 --llama_model_path ./pretrained/llama2/ \ 14 | # --memory \ 15 | 16 | 17 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 18 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \ 19 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 20 | # --output_dir results/llama3_8B/nextqa \ 21 | # --llama3 --llama_model_path ./pretrained/llama3/ \ 22 | # --memory \ 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/eval/zeroshot_eval_star.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 2 | --max_seq_len 128 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \ 3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 4 | --output_dir results/llama2_7B/star \ 5 | --llama2 --llama_model_path ./pretrained/llama2/ \ 6 | --memory \ 7 | 8 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \ 9 | # --max_seq_len 128 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \ 10 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \ 11 | # --output_dir results/llama2_13B/star \ 12 | # --llama2 --llama_model_path ./pretrained/llama2/ \ 13 | # --memory \ 14 | 15 | 16 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 17 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \ 18 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 19 | # --output_dir results/llama3_8B/star \ 20 | # --llama3 --llama_model_path ./pretrained/llama3/ \ 21 | # --memory \ 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/eval/zeroshot_eval_tvqa.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 2 | --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \ 3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 4 | --output_dir results/llama2_7B/tvqa \ 5 | --llama2 --llama_model_path ./pretrained/llama2/ \ 6 | --memory \ 7 | 8 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \ 9 | # --max_seq_len 150 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \ 10 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \ 11 | # --output_dir results/llama2_13B/tvqa \ 12 | # --llama2 --llama_model_path ./pretrained/llama2/ \ 13 | # --memory \ 14 | 15 | 16 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 17 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \ 18 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \ 19 | # --output_dir results/llama3_8B/tvqa \ 20 | # --llama3 --llama_model_path ./pretrained/llama3/ \ 21 | # --memory \ 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/finetune/LLama2_13b_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 2 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 3 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_nextqa_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \ 4 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \ 5 | 6 | 7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 8 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 9 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_star_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \ 10 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \ 11 | 12 | 13 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \ 14 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 15 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_tvqa_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \ 16 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \ 17 | -------------------------------------------------------------------------------- /scripts/finetune/LLama2_7b_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 2 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 3 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_nextqa_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \ 4 | --llama2 --llama_model_path ./pretrained/llama2/ 5 | 6 | 7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 8 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 9 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_realstar_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \ 10 | --llama2 --llama_model_path ./pretrained/llama2/ 11 | 12 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \ 13 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 14 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_tvqa_5e3_acc4 --accum_iter 4 --adapter_len 50 --finetune \ 15 | --llama2 --llama_model_path ./pretrained/llama2/ -------------------------------------------------------------------------------- /scripts/finetune/LLama3_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 3 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \ 4 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_nextqa_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \ 5 | --llama3 --llama_model_path ./pretrained/llama3/ \ 6 | 7 | 8 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 9 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \ 10 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_star_1e2_acc4_ep10 --accum_iter 4 --adapter_len 50 --finetune \ 11 | --llama3 --llama_model_path ./pretrained/llama3/ \ 12 | 13 | 14 | 15 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \ 16 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \ 17 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_tvqa_5e3_acc4 --accum_iter 4 --adapter_len 50 --finetune \ 18 | --llama3 --llama_model_path ./pretrained/llama3/ \ 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/pretrain/llama2_13b.sh: -------------------------------------------------------------------------------- 1 | 2 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number 3 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 13B \ 4 | --max_seq_len 150 --batch_size 4 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \ 5 | --blr 8e-3 --weight_decay 0.1 --output_dir vqa_checkpoint/checkpoint_pretrain/llama2_13b_test --accum_iter 8 --textvid --variance 0.0 --memory --video_caption --vaq --openvqa --answer_balance --adapter_len 50 \ 6 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \ 7 | 8 | -------------------------------------------------------------------------------- /scripts/pretrain/llama2_7b.sh: -------------------------------------------------------------------------------- 1 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number 2 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 7B \ 3 | --max_seq_len 150 --batch_size 18 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \ 4 | --blr 5e-3 --weight_decay 0.1 --output_dir vqa_checkpoint/checkpoint_pretrain/llama2_test --accum_iter 4 --textvid --variance 0.0 --video_caption --vaq --openvqa --answer_balance --adapter_len 50 --memory \ 5 | --llama2 --llama_model_path ./pretrained/llama2/ \ 6 | -------------------------------------------------------------------------------- /scripts/pretrain/llama3_8b.sh: -------------------------------------------------------------------------------- 1 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number 2 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 8B \ 3 | --max_seq_len 150 --batch_size 14 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \ 4 | --blr 5e-3 --weight_decay 0.1 --output_dir ./vqa_checkpoint/checkpoint_pretrain/llama3_test --accum_iter 8 --textvid --variance 0.0 --memory --video_caption --vaq --openvqa --answer_balance --adapter_len 50 \ 5 | --llama3 --llama_model_path ./pretrained/llama3/ \ 6 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html 2 | pip install fairscale 3 | pip install fire 4 | pip install sentencepiece 5 | pip install transformers 6 | pip install timm 7 | pip install pandas 8 | pip install setuptools==59.5.0 9 | pip install pysrt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import json 5 | import time 6 | import numpy as np 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | import timm 13 | import timm.optim.optim_factory as optim_factory 14 | 15 | import util.misc as misc 16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 17 | from engine import train_one_epoch, val_one_epoch,test_one_epoch 18 | from llama import Tokenizer, Tokenizer_llama3 19 | from llama_vqa import LLaMA_VQA 20 | from dataloader import load_data, load_data_instruct 21 | from torch.utils.data import DataLoader, ConcatDataset 22 | 23 | def save_arguments(args, filepath): 24 | with open(filepath, 'w') as file: 25 | json.dump(vars(args), file) 26 | 27 | def load_arguments(filepath): 28 | with open(filepath, 'r') as file: 29 | args_dict = json.load(file) 30 | return args_dict 31 | 32 | # Optionally, repopulate argparse.ArgumentParser with these arguments 33 | def repopulate_arguments(args_dict): 34 | parser = argparse.ArgumentParser(description="Example script") 35 | for key, value in args_dict.items(): 36 | parser.add_argument(f'--{key}', type=type(value),default=value) 37 | return parser.parse_args([]) 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 41 | parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 42 | parser.add_argument('--epochs', default=400, type=int) 43 | parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 44 | 45 | # Model parameters 46 | parser.add_argument('--llama_model_path', default='./pretrained/llama/', type=str, help='path of llama model') 47 | parser.add_argument('--model', default='llama7B_adapter', type=str, metavar='MODEL', help='Name of model to train') 48 | parser.add_argument('--adapter_layer', type=int, default=32, metavar='LENGTH', help='the number of adapter layer') 49 | parser.add_argument('--adapter_len', type=int, default=10, metavar='LENGTH', help='the adapter length') 50 | parser.add_argument('--max_seq_len', type=int, default=512, metavar='LENGTH', help='the maximum sequence length') 51 | parser.add_argument('--max_feats', type=int, default=10, metavar='LENGTH', help='the maximum feature length') 52 | 53 | # Optimizer parameters 54 | parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') 55 | parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') 56 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 57 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') 58 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') 59 | 60 | # Dataset parameters 61 | parser.add_argument('--dataset', default='nextqa', type=str, help='dataset') 62 | parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') 63 | parser.add_argument('--device', default='cuda', help='device to use for training / testing') 64 | parser.add_argument('--seed', default=0, type=int) 65 | parser.add_argument('--resume', default='', help='resume from checkpoint') 66 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') 67 | parser.add_argument('--num_workers', default=2, type=int) 68 | parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 69 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 70 | parser.set_defaults(pin_mem=True) 71 | 72 | # distributed training parameters 73 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 74 | parser.add_argument('--local_rank', default=-1, type=int) 75 | parser.add_argument('--dist_on_itp', action='store_true') 76 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 77 | 78 | parser.add_argument('--vaq', action='store_true', help='vaq loss') 79 | parser.add_argument('--qav', action='store_true', help='qav loss') 80 | parser.add_argument('--bias', type=float, default=3., help='attention bias') 81 | parser.add_argument('--tau', type=float, default=100., help='tau') 82 | parser.add_argument('--sub', action='store_true', help='subtitles for VLEP and TVQA') 83 | parser.add_argument('--eval', action='store_true', help='eval') 84 | parser.add_argument('--test', action='store_true', help='test') 85 | parser.add_argument('--memory', action='store_true', help='meomory') 86 | parser.add_argument('--finetune', action='store_true', help='finetune') 87 | parser.add_argument('--data_ratio', type=float, default=1., help='tau') 88 | parser.add_argument('--textvid', action='store_true', help='virtual video training') 89 | parser.add_argument('--variance', type=float, default=0., help='variance') 90 | parser.add_argument('--evalall', action='store_true', help='evalall') 91 | parser.add_argument('--debug', action='store_true', help='debug') 92 | parser.add_argument('--onlyqa', action='store_true', help='onlyqa') 93 | parser.add_argument('--llama2', action='store_true', help='llama2') 94 | parser.add_argument('--llama3', action='store_true', help='llama3') 95 | parser.add_argument('--answer_balance', action='store_true', help='balance_abcde') 96 | parser.add_argument('--video_caption', action='store_true', help='video captioning training') 97 | parser.add_argument('--instruct', action='store_true', help='instruct') 98 | parser.add_argument('--openvqa', action='store_true', help='openvqa') 99 | parser.add_argument('--weight_captioning', type=float, default=1.0, help='weight_captioning') 100 | parser.add_argument('--webvid', action='store_true', help='webvidfituning') 101 | parser.add_argument('--openvqa_eval', action='store_true', help='logits for MCQA') 102 | parser.add_argument('--single_frame', action='store_true', help='single_frame') 103 | 104 | 105 | 106 | 107 | return parser 108 | 109 | 110 | def main(args): 111 | misc.init_distributed_mode(args) 112 | 113 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 114 | print("{}".format(args).replace(', ', ',\n')) 115 | 116 | device = torch.device(args.device) 117 | 118 | # fix the seed for reproducibility 119 | seed = args.seed + misc.get_rank() 120 | torch.manual_seed(seed) 121 | np.random.seed(seed) 122 | 123 | cudnn.benchmark = True 124 | if args.llama3: 125 | tokenizer = Tokenizer_llama3(model_path=f'{args.llama_model_path}./tokenizer.model') 126 | else: 127 | tokenizer = Tokenizer(model_path=f'{args.llama_model_path}./tokenizer.model') 128 | 129 | 130 | model = LLaMA_VQA(args) 131 | model.to(device) 132 | 133 | model_without_ddp = model 134 | # print("Model = %s" % str(model_without_ddp)) 135 | 136 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 137 | 138 | if args.lr is None: # only base_lr is specified 139 | args.lr = args.blr * eff_batch_size / 256 140 | 141 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 142 | print("actual lr: %.2e" % args.lr) 143 | 144 | print("accumulate grad iterations: %d" % args.accum_iter) 145 | print("effective batch size: %d" % eff_batch_size) 146 | 147 | if args.distributed: 148 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 149 | model_without_ddp = model.module 150 | 151 | # following timm: set wd as 0 for bias and norm layers 152 | 153 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 154 | 155 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 156 | print(optimizer) 157 | loss_scaler = NativeScaler() 158 | best_acc = 0. 159 | 160 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 161 | 162 | 163 | print(f"Start training for {args.epochs} epochs") 164 | start_time = time.time() 165 | 166 | if args.eval: 167 | 168 | epoch=0 169 | if args.dataset == 'egos': 170 | args.batch_size=1 171 | args.max_seq_len = 600 172 | 173 | model.module.re_init_freqs(600) 174 | if args.test: 175 | data_loader_val = load_data(args, tokenizer, split='test') 176 | if args.distributed: 177 | data_loader_val.sampler.set_epoch(epoch) 178 | val_stats = test_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args) 179 | else: 180 | data_loader_val = load_data(args, tokenizer, split='val') 181 | if args.distributed: 182 | data_loader_val.sampler.set_epoch(epoch) 183 | val_stats = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args) 184 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}} 185 | 186 | if args.output_dir and misc.is_main_process(): 187 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 188 | f.write(json.dumps(log_stats) + "\n") 189 | 190 | elif args.textvid: 191 | data_loader_vals = {} 192 | batch_size = args.batch_size 193 | max_seq_len = args.max_seq_len 194 | data_loader_train = load_data(args, tokenizer, split='train') 195 | 196 | eval_datasets = ['egos'] 197 | 198 | for dataset_name in eval_datasets: 199 | # for dataset_name in ['egos','tvqa']: 200 | args.dataset=dataset_name 201 | if dataset_name in ['egos']: 202 | args.batch_size=1 203 | args.max_seq_len = 600 204 | else: 205 | args.batch_size= batch_size 206 | args.max_seq_len = 200 207 | data_loader_vals[dataset_name] = load_data(args, tokenizer, split='val') 208 | 209 | for epoch in range(args.start_epoch, args.epochs): 210 | 211 | if args.distributed: 212 | data_loader_train.sampler.set_epoch(epoch) 213 | model.module.re_init_freqs(600) 214 | train_stats = train_one_epoch(model, data_loader_train, optimizer, epoch, loss_scaler, args=args) 215 | val_stats = {} 216 | for key,data_loader_val in data_loader_vals.items(): 217 | if args.distributed: 218 | data_loader_val.sampler.set_epoch(epoch) 219 | val_stats[key] = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args) 220 | 221 | if True: 222 | model_name = f"checkpoint_{epoch}" 223 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name) 224 | 225 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch} 226 | 227 | if args.output_dir and misc.is_main_process(): 228 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 229 | f.write(json.dumps(log_stats) + "\n") 230 | for key, val_stat in val_stats.items(): 231 | log_stat = {'dataset:':key, **{f'val_{k}': v for k, v in val_stat.items()}} 232 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 233 | f.write(json.dumps(log_stat) + "\n") 234 | 235 | else: 236 | model.module.re_init_freqs(600) 237 | data_loader_train = load_data(args, tokenizer, split='train') 238 | data_loader_val = load_data(args, tokenizer, split='val') 239 | for epoch in range(args.start_epoch, args.epochs): 240 | 241 | if args.distributed: 242 | data_loader_train.sampler.set_epoch(epoch) 243 | data_loader_val.sampler.set_epoch(epoch) 244 | 245 | train_stats = train_one_epoch(model, data_loader_train, optimizer, epoch, loss_scaler, args=args) 246 | val_stats = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args) 247 | 248 | if args.output_dir and best_acc < val_stats['acc']: 249 | best_acc = val_stats['acc'] 250 | model_name = 'checkpoint_best' 251 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name) 252 | if True: 253 | model_name = f"checkpoint_{epoch}" 254 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name) 255 | 256 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, **{f'val_{k}': v for k, v in val_stats.items()}} 257 | 258 | if args.output_dir and misc.is_main_process(): 259 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 260 | f.write(json.dumps(log_stats) + "\n") 261 | 262 | total_time = time.time() - start_time 263 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 264 | print('Training time {}'.format(total_time_str)) 265 | 266 | 267 | if __name__ == '__main__': 268 | args = get_args_parser() 269 | args = args.parse_args() 270 | if args.output_dir: 271 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 272 | save_arguments(args, args.output_dir+'/args.json') 273 | 274 | main(args) 275 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import inf 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | def __init__(self, window_size=20, fmt=None): 29 | if fmt is None: 30 | fmt = "{median:.4f} ({global_avg:.4f})" 31 | self.deque = deque(maxlen=window_size) 32 | self.total = 0.0 33 | self.count = 0 34 | self.fmt = fmt 35 | 36 | def update(self, value, n=1): 37 | self.deque.append(value) 38 | self.count += n 39 | self.total += value * n 40 | 41 | def synchronize_between_processes(self): 42 | """ 43 | Warning: does not synchronize the deque! 44 | """ 45 | if not is_dist_avail_and_initialized(): 46 | return 47 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 48 | dist.barrier() 49 | dist.all_reduce(t) 50 | t = t.tolist() 51 | self.count = int(t[0]) 52 | self.total = t[1] 53 | 54 | @property 55 | def median(self): 56 | d = torch.tensor(list(self.deque)) 57 | return d.median().item() 58 | 59 | @property 60 | def avg(self): 61 | d = torch.tensor(list(self.deque), dtype=torch.float32) 62 | return d.mean().item() 63 | 64 | @property 65 | def global_avg(self): 66 | return self.total / self.count 67 | 68 | @property 69 | def max(self): 70 | return max(self.deque) 71 | 72 | @property 73 | def value(self): 74 | return self.deque[-1] 75 | 76 | def __str__(self): 77 | return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, n=1, **kwargs): 86 | for k, v in kwargs.items(): 87 | if v is None: 88 | continue 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v, n=n) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 100 | 101 | def __str__(self): 102 | loss_str = [] 103 | for name, meter in self.meters.items(): 104 | loss_str.append("{}: {}".format(name, str(meter))) 105 | return self.delimiter.join(loss_str) 106 | 107 | def synchronize_between_processes(self): 108 | for meter in self.meters.values(): 109 | meter.synchronize_between_processes() 110 | 111 | def add_meter(self, name, meter): 112 | self.meters[name] = meter 113 | 114 | def log_every(self, iterable, print_freq, header=None): 115 | i = 0 116 | if not header: 117 | header = '' 118 | start_time = time.time() 119 | end = time.time() 120 | iter_time = SmoothedValue(fmt='{avg:.4f}') 121 | data_time = SmoothedValue(fmt='{avg:.4f}') 122 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 123 | log_msg = [header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}'] 124 | if torch.cuda.is_available(): 125 | log_msg.append('max mem: {memory:.0f}') 126 | log_msg = self.delimiter.join(log_msg) 127 | MB = 1024.0 * 1024.0 128 | for obj in iterable: 129 | data_time.update(time.time() - end) 130 | yield obj 131 | iter_time.update(time.time() - end) 132 | if i % print_freq == 0 or i == len(iterable) - 1: 133 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 134 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 135 | if torch.cuda.is_available(): 136 | print(log_msg.format(i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) 137 | else: 138 | print(log_msg.format(i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) 139 | i += 1 140 | end = time.time() 141 | total_time = time.time() - start_time 142 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 143 | print('{} Total time: {} ({:.4f} s / it)'.format(header, total_time_str, total_time / len(iterable))) 144 | 145 | 146 | def setup_for_distributed(is_master): 147 | """ 148 | This function disables printing when not in master process 149 | """ 150 | builtin_print = builtins.print 151 | 152 | def print(*args, **kwargs): 153 | force = kwargs.pop('force', False) 154 | force = force or (get_world_size() > 8) 155 | if is_master or force: 156 | now = datetime.datetime.now().time() 157 | builtin_print('[{}] '.format(now), end='') # print with time stamp 158 | builtin_print(*args, **kwargs) 159 | 160 | builtins.print = print 161 | 162 | 163 | def is_dist_avail_and_initialized(): 164 | if not dist.is_available(): 165 | return False 166 | if not dist.is_initialized(): 167 | return False 168 | return True 169 | 170 | 171 | def get_world_size(): 172 | if not is_dist_avail_and_initialized(): 173 | return 1 174 | return dist.get_world_size() 175 | 176 | 177 | def get_rank(): 178 | if not is_dist_avail_and_initialized(): 179 | return 0 180 | return dist.get_rank() 181 | 182 | 183 | def is_main_process(): 184 | return get_rank() == 0 185 | 186 | 187 | def save_on_master(*args, **kwargs): 188 | if is_main_process(): 189 | torch.save(*args, **kwargs) 190 | 191 | 192 | def init_distributed_mode(args): 193 | if args.dist_on_itp: 194 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 195 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 196 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 197 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 198 | os.environ['LOCAL_RANK'] = str(args.gpu) 199 | os.environ['RANK'] = str(args.rank) 200 | os.environ['WORLD_SIZE'] = str(args.world_size) 201 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 202 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 203 | args.rank = int(os.environ["RANK"]) 204 | args.world_size = int(os.environ['WORLD_SIZE']) 205 | args.gpu = int(os.environ['LOCAL_RANK']) 206 | elif 'SLURM_PROCID' in os.environ: 207 | args.rank = int(os.environ['SLURM_PROCID']) 208 | args.gpu = args.rank % torch.cuda.device_count() 209 | else: 210 | print('Not using distributed mode') 211 | setup_for_distributed(is_master=True) # hack 212 | args.distributed = False 213 | return 214 | 215 | args.distributed = True 216 | 217 | torch.cuda.set_device(args.gpu) 218 | args.dist_backend = 'nccl' 219 | print('| distributed init (rank {}): {}, gpu {}'.format(args.rank, args.dist_url, args.gpu), flush=True) 220 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 221 | torch.distributed.barrier() 222 | 223 | setup_for_distributed(args.rank == 0) 224 | 225 | 226 | class NativeScalerWithGradNormCount: 227 | state_dict_key = "amp_scaler" 228 | 229 | def __init__(self): 230 | self._scaler = torch.cuda.amp.GradScaler() 231 | 232 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 233 | self._scaler.scale(loss).backward(create_graph=create_graph) 234 | if update_grad: 235 | if clip_grad is not None: 236 | assert parameters is not None 237 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 238 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 239 | else: 240 | self._scaler.unscale_(optimizer) 241 | norm = get_grad_norm_(parameters) 242 | self._scaler.step(optimizer) 243 | self._scaler.update() 244 | else: 245 | norm = None 246 | return norm 247 | 248 | def state_dict(self): 249 | return self._scaler.state_dict() 250 | 251 | def load_state_dict(self, state_dict): 252 | self._scaler.load_state_dict(state_dict) 253 | 254 | 255 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 256 | if isinstance(parameters, torch.Tensor): 257 | parameters = [parameters] 258 | parameters = [p for p in parameters if p.grad is not None] 259 | norm_type = float(norm_type) 260 | if len(parameters) == 0: 261 | return torch.tensor(0.) 262 | device = parameters[0].grad.device 263 | if norm_type == inf: 264 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 265 | else: 266 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 267 | return total_norm 268 | 269 | 270 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, name): 271 | output_dir = Path(args.output_dir) 272 | epoch_name = str(epoch) 273 | if loss_scaler is not None: 274 | checkpoint_paths = [output_dir / (f'{name}.pth')] 275 | 276 | unfrozen_model = {} 277 | for n, p in model_without_ddp.named_parameters(): 278 | if ('gate' in n) or ('adapter' in n) or ('temporal_emb' in n) or ('visual_proj' in n): 279 | unfrozen_model[n] = p 280 | 281 | for checkpoint_path in checkpoint_paths: 282 | to_save = { 283 | 'model': unfrozen_model, 284 | 'optimizer': optimizer.state_dict(), 285 | 'epoch': epoch, 286 | 'scaler': loss_scaler.state_dict(), 287 | 'args': args, 288 | } 289 | 290 | save_on_master(to_save, checkpoint_path) 291 | else: 292 | client_state = {'epoch': epoch} 293 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 294 | 295 | 296 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 297 | if args.resume: 298 | if args.resume.startswith('https'): 299 | checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) 300 | else: 301 | checkpoint = torch.load(args.resume, map_location='cpu') 302 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 303 | print("Resume checkpoint %s" % args.resume) 304 | if not args.finetune: 305 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 306 | optimizer.load_state_dict(checkpoint['optimizer']) 307 | args.start_epoch = checkpoint['epoch'] + 1 308 | if 'scaler' in checkpoint: 309 | loss_scaler.load_state_dict(checkpoint['scaler']) 310 | print("With optim & sched!") 311 | 312 | 313 | def all_reduce_mean(x): 314 | world_size = get_world_size() 315 | if world_size > 1: 316 | x_reduce = torch.tensor(x).cuda() 317 | dist.all_reduce(x_reduce) 318 | x_reduce /= world_size 319 | return x_reduce.item() 320 | else: 321 | return x 322 | 323 | 324 | def getCount(freq): 325 | count, total = freq[0], freq[1] 326 | return count / total if total != 0 else 0.0 327 | 328 | def log_qtype(data, eval, metric_logger, args): 329 | ep = 1e-10 330 | 331 | if args.dataset == 'nextqa': 332 | qtype2id= {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8} 333 | elif args.dataset == "star": 334 | qtype2id= {'In': 1, 'Seq': 2, 'Pre': 3, 'Feas': 4} 335 | else: 336 | return 337 | 338 | q_freq = {i : [0., 0.] for i in qtype2id.values()} 339 | q_freq[0] = [0., 0.] 340 | for i, v in enumerate(eval): 341 | qt = data['qtype'][i].item() 342 | q_freq[qt][0] += v.item() 343 | q_freq[qt][1] += 1 344 | q_freq[0][0] += v.item() 345 | q_freq[0][1] += 1 346 | 347 | if args.dataset == 'nextqa': 348 | metric_logger.update(n=(q_freq[1][1]+q_freq[2][1]+ ep), C=(q_freq[1][0]+q_freq[2][0]) / (q_freq[1][1]+q_freq[2][1]+ ep)) 349 | metric_logger.update(n=(q_freq[3][1]+q_freq[4][1]+ q_freq[5][1]+ ep), T=(q_freq[3][0]+q_freq[4][0]+q_freq[5][0]) / (q_freq[3][1]+q_freq[4][1]+ q_freq[5][1]+ ep)) 350 | metric_logger.update(n=(q_freq[6][1]+q_freq[7][1]+ q_freq[8][1]+ ep), D=(q_freq[6][0]+q_freq[7][0]+q_freq[8][0]) / (q_freq[6][1]+q_freq[7][1]+ q_freq[8][1]+ ep)) 351 | metric_logger.update(n=q_freq[0][1]+ep, Total=getCount(q_freq[0])) 352 | elif args.dataset == "star": 353 | metric_logger.update(n=q_freq[1][1]+ep, In=getCount(q_freq[1])) 354 | metric_logger.update(n=q_freq[2][1]+ep, Seq=getCount(q_freq[2])) 355 | metric_logger.update(n=q_freq[3][1]+ep, Pre=getCount(q_freq[3])) 356 | metric_logger.update(n=q_freq[4][1]+ep, Feas=getCount(q_freq[4])) 357 | metric_logger.update(n=q_freq[0][1]+ep, Total=getCount(q_freq[0])) 358 | -------------------------------------------------------------------------------- /vqa_checkpoint/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/vqa_checkpoint/.gitkeep --------------------------------------------------------------------------------