├── .gitignore
├── LICENSE
├── README.md
├── data
└── .gitkeep
├── dataloader
├── __init__.py
├── base_dataset.py
├── dramaqa.py
├── egoschema.py
├── how2qa.py
├── nextqa.py
├── perception.py
├── star.py
├── textvid.py
├── tvqa.py
├── vlep.py
└── webvid.py
├── demos
├── 13B_msrvtt_results.json
├── 7B_msrvtt_results.json
├── Eval_Cap_MSRVTT.ipynb
├── Eval_Cap_VATEX.ipynb
├── mvbench.ipynb
├── upload_leaderboard_13B_zero_shot.json
├── upload_leaderboard_7B_ZS.json
└── video_transforms.py
├── engine.py
├── llama
├── __init__.py
├── generation.py
├── model.py
├── model_llama3.py
├── tokenizer.py
└── tokenizer_llama3.py
├── llama_vqa.py
├── pics
└── topa_framework.jpg
├── pretrained
└── .gitkeep
├── scripts
├── baseline
│ ├── llama2_13b.sh
│ ├── llama2_7b.sh
│ └── llama3_8b.sh
├── eval
│ ├── zeroshot_eval_egos.sh
│ ├── zeroshot_eval_nextqa.sh
│ ├── zeroshot_eval_star.sh
│ └── zeroshot_eval_tvqa.sh
├── finetune
│ ├── LLama2_13b_finetune.sh
│ ├── LLama2_7b_finetune.sh
│ └── LLama3_finetune.sh
└── pretrain
│ ├── llama2_13b.sh
│ ├── llama2_7b.sh
│ └── llama3_8b.sh
├── setup.sh
├── train.py
├── util
├── lr_sched.py
└── misc.py
└── vqa_checkpoint
└── .gitkeep
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore
2 | data/*
3 | !data/.gitkeep
4 | pretrained/*
5 | !pretrained/.gitkeep
6 | vqa_checkpoint/*
7 | !vqa_checkpoint/.gitkeep
8 | /python
9 | results/*
10 |
11 | *ipynb_checkpoints
12 | *.pyc
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MLV Lab (Machine Learning and Vision Lab at Korea University)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
TOPA: Extend Large Language Models for Video Understanding via Text-Only Pre-Alignment
3 |
NeurIPS 2024 Spollight
4 |
5 |
6 |
7 |
8 |
9 |

10 |
11 |
12 | ## Data Preparation:
13 | Prepare the data as follows.
14 |
15 | **TextVID**: download TextVID at [TextVid only](https://drive.google.com/file/d/12xocihCDYocHVtsdzymii3BnTJmlh430/view?usp=sharing). download TextVID and preprocessed features at [TextVid and features](https://drive.google.com/file/d/1hfMIlABeAl9D_qhG5EcLUVM2HcZArTUx/view?usp=sharing)
16 |
17 | **NeXTQA, STAR and TVQA**:
18 | The prepocessed feautures are available at [here](https://github.com/mlvlab/Flipped-VQA).
19 |
20 | **EgoScehma**:
21 | Download raw videos from [EgoSchema](https://github.com/egoschema/EgoSchema). We provide prepocessed feature [here](https://drive.google.com/file/d/1yCAw101BZvOtSntDToNX7TR8ErdnfSdk/view?usp=sharing)
22 |
23 | **MVBench**:
24 | Download raw videos from [Hugging Face](https://huggingface.co/datasets/OpenGVLab/MVBench).
25 |
26 | **MSRVTT**:
27 | Download raw videos from [MSRVTT](https://github.com/crux82/msr-vtt-it).
28 |
29 | ```
30 | ./data
31 | |─ nextqa
32 | | |─ train.csv
33 | | |─ val.csv
34 | | └─ clipvitl14.pth
35 | |─ star
36 | | :
37 | |─ tvqa
38 | | :
39 | └─ egos
40 | :
41 | ```
42 | ## Model Preparation:
43 | Prepare the model as follows.
44 |
45 | **LLMs**: Download the pretrained Llama models from [Llama2](https://github.com/meta-llama/llama) and [Llama3](https://github.com/meta-llama/llama3).
46 |
47 | **TOPA Checkpoints**: Download our [pretrained models](https://drive.google.com/file/d/1-Ce6LC-1TeKvUbg_BeCWzsps6XBf-dlG/view?usp=sharing)
48 | ```
49 | ./pretrained
50 | └─ llama2
51 | | |─ 7B
52 | | | |─ consolidated.00.pth
53 | | | └─ params.json
54 | | |─ 13B
55 | | | :
56 | | | :
57 | | └─ tokenizer.model
58 | └─ llama3
59 | |─ 8B
60 | | |─ consolidated.00.pth
61 | | └─ params.json
62 | └─ tokenizer.model
63 |
64 | ./vqa_checkpoint
65 | └─ checkpoint_pretrain
66 | |─ llama2_7b
67 | |─ llama2_13b
68 | └─ llama3_8b
69 | ```
70 |
71 | ## Training & Evaluation
72 | ### Text-only Pre-alignment
73 | ```
74 | ./scripts/pretrain/llama2_7b.sh
75 | ```
76 | ### Zero-shot inference
77 | ```
78 | ./scripts/eval/zeroshot_eval_egos.sh
79 | ./scripts/eval/zeroshot_eval_nextqa.sh
80 | ./scripts/eval/zeroshot_eval_star.sh
81 | ./scripts/eval/zeroshot_eval_tvqa.sh
82 | ```
83 | ### Evaluate on MVBench
84 | [mvbench.ipynb](demos/mvbench.ipynb)
85 |
86 | ### Evaluate on video captioning benchmarks
87 | [MSRVTT.ipynb](demos/Eval_Cap_MSRVTT.ipynb)
88 |
89 | [VATEX.ipynb](demos/Eval_Cap_VATEX.ipynb)
90 |
91 | ## Acknowledgements
92 | This repo is built upon [Flipped-VQA](https://github.com/mlvlab/Flipped-VQA) and benefits from [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [DeCap](https://github.com/dhg-wei/DeCap), [MVBench](https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/MVBENCH.md), [Llama2](https://github.com/meta-llama/llama) and [Llama3](https://github.com/meta-llama/llama3).
93 |
94 |
95 | ## Citations
96 |
97 | ```
98 | @article{li2024topa,
99 | title={TOPA: Extend Large Language Models for Video Understanding via Text-Only Pre-Alignment},
100 | author={Li, Wei and Fan, Hehe and Wong, Yongkang and Kankanhalli, Mohan and Yang, Yi},
101 | journal={arXiv preprint arXiv:2405.13911},
102 | year={2024}
103 | }
104 | ```
105 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/data/.gitkeep
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from util import misc
3 | from .nextqa import NextQA
4 | from .dramaqa import DramaQA
5 | from .star import STAR
6 | from .vlep import VLEP
7 | from .tvqa import TVQA
8 | from .textvid import TextVid
9 | from .egoschema import EgoSchema
10 | from .perception import PerceptionTest
11 | from .how2qa import How2qa
12 | from .webvid import Webvid
13 |
14 | from torch.utils.data import DataLoader, ConcatDataset
15 |
16 | dataset_mapping = {'nextqa': NextQA, 'star': STAR, 'dramaqa': DramaQA, 'vlep': VLEP, 'tvqa': TVQA,'textvid':TextVid,'egos':EgoSchema,'perc':PerceptionTest,'how2qa':How2qa,'webvid':Webvid}
17 | num_options_mapping = {'nextqa': 5, 'star': 4, 'dramaqa': 5, 'vlep': 2, 'tvqa': 5,'textvid': 5,'egos':5,'perc':3,'how2qa':4,'webvid':5}
18 |
19 | def load_data(args, tokenizer, split='train'):
20 | if split=='train' and args.textvid:
21 | args.num_options = num_options_mapping['textvid']
22 | dataset = dataset_mapping['textvid'](args=args, tokenizer=tokenizer, split=split)
23 | else:
24 | args.num_options = num_options_mapping[args.dataset]
25 | dataset = dataset_mapping[args.dataset](args=args, tokenizer=tokenizer, split=split)
26 |
27 | num_tasks = misc.get_world_size()
28 | global_rank = misc.get_rank()
29 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True)
30 | data_loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=batch_collate,
31 | pin_memory=args.pin_mem, drop_last=False)
32 |
33 | return data_loader
34 |
35 |
36 |
37 | def load_data_instruct(args, tokenizer, split='train'):
38 |
39 | instruct_datasets = ['nextqa','star','tvqa']
40 | data_loader_instruct=[]
41 | for dataset_name in instruct_datasets:
42 | args.dataset=dataset_name
43 | args.num_options = num_options_mapping[args.dataset]
44 | dataset = dataset_mapping[args.dataset](args=args, tokenizer=tokenizer, split=split)
45 |
46 | data_loader_instruct.append(dataset)
47 |
48 | dataset = ConcatDataset(data_loader_instruct)
49 |
50 | num_tasks = misc.get_world_size()
51 | global_rank = misc.get_rank()
52 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True)
53 | data_loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=batch_collate,
54 | pin_memory=args.pin_mem, drop_last=False)
55 |
56 | return data_loader
57 |
58 | def batch_collate(batch):
59 | q_index = None
60 | bs = len(batch)
61 | vid = [batch[i]["vid"] for i in range(bs)]
62 | video = torch.stack([batch[i]["video"] for i in range(bs)])
63 | video_len = torch.tensor([batch[i]["video_len"] for i in range(bs)], dtype=torch.long)
64 | text = [batch[i]["text"] for i in range(bs)]
65 | qid = [batch[i]["qid"] for i in range(bs)]
66 | qtype = torch.tensor([batch[i]['qtype'] for i in range(bs)])
67 |
68 | vqa_id = torch.stack([batch[i]['text_id']['vqa'] for i in range(bs)])
69 | vaq_id = torch.stack([batch[i]['text_id']['vaq'] for i in range(bs)])
70 | qav_id = torch.stack([batch[i]['text_id']['qav'] for i in range(bs)])
71 | text_id = {'vqa': vqa_id, 'vaq': vaq_id, 'qav': qav_id}
72 |
73 | vqa_label = torch.stack([batch[i]['label']['vqa'] for i in range(bs)])
74 | vaq_label = torch.stack([batch[i]['label']['vaq'] for i in range(bs)])
75 | qav_label = torch.stack([batch[i]['label']['qav'] for i in range(bs)])
76 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label}
77 |
78 | vqa_video_start = [batch[i]["video_start"]['vqa'] for i in range(bs)]
79 | vaq_video_start = [batch[i]["video_start"]['vaq'] for i in range(bs)]
80 | qav_video_start = [batch[i]["video_start"]['qav'] for i in range(bs)]
81 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_video_start}
82 | # q_index = [batch[i]["q_index"] for i in range(bs)]
83 |
84 | vqa_video_index = torch.stack([batch[i]["video_index"]['vqa'] for i in range(bs)])
85 | vaq_video_index = torch.stack([batch[i]["video_index"]['vaq'] for i in range(bs)])
86 | qav_video_index = torch.stack([batch[i]["video_index"]['qav'] for i in range(bs)])
87 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index}
88 |
89 | vqa_label_mask = torch.stack([batch[i]["label_mask"]['vqa'] for i in range(bs)])
90 | vaq_label_mask = torch.stack([batch[i]["label_mask"]['vaq'] for i in range(bs)])
91 | qav_label_mask = torch.stack([batch[i]["label_mask"]['qav'] for i in range(bs)])
92 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask}
93 |
94 | answer = torch.tensor([batch[i]["answer"] for i in range(bs)])
95 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
96 | "video_index": video_index, "label_mask": label_mask, "qid": qid, "answer": answer, "qtype": qtype,"q_index":q_index}
97 |
--------------------------------------------------------------------------------
/dataloader/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import copy
4 |
5 | class BaseDataset(Dataset):
6 | def __init__(self, args, tokenizer, split):
7 | self.args = args
8 | self.max_feats = args.max_feats
9 | self.tokenizer = tokenizer
10 | self.max_seq_len = args.max_seq_len
11 | self.split = split
12 |
13 | self.features_dim = 768
14 |
15 | def _get_padding_id(self, text_id):
16 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1
17 |
18 | for i, tid in enumerate(text_id):
19 | # print(len(tid))
20 | padding = self.max_seq_len - len(tid)
21 | # print(padding)
22 | if padding >= 0:
23 | padding_text_id[i, :len(tid)] = tid
24 | else:
25 | padding_text_id[i] = tid[:self.max_seq_len]
26 | # raise Exception
27 | # print('max sequence length overflow')
28 | return padding_text_id
29 |
30 | def _get_text_token(self, text, answer):
31 | vqa_id, vqa_prefix_index, vqa_video_start = self.tokenizer.encode_vqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
32 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_vaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
33 | qav_id, qav_prefix_index = self.tokenizer.encode_qav(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
34 |
35 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id]
36 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id]
37 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id]
38 |
39 | vqa_padding_text_id = self._get_padding_id(vqa_id)
40 | vaq_padding_text_id = self._get_padding_id(vaq_id)
41 | qav_padding_text_id = self._get_padding_id(qav_id)
42 |
43 | # label
44 | vqa_label = copy.deepcopy(vqa_padding_text_id)
45 | vqa_label[:, :vqa_prefix_index] = -1
46 | vqa_label_mask = vqa_label.ge(0)
47 | vqa_label[~vqa_label_mask] = 0
48 | vqa_label_mask = vqa_label_mask.float()
49 |
50 | vaq_label = copy.deepcopy(vaq_padding_text_id)
51 | vaq_label[:, :vaq_prefix_index] = -1
52 | vaq_label_mask = vaq_label.ge(0)
53 | vaq_label[~vaq_label_mask] = 0
54 | vaq_label_mask = vaq_label_mask.float()
55 |
56 | qav_label = torch.ones_like(qav_padding_text_id) * -1
57 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats)
58 | qav_label_mask = torch.zeros_like(qav_padding_text_id)
59 | qav_label_mask[:, qav_prefix_index] = 1
60 | qav_label_mask = qav_label_mask.float()
61 |
62 | # text mask
63 | vqa_text_mask = vqa_padding_text_id.ge(0)
64 | vqa_padding_text_id[~vqa_text_mask] = 0
65 | vaq_text_mask = vaq_padding_text_id.ge(0)
66 | vaq_padding_text_id[~vaq_text_mask] = 0
67 | qav_text_mask = qav_padding_text_id.ge(0)
68 | qav_padding_text_id[~qav_text_mask] = 0
69 |
70 | # video index
71 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats)
72 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats)
73 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats)
74 |
75 |
76 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id}
77 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label}
78 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index}
79 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index}
80 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask}
81 | return text_id, label, video_start, video_index, label_mask
82 |
83 |
84 | def _get_caption_token(self, text, answer):
85 |
86 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_videocap(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
87 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id]
88 |
89 | vaq_padding_text_id = self._get_padding_id(vaq_id)
90 |
91 | # label
92 |
93 |
94 | vaq_label = copy.deepcopy(vaq_padding_text_id)
95 | vaq_label[:, :vaq_prefix_index] = -1
96 | vaq_label_mask = vaq_label.ge(0)
97 | vaq_label[~vaq_label_mask] = 0
98 | vaq_label_mask = vaq_label_mask.float()
99 |
100 | # text mask
101 |
102 | vaq_text_mask = vaq_padding_text_id.ge(0)
103 | vaq_padding_text_id[~vaq_text_mask] = 0
104 |
105 | # video index
106 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats)
107 |
108 |
109 | text_id = {'vaq': vaq_padding_text_id}
110 | label = {'vaq': vaq_label}
111 | video_start = {'vaq': vaq_video_start}
112 | video_index = {'vaq': vaq_video_index}
113 | label_mask = {'vaq': vaq_label_mask}
114 | return text_id, label, video_start, video_index, label_mask
115 |
116 | def _get_openvqa_token(self, text, answer):
117 |
118 | vaq_id, vaq_prefix_index, vaq_video_start = self.tokenizer.encode_openvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
119 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id]
120 |
121 | vaq_padding_text_id = self._get_padding_id(vaq_id)
122 |
123 | vaq_label = copy.deepcopy(vaq_padding_text_id)
124 | if type(vaq_prefix_index)==list:
125 | for i,index_ in enumerate(vaq_prefix_index):
126 | vaq_label[i, :index_] = -1
127 | else:
128 | vaq_label[:, :vaq_prefix_index] = -1
129 | vaq_label_mask = vaq_label.ge(0)
130 | vaq_label[~vaq_label_mask] = 0
131 | vaq_label_mask = vaq_label_mask.float()
132 |
133 | # text mask
134 |
135 | vaq_text_mask = vaq_padding_text_id.ge(0)
136 | vaq_padding_text_id[~vaq_text_mask] = 0
137 |
138 | # video index
139 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats)
140 |
141 |
142 | text_id = {'vaq': vaq_padding_text_id}
143 | label = {'vaq': vaq_label}
144 | video_start = {'vaq': vaq_video_start}
145 | video_index = {'vaq': vaq_video_index}
146 | label_mask = {'vaq': vaq_label_mask}
147 | return text_id, label, video_start, video_index, label_mask
--------------------------------------------------------------------------------
/dataloader/dramaqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import json
4 |
5 | class DramaQA(BaseDataset):
6 | def __init__(self, args=None, tokenizer=None, split='train'):
7 | super().__init__(args, tokenizer, split)
8 | self.data = json.load(open(f'./data/dramaqa/AnotherMissOhQA_{split}_set.json', "r"))
9 | self.features = torch.load(f'./data/dramaqa/clipvitl14.pth')
10 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
11 | self.num_options = 5
12 | print(f"Num {split} data: {len(self.data)}")
13 |
14 | def _get_text(self, idx):
15 | question = self.data[idx]["que"].capitalize().strip()
16 | if question[-1] != "?":
17 | question = str(question) + "?"
18 |
19 | options = self.data[idx]['answers']
20 |
21 | q_text = f"Question: {question}\n"
22 | o_text = "Choices: \n"
23 | for i in range(self.num_options):
24 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
25 | a_text = "Answer: The answer is "
26 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options}
27 | return text
28 |
29 | def _get_video(self, video_id , idx):
30 |
31 | scene = True
32 | # Scene
33 | if video_id[-4:] == '0000':
34 | shots = self.data[idx]['shot_contained']
35 | start, end = shots[0], shots[1]
36 |
37 | for i in range(start, end+1):
38 | v_name = video_id[:-4] + f'{i:04}'
39 |
40 | if v_name not in self.features.keys():
41 | print(v_name, " Not in features")
42 | nxt_vid = torch.zeros(1, self.features_dim)
43 | else: nxt_vid = self.features[v_name].float()
44 |
45 | if i == start: video = nxt_vid
46 | else: video = torch.concat((video, nxt_vid), dim = 0)
47 | # Shot
48 | else:
49 | scene = False
50 | if video_id not in self.features.keys():
51 | print(video_id, "Not in freatures")
52 | video = torch.zeros(1, self.features_dim)
53 | else:
54 | video = self.features[video_id].float()
55 |
56 | if len(video) > self.max_feats:
57 | sampled = []
58 | for j in range(self.max_feats):
59 | sampled.append(video[(j * len(video)) // self.max_feats])
60 | video = torch.stack(sampled)
61 | video_len = self.max_feats
62 | elif len(video) < self.max_feats:
63 | video_len = len(video)
64 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0)
65 | else:
66 | video_len = self.max_feats
67 |
68 | return video, video_len, scene
69 |
70 | def __getitem__(self, idx):
71 | vid = self.data[idx]['vid']
72 | qtype = -1
73 | answer = self.data[idx]['correct_idx']
74 | text = self._get_text(idx)
75 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
76 | video, video_len, scene = self._get_video(f'{vid}', idx)
77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
78 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
79 |
80 | def __len__(self):
81 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/egoschema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import pickle
5 | import numpy as np
6 | import numpy
7 | import random
8 |
9 | class EgoSchema(BaseDataset):
10 | def __init__(self, args=None, tokenizer=None, split='train'):
11 | super().__init__(args, tokenizer, split)
12 | self.data = pd.read_csv(f'./data/egos/{split}.csv')
13 |
14 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
15 | self.num_options = 5
16 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
17 |
18 | self.features_path = './data/egos/features/'
19 |
20 | print(f"Num {split} data: {len(self.data)}")
21 |
22 | def _get_text(self, idx):
23 | question = self.data["question"].values[idx].capitalize().strip()
24 | if question[-1] != "?" and question[-1] != ".":
25 | question = str(question) + "?"
26 |
27 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
28 |
29 | q_text = f"Question: {question}\n"
30 | o_text = "Choices: \n"
31 | for i in range(self.num_options):
32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
33 |
34 | a_text = "Answer: The correct choice is "
35 | open_options = [f"Answer: {option}" for option in options]
36 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options}
37 | return text
38 |
39 |
40 | def _get_video(self, video):
41 | video=torch.from_numpy(video).float()
42 | video = video/video.norm(dim=-1,keepdim=True)
43 |
44 | if len(video) > self.max_feats:
45 | sampled = []
46 | for j in range(self.max_feats):
47 | sampled.append(video[(j * len(video)) // self.max_feats])
48 | video = torch.stack(sampled)
49 | video_len = self.max_feats
50 | elif len(video) < self.max_feats:
51 | video_len = len(video)
52 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
53 | else:
54 | video_len = self.max_feats
55 | if self.args.single_frame:
56 | video = video[::2]
57 | video=video.repeat_interleave(2,dim=0)
58 | return video, video_len
59 |
60 | def __getitem__(self, idx):
61 | while True:
62 | try:
63 | vid = self.data['uid'].values[idx]
64 |
65 | qtype = 1
66 | answer = self.data['answer'].values[idx]
67 |
68 | text = self._get_text(idx)
69 |
70 |
71 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
72 | q_index = len(self.tokenizer.sp_model.encode(text['q_text']))+1+video_start['vqa']+self.max_feats
73 | if self.args.openvqa_eval:
74 | text['oa_text'] = f"_"
75 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer)
76 | text_id.update(text_id_c)
77 | label.update(label_c)
78 | video_start.update(video_start_c)
79 | video_index.update(video_index_c)
80 | label_mask.update(label_mask_c)
81 |
82 | v_path = f'{self.features_path}{vid}.npy'
83 | with open(v_path,'rb') as f:
84 | video = numpy.load(f)
85 | video, video_len = self._get_video(video)
86 |
87 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start, "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype,"q_index": q_index}
88 | except:
89 | print(idx)
90 | idx = np.random.randint(0, len(self)-1)
91 |
92 | def __len__(self):
93 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/how2qa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import pickle
5 | import numpy as np
6 | import numpy
7 | class How2qa(BaseDataset):
8 | def __init__(self, args=None, tokenizer=None, split='train'):
9 | super().__init__(args, tokenizer, split)
10 | self.data = pd.read_csv(f'./data/how2qa/{split}.csv')
11 |
12 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
13 | self.num_options = 4
14 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
15 |
16 | self.features_path = '../../EgoSchema/benchmarking/FrozenBilm/features_how2qa/'
17 |
18 | print(f"Num {split} data: {len(self.data)}")
19 |
20 | def _get_text(self, idx):
21 | question = self.data["question"].values[idx].capitalize().strip()
22 | if question[-1] != "?":
23 | question = str(question) + "?"
24 |
25 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
26 |
27 | q_text = f"Question: {question}\n"
28 | o_text = "Choices: \n"
29 | for i in range(self.num_options):
30 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
31 |
32 | a_text = "Answer: The correct choice is "
33 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options}
34 | return text
35 |
36 | def _get_video(self, video):
37 | video=torch.from_numpy(video).float()
38 | video = video/video.norm(dim=-1,keepdim=True)
39 | # video = torch.zeros(1, self.features_dim)
40 | if len(video) > self.max_feats:
41 | sampled = []
42 | for j in range(self.max_feats):
43 | sampled.append(video[(j * len(video)) // self.max_feats])
44 | video = torch.stack(sampled)
45 | video_len = self.max_feats
46 | elif len(video) < self.max_feats:
47 | video_len = len(video)
48 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
49 | else:
50 | video_len = self.max_feats
51 |
52 | return video, video_len
53 |
54 | def __getitem__(self, idx):
55 | while True:
56 | try:
57 | # if True:
58 | vid = self.data['uid'].values[idx]
59 | video_id = self.data['video_id'].values[idx]
60 | qtype = 1
61 | answer = self.data['answer'].values[idx]
62 | text = self._get_text(idx)
63 | # print(text)
64 | # print(answer)
65 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
66 |
67 | v_path = f'{self.features_path}{video_id}_{vid}.npy'
68 | # v_path = f'{self.features_path}507441ee-3eb4-4dc6-bac2-26bec2b66380.npy'
69 | with open(v_path,'rb') as f:
70 | video = numpy.load(f)
71 | video, video_len = self._get_video(video)
72 |
73 | # print(label_mask)
74 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
75 | "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype}
76 | except:
77 | print(idx)
78 | idx = np.random.randint(0, len(self)-1)
79 |
80 | def __len__(self):
81 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/nextqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import pickle
5 | import json
6 | import numpy
7 |
8 | class NextQA(BaseDataset):
9 | def __init__(self, args=None, tokenizer=None, split='train'):
10 | super().__init__(args, tokenizer, split)
11 | self.split =split
12 | self.data = pd.read_csv(f'./data/nextqa/{split}.csv')
13 |
14 |
15 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
16 | self.num_options = 5
17 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
18 |
19 |
20 | self.features = torch.load(f'./data/{args.dataset}/clipvitl14.pth')
21 | print(f"Num {split} data: {len(self.data)}")
22 | if self.split=='train':
23 | self.train_ratio = int(len(self.data)*self.args.data_ratio)
24 | else:
25 | self.train_ratio = int(len(self.data)*1)
26 |
27 | def _get_text(self, idx):
28 | question = self.data["question"].values[idx].capitalize().strip()
29 | if question[-1] != "?":
30 | question = str(question) + "?"
31 |
32 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
33 |
34 | q_text = f"Question: {question}\n"
35 | o_text = "Choices: \n"
36 | for i in range(self.num_options):
37 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
38 |
39 | a_text = "Answer: The correct choice is "
40 | open_options = [f"\nAnswer: {option}" for option in options]
41 |
42 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options}
43 | return text
44 |
45 |
46 | def _get_video(self, video):
47 | video = video/video.norm(dim=-1,keepdim=True)
48 | # video = torch.zeros(1, self.features_dim)
49 | # video = video.repeat(,1)
50 |
51 | if len(video) > self.max_feats:
52 | sampled = []
53 | for j in range(self.max_feats):
54 | sampled.append(video[(j * len(video)) // self.max_feats])
55 | video = torch.stack(sampled)
56 | video_len = self.max_feats
57 | elif len(video) < self.max_feats:
58 | video_len = len(video)
59 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
60 | else:
61 | video_len = self.max_feats
62 | if self.args.single_frame:
63 | video = video[::2]
64 | video=video.repeat_interleave(2,dim=0)
65 | return video, video_len
66 |
67 | def __getitem__(self, idx):
68 | idx = idx%self.train_ratio
69 | vid = self.data['video'].values[idx]
70 | qid = self.data['qid'].values[idx]
71 | # print(vid)
72 | qtype = self.qtype_mapping[self.data['type'].values[idx]]
73 | answer = self.data['answer'].values[idx]
74 | text = self._get_text(idx)
75 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
76 |
77 |
78 | video = self.features[f'{vid}'].float()
79 | video, video_len = self._get_video(video)
80 |
81 | if self.args.openvqa_eval:
82 | text['oa_text'] = f"_"
83 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer)
84 | text_id.update(text_id_c)
85 | label.update(label_c)
86 | video_start.update(video_start_c)
87 | video_index.update(video_index_c)
88 | label_mask.update(label_mask_c)
89 |
90 | # print(label_mask)
91 | return {"vid": str(vid), "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
92 | "video_index": video_index, "label_mask": label_mask, "qid": qid, "answer": answer, "qtype": qtype}
93 |
94 | def __len__(self):
95 | if self.args.debug:
96 | return len(self.data[:2000])
97 | return len(self.data[:])
98 |
--------------------------------------------------------------------------------
/dataloader/perception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import pickle
5 | import numpy as np
6 | import numpy
7 |
8 | class PerceptionTest(BaseDataset):
9 | def __init__(self, args=None, tokenizer=None, split='train'):
10 | super().__init__(args, tokenizer, split)
11 | self.data = pd.read_csv(f'./data/perception/{split}.csv')
12 |
13 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)'}
14 | self.num_options = 3
15 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
16 | self.features_path = '../../EgoSchema/benchmarking/FrozenBilm/features_perception_test_2fps/'
17 |
18 |
19 | print(f"Num {split} data: {len(self.data)}")
20 |
21 | def _get_text(self, idx):
22 | question = self.data["question"].values[idx].capitalize().strip()
23 | if question[-1] != "?":
24 | question = str(question) + "?"
25 |
26 | options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
27 |
28 | q_text = f"Question: {question}\n"
29 | o_text = "Choices: \n"
30 | for i in range(self.num_options):
31 |
32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
33 |
34 | a_text = "Answer: The correct choice is "
35 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options}
36 | return text
37 |
38 | def _get_video(self, video):
39 | video=torch.from_numpy(video).float()
40 |
41 |
42 | video = video/video.norm(dim=-1,keepdim=True)
43 |
44 | if len(video) > self.max_feats:
45 | sampled = []
46 | for j in range(self.max_feats):
47 | sampled.append(video[(j * len(video)) // self.max_feats])
48 | video = torch.stack(sampled)
49 | video_len = self.max_feats
50 | elif len(video) < self.max_feats:
51 | video_len = len(video)
52 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
53 | else:
54 | video_len = self.max_feats
55 |
56 | return video, video_len
57 |
58 | def __getitem__(self, idx):
59 | while True:
60 | try:
61 | vid = self.data['uid'].values[idx]
62 | # print(vid)
63 | qtype = 1
64 | answer = self.data['answer'].values[idx]
65 | text = self._get_text(idx)
66 | # print(text)
67 | # print(answer)
68 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
69 |
70 | v_path = f'{self.features_path}{vid}.npy'
71 |
72 | with open(v_path,'rb') as f:
73 | video = numpy.load(f)
74 | video, video_len = self._get_video(video)
75 |
76 | # print(label_mask)
77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
78 | "video_index": video_index, "label_mask": label_mask, "qid": vid, "answer": answer, "qtype": qtype}
79 | except:
80 | print(idx)
81 | idx = np.random.randint(0, len(self)-1)
82 |
83 | def __len__(self):
84 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/star.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import json
4 |
5 | class STAR(BaseDataset):
6 | def __init__(self, args=None, tokenizer=None, split='train'):
7 | super().__init__(args, tokenizer, split)
8 | self.data = json.load(open(f'./data/star/STAR_{split}.json', 'r'))
9 | self.features = torch.load(f'./data/star/clipvitl14.pth')
10 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)'}
11 | self.qtype_mapping = {'Interaction': 1, 'Sequence': 2, 'Prediction': 3, 'Feasibility': 4}
12 | self.num_options = 4
13 | self.split = split
14 | print(f"Num {split} data: {len(self.data)}")
15 | if split=='train':
16 | self.train_ratio = int(len(self.data)*self.args.data_ratio)
17 | else:
18 | self.train_ratio = int(len(self.data))
19 |
20 | def _get_text(self, idx):
21 | question = self.data[idx]["question"].capitalize().strip()
22 | if question[-1] != "?":
23 | question = str(question) + "?"
24 |
25 | options = {x['choice_id']: x['choice'] for x in self.data[idx]['choices']}
26 | options = [options[i] for i in range(self.num_options)]
27 | answer = options.index(self.data[idx]['answer'])
28 |
29 | q_text = f"Question: {question}\n"
30 | o_text = "Choices: \n"
31 | for i in range(self.num_options):
32 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
33 | a_text = "Answer: The answer is "
34 | open_options = [f"Answer: {option}" for option in options]
35 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options,'open_options':open_options}
36 | return text, answer
37 |
38 | def _get_video(self, video_id, start, end):
39 | if video_id not in self.features:
40 | print(video_id)
41 | video = torch.zeros(1, self.features_dim)
42 | else:
43 | video = self.features[video_id][start: end +1, :].float() # ts
44 | video = video/video.norm(dim=-1,keepdim=True)
45 | if len(video) > self.max_feats:
46 | sampled = []
47 | for j in range(self.max_feats):
48 | sampled.append(video[(j * len(video)) // self.max_feats])
49 | video = torch.stack(sampled)
50 | video_len = self.max_feats
51 | elif len(video) < self.max_feats:
52 | video_len = len(video)
53 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0)
54 | else:
55 | video_len = self.max_feats
56 |
57 | return video, video_len
58 |
59 | def __getitem__(self, idx):
60 | idx = idx%self.train_ratio
61 | vid = self.data[idx]['video_id']
62 | qtype = self.qtype_mapping[self.data[idx]['question_id'].split('_')[0]]
63 | text, answer = self._get_text(idx)
64 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
65 | start, end = round(self.data[idx]['start']), round(self.data[idx]['end'])
66 | video, video_len = self._get_video(f'{vid}', start, end)
67 |
68 | if self.args.openvqa_eval:
69 | text['oa_text'] = f"_"
70 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer)
71 | text_id.update(text_id_c)
72 | label.update(label_c)
73 | video_start.update(video_start_c)
74 | video_index.update(video_index_c)
75 | label_mask.update(label_mask_c)
76 |
77 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
78 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
79 |
80 |
81 | def __len__(self):
82 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/textvid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import json
5 | import random
6 | import pickle
7 | import numpy as np
8 | import math
9 | import pickle
10 | def noise_injection(x, variance=0.001, modality_offset=None, uniform_noise=False, dont_norm=False):
11 | if variance == 0.0:
12 | return x
13 | std = math.sqrt(variance)
14 | if not dont_norm:
15 | x = torch.nn.functional.normalize(x, dim=1)
16 |
17 | x = x + (torch.randn(x.shape, device=x.device) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim
18 | noise = (torch.randn(x.shape, device=x.device) * std)
19 | noise/=noise.norm(dim=-1,keepdim=True)
20 | x = x+noise*std
21 | if modality_offset is not None:
22 | x = x + modality_offset
23 | return torch.nn.functional.normalize(x, dim=-1)
24 |
25 | class TextVid(BaseDataset):
26 | def __init__(self, args=None, tokenizer=None, split='train'):
27 | super().__init__(args, tokenizer, split)
28 |
29 | with open('./data/textvid/textvid.json','r') as f:
30 | self.data = json.load(f)
31 | self.feature_path = './data/textvid/features'
32 | # with open('./data/textvid/feature_small.pkl','rb') as f:
33 | # self.feature = pickle.load(f)
34 | self.letter2number = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
35 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
36 | self.num_options = 5
37 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
38 | print(f"Num {split} data: {len(self.data)}")
39 |
40 | def _get_text(self, question,options):
41 | question = question.capitalize().strip()
42 | if question[-1] != "?":
43 | question = str(question) + "?"
44 |
45 | # options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
46 |
47 | q_text = f"Question: {question}\n"
48 | o_text = "Choices: \n"
49 | for i in range(len(options)):
50 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
51 |
52 | a_text = "Answer: The correct choice is "
53 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options}
54 | return text
55 |
56 | def _get_video(self, video):
57 | video=video.float()
58 | if len(video) > self.max_feats:
59 | sampled = []
60 | for j in range(self.max_feats):
61 | sampled.append(video[(j * len(video)) // self.max_feats])
62 | video = torch.stack(sampled)
63 | video_len = self.max_feats
64 | elif len(video) < self.max_feats:
65 | video_len = len(video)
66 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
67 |
68 | else:
69 | video_len = self.max_feats
70 |
71 | return video, video_len
72 |
73 | def _get_frame_feature(self, features,frames):
74 | start = 0
75 | video = []
76 | features = features/features.norm(dim=-1,keepdim=True)
77 |
78 | for frame in frames:
79 | feature = features[start:start+len(list(frame.keys()))]
80 | feature = feature.mean(dim=0,keepdim=True)
81 | feature = feature/feature.norm(dim=-1,keepdim=True)
82 | video.append(feature)
83 | start+=len(list(frame.keys()))
84 | video = torch.cat(video,dim=0)
85 | video = noise_injection(video, variance=self.args.variance)
86 | return video
87 |
88 |
89 | def __getitem__(self, idx):
90 | while True:
91 | try:
92 |
93 | video_meta = self.data[idx]
94 |
95 | vid = video_meta['idx']
96 | with open(f'{self.feature_path}/{vid}.pkl','rb') as f:
97 | features = pickle.load(f)
98 | # features = self.feature[vid]
99 | video = self._get_frame_feature(features,video_meta['frames'])
100 |
101 | qtype = 1
102 | qa = random.choice(video_meta['QAs'])
103 | question = qa['question']
104 |
105 | # if question.strip().lower().split(' ')[0] =='what':
106 | # if random.random()>0.5:
107 | # qa = random.choice(video_meta['QAs'])
108 | # question = qa['question']
109 | # if question.strip().lower().split(' ')[0] =='what':
110 | # raise
111 |
112 | answer = qa['answer']
113 | answer = self.letter2number[answer]
114 | options = qa['options']
115 | answer_text = options[answer]
116 | if self.args.answer_balance:
117 | if not (("both" in answer_text.lower()) or ("all" in answer_text.lower()) or ('none' in answer_text.lower()) or ('and' in answer_text.lower())):
118 | random.shuffle(options)
119 | answer = options.index(answer_text)
120 |
121 | text = self._get_text(question,options)
122 |
123 |
124 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
125 | video, video_len = self._get_video(video)
126 |
127 | if self.args.video_caption:
128 | caption = video_meta['global_video_caption']
129 | text['c_text'] = f"Description: {caption}\n"
130 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_caption_token(text, answer)
131 | text_id.update(text_id_c)
132 | label.update(label_c)
133 | video_start.update(video_start_c)
134 | video_index.update(video_index_c)
135 | label_mask.update(label_mask_c)
136 |
137 | if self.args.openvqa and (random.random()>0.5):
138 | if 'answer_open_ended' in qa.keys():
139 | if ('both' not in answer_text.lower()) and ('all' not in answer_text.lower()):
140 | answer_open_ended = qa['answer_open_ended'].strip()
141 | text['oa_text'] = f"Answer: {answer_open_ended}\n"
142 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_openvqa_token(text, answer)
143 | text_id.update(text_id_c)
144 | label.update(label_c)
145 | video_start.update(video_start_c)
146 | video_index.update(video_index_c)
147 | label_mask.update(label_mask_c)
148 |
149 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
150 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
151 |
152 | except:
153 | idx = np.random.randint(0, len(self)-1)
154 | # print(f'Error reading {idx}')
155 |
156 | def __len__(self):
157 | if self.args.debug:
158 | return len(self.data[:10000])
159 | else:
160 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/tvqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import json
4 | import copy
5 | import pysrt
6 | import numpy as np
7 | class TVQA(BaseDataset):
8 | def __init__(self, args=None, tokenizer=None, split='train'):
9 | super().__init__(args, tokenizer, split)
10 | self.split=split
11 | json_path = f'./data/tvqa/tvqa_{split}.jsonl'
12 | feature_path = f'./data/tvqa/clipvitl14.pth'
13 |
14 | with open(json_path, "r") as f:
15 | data_list = list(f)
16 | self.data = [json.loads(x) for x in data_list]
17 | self.features = torch.load(feature_path)
18 | self.subtitle_path = f'./data/tvqa/tvqa_subtitles/' # provided as castle_s01e01_seg02_clip_00.srt
19 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3 : '(D)', 4: '(E)'}
20 | self.num_options = 5
21 | self.sub = args.sub
22 | print(f"Num {split} data: {len(self.data)}")
23 | if split=='train':
24 | self.train_ratio = int(len(self.data)*self.args.data_ratio)
25 | else:
26 | self.train_ratio = int(len(self.data))
27 |
28 | def _get_text(self, idx, choices, vid, start, end):
29 | question = self.data[idx]["q"].capitalize().strip()
30 | if question[-1] != "?":
31 | question = str(question) + "?"
32 |
33 | if self.sub:
34 | dialogue = ''
35 |
36 | for t in pysrt.open(self.subtitle_path+f'{vid}'+'.srt'):
37 | txt = t.text.replace('\n', ' ')
38 | st = t.start.minutes * 60 + t.start.seconds
39 | et = t.end.minutes * 60 + t.end.seconds
40 | if (st >= start and et <= end) or (st <= start and et <= end and start <= et):
41 | dialogue += ' ' + txt
42 |
43 | if dialogue != '': d_text = f"Dialogue: {dialogue}\n"
44 | else: d_text = ''
45 |
46 | else:
47 | d_text = ""
48 |
49 | q_text = f"Question: {question}\n"
50 | o_text = f"Choices: \n"
51 |
52 | assert len(choices) == self.num_options, "Double check number of choices"
53 | for i, option in enumerate(choices):
54 | o_text += f"{self.answer_mapping[i]} {option}\n"
55 |
56 | a_text = f"Answer: The correct choice is "
57 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'd_text': d_text}
58 | return text
59 |
60 | def _get_video(self, video_id, start, end):
61 | if video_id not in self.features:
62 | print(video_id)
63 | video = torch.zeros(1, self.features_dim)
64 | else:
65 | video = self.features[video_id][start * 3: (end + 1) * 3, :].float() # 3fps
66 | video = video/video.norm(dim=-1,keepdim=True)
67 | if len(video) > self.max_feats:
68 | sampled = []
69 | for j in range(self.max_feats):
70 | sampled.append(video[(j * len(video)) // self.max_feats])
71 | video = torch.stack(sampled)
72 | video_len = self.max_feats
73 | elif len(video) < self.max_feats:
74 | video_len = len(video)
75 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0)
76 | else:
77 | video_len = self.max_feats
78 |
79 | return video, video_len
80 |
81 | def _get_padding_id(self, text_id, prefix_index, prefix_i, prefix_main, type):
82 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1
83 |
84 | prefix = prefix_index
85 | for i, tid in enumerate(text_id):
86 | padding = self.max_seq_len - len(tid)
87 | # print(padding)
88 | if padding >= 0:
89 | padding_text_id[i, :len(tid)] = tid
90 | prefix = prefix_index
91 | else:
92 | if self.sub and prefix_i != prefix_main:
93 | pad = self.max_seq_len - ((prefix_i) + (len(tid) - prefix_main))
94 | padding_text_id[i, :prefix_i] = tid[:prefix_i]
95 | padding_text_id[i, prefix_i: prefix_i + pad] = tid[prefix_i: prefix_i + pad]
96 | padding_text_id[i, prefix_i + pad :] = tid[prefix_main:]
97 |
98 | if type == "vqa":
99 | prefix = len(padding_text_id[i]) - 4
100 | elif type == "vaq":
101 | if self.split == "train":
102 | try:
103 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1].item() + 2
104 | except:
105 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2
106 | else:
107 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2
108 | else:
109 | prefix = len(padding_text_id[i]) - self.max_feats - 1
110 | else:
111 | padding_text_id[i] = tid[:self.max_seq_len]
112 | prefix = prefix_index
113 | # print('max sequence length overflow')
114 |
115 | return padding_text_id, prefix
116 |
117 | def _get_text_token(self, text, answer):
118 | vqa_id, vqa_prefix_index, vqa_video_start, vqa_prefix_i, vqa_prefix_q = self.tokenizer.encode_dvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
119 | vaq_id, vaq_prefix_index, vaq_video_start, vaq_prefix_i, vaq_prefix_q = self.tokenizer.encode_dvaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
120 | qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q = self.tokenizer.encode_dqav(text=text, max_feats=self.max_feats, max_seq_len=self.max_seq_len, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
121 |
122 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id]
123 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id]
124 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id]
125 |
126 | vqa_padding_text_id, vqa_prefix_index = self._get_padding_id(vqa_id, vqa_prefix_index, vqa_prefix_i, vqa_prefix_q, "vqa")
127 | vaq_padding_text_id, vaq_prefix_index = self._get_padding_id(vaq_id, vaq_prefix_index, vaq_prefix_i, vaq_prefix_q, "vaq")
128 | qav_padding_text_id, qav_prefix_index = self._get_padding_id(qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q, "qav")
129 |
130 | # label
131 | vqa_label = copy.deepcopy(vqa_padding_text_id)
132 | vqa_label[:, :vqa_prefix_index] = -1
133 | vqa_label_mask = vqa_label.ge(0)
134 | vqa_label[~vqa_label_mask] = 0
135 | vqa_label_mask = vqa_label_mask.float()
136 |
137 | vaq_label = copy.deepcopy(vaq_padding_text_id)
138 | vaq_label[:, :vaq_prefix_index] = -1
139 | vaq_label_mask = vaq_label.ge(0)
140 | vaq_label[~vaq_label_mask] = 0
141 | vaq_label_mask = vaq_label_mask.float()
142 |
143 | qav_label = torch.ones_like(qav_padding_text_id) * -1
144 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats)
145 | qav_label_mask = torch.zeros_like(qav_padding_text_id)
146 | qav_label_mask[:, qav_prefix_index] = 1
147 | qav_label_mask = qav_label_mask.float()
148 |
149 | # text mask
150 | vqa_text_mask = vqa_padding_text_id.ge(0)
151 | vqa_padding_text_id[~vqa_text_mask] = 0
152 | vaq_text_mask = vaq_padding_text_id.ge(0)
153 | vaq_padding_text_id[~vaq_text_mask] = 0
154 | qav_text_mask = qav_padding_text_id.ge(0)
155 | qav_padding_text_id[~qav_text_mask] = 0
156 |
157 | # video index
158 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats)
159 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats)
160 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats)
161 |
162 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id}
163 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label}
164 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index}
165 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index}
166 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask}
167 | return text_id, label, video_start, video_index, label_mask
168 |
169 | def __getitem__(self, idx):
170 | while True:
171 | try:
172 | idx = idx%self.train_ratio
173 | vid = self.data[idx]['vid_name']
174 | qtype = -1
175 | choices = [ self.data[idx][f'a{i}'] for i in range(self.num_options)]
176 | answer = self.data[idx]['answer_idx']
177 |
178 | start, end = map(float, self.data[idx]['ts'].split('-'))
179 | try:
180 | start, end = round(start), round(end)
181 | except:
182 | start, end = -1000, 1000
183 |
184 | video, video_len = self._get_video(f'{vid}', start, end)
185 | text = self._get_text(idx, choices, f'{vid}', start, end)
186 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
187 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
188 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
189 | except:
190 | idx = np.random.randint(0, len(self)-1)
191 |
192 |
193 | def __len__(self):
194 |
195 | return len(self.data[:])
196 |
--------------------------------------------------------------------------------
/dataloader/vlep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import json
4 | import copy
5 | import numpy as np
6 | class VLEP(BaseDataset):
7 | def __init__(self, args=None, tokenizer=None, split='train'):
8 | super().__init__(args, tokenizer, split)
9 | if split == 'val':
10 | json_path = f'./data/vlep/vlep_dev_release.jsonl'
11 | else:
12 | json_path = f'./data/vlep/vlep_{split}_release.jsonl'
13 | feature_path = f'./data/vlep/clipvitl14.pth'
14 | sub_path = f'./data/vlep/vlep_subtitles.jsonl'
15 |
16 | with open(json_path, "r") as f:
17 | data_list = list(f)
18 | with open(sub_path, "r") as s:
19 | sub_list = list(s)
20 | self.data = [json.loads(x) for x in data_list]
21 | self.subtitle = [json.loads(x) for x in sub_list]
22 | self.features = torch.load(feature_path)
23 | self.answer_mapping = {0: '(A)', 1: '(B)'}
24 | self.num_options = 2
25 | self.sub = args.sub
26 | print(f"Num {split} data: {len(self.data)}")
27 |
28 | def _get_text(self, choices, vid, start, end):
29 | question = "Which event is more likely to happen right after?".capitalize().strip()
30 |
31 | if self.sub:
32 | text = [x['sub'] for x in self.subtitle if x['vid_name'] == vid][0]
33 | dialogue = ''
34 | for txt in text:
35 | s, e, t = round(int(txt['start'])), int(txt['end']), txt['text'].replace('-', '')
36 | if (s >= start and e <= end) or (s <= start and e <= end and start <= e):
37 | dialogue+= t
38 | d_text = f"Dialogue: {dialogue}\n" # subtitles
39 | else:
40 | d_text = ""
41 |
42 | q_text = f"Question: {question}\n"
43 | o_text = f"Choices: \n"
44 |
45 | assert len(choices) == self.num_options, "Double check number of choices"
46 | for i, option in enumerate(choices):
47 | o_text += f"{self.answer_mapping[i]} {option}\n"
48 |
49 | a_text = f"Answer: The answer is "
50 |
51 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'd_text': d_text}
52 | return text
53 |
54 | def _get_video(self, video_id, start, end):
55 | if video_id not in self.features:
56 | print(video_id)
57 | video = torch.zeros(1, self.features_dim)
58 | else:
59 | video = self.features[video_id][start: end +1, :].float()
60 | video = video/video.norm(dim=-1,keepdim=True)
61 |
62 | if len(video) > self.max_feats:
63 | sampled = []
64 | for j in range(self.max_feats):
65 | sampled.append(video[(j * len(video)) // self.max_feats])
66 | video = torch.stack(sampled)
67 | video_len = self.max_feats
68 | elif len(video) < self.max_feats:
69 | video_len = len(video)
70 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0)
71 | else:
72 | video_len = self.max_feats
73 | return video, video_len
74 |
75 | def _get_padding_id(self, text_id, prefix_index, prefix_i, prefix_main, type):
76 | padding_text_id = torch.zeros((len(text_id), self.max_seq_len), dtype=torch.int64) - 1
77 |
78 | prefix = prefix_index
79 | for i, tid in enumerate(text_id):
80 | padding = self.max_seq_len - len(tid)
81 | if padding >= 0:
82 | padding_text_id[i, :len(tid)] = tid
83 | prefix = prefix_index
84 | else:
85 | if self.sub and prefix_i != prefix_main:
86 | pad = self.max_seq_len - ((prefix_i) + (len(tid) - prefix_main))
87 | padding_text_id[i, :prefix_i] = tid[:prefix_i]
88 | padding_text_id[i, prefix_i: prefix_i + pad] = tid[prefix_i: prefix_i + pad]
89 | padding_text_id[i, prefix_i + pad :] = tid[prefix_main:]
90 |
91 | if type == "vqa":
92 | prefix = len(padding_text_id[i]) - 4
93 | elif type == "vaq":
94 | if self.split == "train":
95 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1].item() + 2
96 | else:
97 | prefix = (padding_text_id == self.tokenizer.q_token_id).nonzero(as_tuple=True)[1][0].item() + 2
98 | else:
99 | prefix = len(padding_text_id[i]) - self.max_feats - 1
100 | else:
101 | padding_text_id[i] = tid[:self.max_seq_len]
102 | prefix = prefix_index
103 | return padding_text_id, prefix
104 |
105 |
106 | def _get_text_token(self, text, answer):
107 | vqa_id, vqa_prefix_index, vqa_video_start, vqa_prefix_i, vqa_prefix_q = self.tokenizer.encode_dvqa(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
108 | vaq_id, vaq_prefix_index, vaq_video_start, vaq_prefix_i, vaq_prefix_q = self.tokenizer.encode_dvaq(text=text, max_feats=self.max_feats, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
109 | qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q = self.tokenizer.encode_dqav(text=text, max_feats=self.max_feats, max_seq_len=self.max_seq_len, split=self.split, answer_mapping=self.answer_mapping, answer=answer)
110 |
111 | vqa_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vqa_id]
112 | vaq_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in vaq_id]
113 | qav_id = [torch.tensor(v_id, dtype=torch.int64) for v_id in qav_id]
114 |
115 | vqa_padding_text_id, vqa_prefix_index = self._get_padding_id(vqa_id, vqa_prefix_index, vqa_prefix_i, vqa_prefix_q, "vqa")
116 | vaq_padding_text_id, vaq_prefix_index = self._get_padding_id(vaq_id, vaq_prefix_index, vaq_prefix_i, vaq_prefix_q, "vaq")
117 | qav_padding_text_id, qav_prefix_index = self._get_padding_id(qav_id, qav_prefix_index, qav_prefix_i, qav_prefix_q, "qav")
118 |
119 | # label
120 | vqa_label = copy.deepcopy(vqa_padding_text_id)
121 | vqa_label[:, :vqa_prefix_index] = -1
122 | vqa_label_mask = vqa_label.ge(0)
123 | vqa_label[~vqa_label_mask] = 0
124 | vqa_label_mask = vqa_label_mask.float()
125 |
126 | vaq_label = copy.deepcopy(vaq_padding_text_id)
127 | vaq_label[:, :vaq_prefix_index] = -1
128 | vaq_label_mask = vaq_label.ge(0)
129 | vaq_label[~vaq_label_mask] = 0
130 | vaq_label_mask = vaq_label_mask.float()
131 |
132 | qav_label = torch.ones_like(qav_padding_text_id) * -1
133 | qav_label[:, qav_prefix_index:qav_prefix_index+self.max_feats] = torch.arange(self.max_feats)
134 | qav_label_mask = torch.zeros_like(qav_padding_text_id)
135 | qav_label_mask[:, qav_prefix_index] = 1
136 | qav_label_mask = qav_label_mask.float()
137 |
138 | # text mask
139 | vqa_text_mask = vqa_padding_text_id.ge(0)
140 | vqa_padding_text_id[~vqa_text_mask] = 0
141 | vaq_text_mask = vaq_padding_text_id.ge(0)
142 | vaq_padding_text_id[~vaq_text_mask] = 0
143 | qav_text_mask = qav_padding_text_id.ge(0)
144 | qav_padding_text_id[~qav_text_mask] = 0
145 |
146 | # video index
147 | vqa_video_index = torch.arange(vqa_prefix_index, vqa_prefix_index + self.max_feats)
148 | vaq_video_index = torch.arange(vaq_prefix_index, vaq_prefix_index + self.max_feats)
149 | qav_video_index = torch.arange(qav_prefix_index, qav_prefix_index + self.max_feats)
150 |
151 | text_id = {'vqa': vqa_padding_text_id, 'vaq': vaq_padding_text_id, 'qav': qav_padding_text_id}
152 | label = {'vqa': vqa_label, 'vaq': vaq_label, 'qav': qav_label}
153 | video_start = {'vqa': vqa_video_start, 'vaq': vaq_video_start, 'qav': qav_prefix_index}
154 | video_index = {'vqa': vqa_video_index, 'vaq': vaq_video_index, 'qav': qav_video_index}
155 | label_mask = {'vqa': vqa_label_mask, 'vaq': vaq_label_mask, 'qav': qav_label_mask}
156 | return text_id, label, video_start, video_index, label_mask
157 |
158 |
159 | def __getitem__(self, idx):
160 | while True:
161 | try:
162 | vid = self.data[idx]['vid_name']
163 | qtype = -1
164 | choices = self.data[idx]['events']
165 | answer = self.data[idx]['answer']
166 | ts = self.data[idx]['ts']
167 | start, end = round(ts[0]), round(ts[1])
168 | video, video_len = self._get_video(f'{vid}', start, end)
169 | text = self._get_text(choices, f'{vid}', start, end)
170 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
171 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
172 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
173 | except:
174 | print(idx)
175 | idx = np.random.randint(0, len(self)-1)
176 |
177 | def __len__(self):
178 | return len(self.data)
--------------------------------------------------------------------------------
/dataloader/webvid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_dataset import BaseDataset
3 | import pandas as pd
4 | import json
5 | import random
6 | import pickle
7 | import numpy as np
8 | import math
9 | import pickle
10 | def noise_injection(x, variance=0.001, modality_offset=None, uniform_noise=False, dont_norm=False):
11 | if variance == 0.0:
12 | return x
13 | std = math.sqrt(variance)
14 | if not dont_norm:
15 | x = torch.nn.functional.normalize(x, dim=1)
16 | # if uniform_noise:
17 | # x = x + get_uniform_ball_noise(x.shape, radius=std)
18 | # else:
19 | x = x + (torch.randn(x.shape, device=x.device) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim
20 | noise = (torch.randn(x.shape, device=x.device) * std)
21 | noise/=noise.norm(dim=-1,keepdim=True)
22 | x = x+noise*std
23 | if modality_offset is not None:
24 | x = x + modality_offset
25 | return torch.nn.functional.normalize(x, dim=-1)
26 |
27 | class Webvid(BaseDataset):
28 | def __init__(self, args=None, tokenizer=None, split='train'):
29 | super().__init__(args, tokenizer, split)
30 |
31 | with open('./data/webvid/train.json','r') as f:
32 | self.data = json.load(f)
33 | with open('./data/webvid/features.pkl','rb') as f:
34 | self.feature = pickle.load(f)
35 | self.letter2number = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
36 | self.answer_mapping = {0: '(A)', 1: '(B)', 2: '(C)', 3: '(D)', 4: '(E)'}
37 | self.num_options = 5
38 | self.qtype_mapping = {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
39 | print(f"Num {split} data: {len(self.data)}")
40 |
41 | def _get_text(self, question,options):
42 | question = question.capitalize().strip()
43 | if question[-1] != "?":
44 | question = str(question) + "?"
45 |
46 | # options = [self.data[f'a{i}'].values[idx] for i in range(self.num_options)]
47 |
48 | q_text = f"Question: {question}\n"
49 | o_text = "Choices: \n"
50 | for i in range(self.num_options):
51 | o_text += f"{self.answer_mapping[i]} {options[i]}\n"
52 |
53 | a_text = "Answer: The correct choice is "
54 | text = {'q_text': q_text, 'o_text': o_text, 'a_text': a_text, 'options': options}
55 | return text
56 |
57 | def _get_video(self, video):
58 | video=video.float()
59 | if len(video) > self.max_feats:
60 | sampled = []
61 | for j in range(self.max_feats):
62 | sampled.append(video[(j * len(video)) // self.max_feats])
63 | video = torch.stack(sampled)
64 | video_len = self.max_feats
65 | elif len(video) < self.max_feats:
66 | video_len = len(video)
67 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], dim=0)
68 |
69 | # elif len(video) < self.max_feats:
70 | # video_len = len(video)
71 | # filled_tensor = torch.zeros(self.max_feats, self.features_dim)
72 |
73 | # indices = random.sample(range(self.max_feats),video_len)
74 | # indices.sort()
75 | # for i,indice in enumerate(indices):
76 | # filled_tensor[indice] = video[i]
77 | # video_len = self.max_feats
78 | # video = filled_tensor
79 | else:
80 | video_len = self.max_feats
81 | # print(video_len)
82 | return video, video_len
83 |
84 | def _get_frame_feature(self, features,frames):
85 | start = 0
86 | video = []
87 | features = features/features.norm(dim=-1,keepdim=True)
88 |
89 | for frame in frames:
90 | feature = features[start:start+len(list(frame.keys()))]
91 | feature = feature.mean(dim=0,keepdim=True)
92 | features = features/features.norm(dim=-1,keepdim=True)
93 | video.append(feature)
94 | start+=len(list(frame.keys()))
95 | video = torch.cat(video,dim=0)
96 | video = noise_injection(video, variance=self.args.variance)
97 | return video
98 |
99 |
100 | def __getitem__(self, idx):
101 | # while True:
102 | # try:
103 |
104 | video_meta = self.data[idx]
105 |
106 | vid = video_meta['feature_idx']
107 | video = self.feature[vid]
108 | # video = self._get_frame_feature(features,video_meta['frames'])
109 | qtype = 1
110 |
111 | video, video_len = self._get_video(video)
112 |
113 |
114 | caption = video_meta['caption']
115 | question = "What does this video show?"
116 | options=['1','1','1','1','1']
117 | answer=0
118 | text = self._get_text(question,options)
119 | text_id, label, video_start, video_index, label_mask = self._get_text_token(text, answer)
120 |
121 | text['c_text'] = f"Description: {caption}."
122 | text_id_c, label_c, video_start_c, video_index_c, label_mask_c = self._get_caption_token(text, answer)
123 | text_id.update(text_id_c)
124 | label.update(label_c)
125 | video_start.update(video_start_c)
126 | video_index.update(video_index_c)
127 | label_mask.update(label_mask_c)
128 |
129 |
130 |
131 | return {"vid": vid, "video": video, "video_len": video_len, "text": text, "text_id": text_id, "label": label, "video_start": video_start,
132 | "video_index": video_index, "label_mask": label_mask, "qid": idx, "answer": answer, "qtype": qtype}
133 |
134 | # except:
135 | # idx = np.random.randint(0, len(self)-1)
136 | # print(f'Error reading {idx}')
137 |
138 | def __len__(self):
139 | if self.args.debug:
140 | return len(self.data[:10000])
141 | else:
142 | return len(self.data)
--------------------------------------------------------------------------------
/demos/13B_msrvtt_results.json:
--------------------------------------------------------------------------------
1 | {"Bleu_1": 0.6634161036124164, "Bleu_2": 0.4809019909504734, "Bleu_3": 0.3141303218433441, "Bleu_4": 0.2110267573252259, "METEOR": 0.2179849241732958, "ROUGE_L": 0.4981662383625836, "CIDEr": 0.33443181238316033, "SPICE": 0.05282364436867422}
--------------------------------------------------------------------------------
/demos/7B_msrvtt_results.json:
--------------------------------------------------------------------------------
1 | {"Bleu_1": 0.654746090605047, "Bleu_2": 0.4820976690148554, "Bleu_3": 0.3151958750789693, "Bleu_4": 0.2078218649347474, "METEOR": 0.21875795472104748, "ROUGE_L": 0.5030358497611536, "CIDEr": 0.32867167801902625, "SPICE": 0.050667287735789934}
--------------------------------------------------------------------------------
/demos/Eval_Cap_MSRVTT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "ebdbbd25-038d-4987-ac84-d5adbdcd40ad",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import argparse\n",
11 | "import json\n",
12 | "import torch\n",
13 | "import json\n",
14 | "import os"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "id": "460fe558-fca6-4906-93f3-facd215d9ab4",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "def save_arguments(args, filepath):\n",
25 | " with open(filepath, 'w') as file:\n",
26 | " json.dump(vars(args), file)\n",
27 | "\n",
28 | "def load_arguments(filepath):\n",
29 | " with open(filepath, 'r') as file:\n",
30 | " args_dict = json.load(file)\n",
31 | " return args_dict\n",
32 | "\n",
33 | "# Optionally, repopulate argparse.ArgumentParser with these arguments\n",
34 | "def repopulate_arguments(args_dict):\n",
35 | " parser = argparse.ArgumentParser(description=\"Example script\")\n",
36 | " for key, value in args_dict.items():\n",
37 | " parser.add_argument(f'--{key}', type=type(value),default=value)\n",
38 | " return parser.parse_args([])"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 3,
44 | "id": "c9c50017-a3cf-4fa3-8a7f-2e24c501a1a0",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "path = '../vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips'\n",
49 | "\n",
50 | "loaded_args = load_arguments(path+'/args.json')\n",
51 | "\n",
52 | "args = repopulate_arguments(loaded_args)\n",
53 | "args.llama_model_path = '.' +args.llama_model_path\n",
54 | "args.resume='../vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth'"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "id": "074499b9-b268-4a4e-b05f-9e0f6d5a189f",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "# path = '../vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips'\n",
65 | "\n",
66 | "# loaded_args = load_arguments(path+'/argsa.json')\n",
67 | "\n",
68 | "# args = repopulate_arguments(loaded_args)\n",
69 | "# args.llama_model_path = '.' +args.llama_model_path\n",
70 | "# args.resume='../vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth'"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "id": "98596d9e-0760-4eef-995f-7e2d0e47d33f",
77 | "metadata": {},
78 | "outputs": [
79 | {
80 | "name": "stderr",
81 | "output_type": "stream",
82 | "text": [
83 | "/home/users/nus/idmwyk/scratch/anaconda3/envs/llama/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
84 | " from .autonotebook import tqdm as notebook_tqdm\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "import sys\n",
90 | "sys.path.append('../')\n",
91 | "from llama import Tokenizer\n",
92 | "from llama_vqa import LLaMA_VQA\n",
93 | "from dataloader import load_data"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 6,
99 | "id": "9272202e-d5ed-4645-8411-98463087b6ca",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "checkpoint = torch.load(args.resume, map_location='cpu')\n",
104 | "model.load_state_dict(checkpoint['model'], strict=False)\n",
105 | "tokenizer = Tokenizer(model_path=f'{args.llama_model_path}./tokenizer.model')"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "c2b12971-d47f-4aef-ba77-4d1d163160d8",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": []
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "id": "1e125471-d745-4f57-b51a-b311df9de478",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 7,
127 | "id": "80ba3180-0ce1-4f7c-817f-79aa02c0ccc5",
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "def decoding(model, tokenizer, prompt1,prompt2,video=None):\n",
132 | " adapter = model.adapter_query.weight.reshape(-1, model.adapter_len, model.params.dim).unsqueeze(1)\n",
133 | " freqs= model.freqs_cis.cuda()\n",
134 | " \n",
135 | " tokens = [tokenizer.bos_id] + tokenizer.sp_model.encode(prompt1)\n",
136 | " query = torch.tensor(tokens, dtype=torch.int64).cuda()\n",
137 | " input_embedding = model.tok_embeddings(query)\n",
138 | "\n",
139 | " tokens_2 = tokenizer.sp_model.encode(prompt2)\n",
140 | " query_2 = torch.tensor(tokens_2, dtype=torch.int64).cuda()\n",
141 | " input_embedding_2 = model.tok_embeddings(query_2)\n",
142 | " tokens.extend(tokens_2)\n",
143 | " video = video.cuda().float()\n",
144 | " video/=video.norm(dim=-1,keepdim=True)\n",
145 | " if True:\n",
146 | " sim = video@model.memory.T\n",
147 | "\n",
148 | " sim = (sim*100).softmax(dim=-1)\n",
149 | "\n",
150 | " video = sim@model.memory\n",
151 | " video = video/video.norm(dim=-1,keepdim=True)\n",
152 | " \n",
153 | " video_feature = model.visual_proj(video)\n",
154 | " video_feature = (video_feature + model.temporal_emb.weight[:, :]).type(model.llamatype)\n",
155 | " vqa_video_start=input_embedding.shape[0]\n",
156 | " # print(video_feature.shape)\n",
157 | " input_embedding = torch.cat([input_embedding,video_feature,input_embedding_2])\n",
158 | " start_pos=0\n",
159 | " for j in range(10):\n",
160 | " vqa_h = input_embedding.unsqueeze(0)\n",
161 | " seqlen = vqa_h.shape[-2]\n",
162 | " freqs_cis = freqs[:seqlen]\n",
163 | " mask = None\n",
164 | " mask = torch.full((1, 1, seqlen, seqlen), float(\"-inf\"), device=vqa_h.device)\n",
165 | " mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h)\n",
166 | "\n",
167 | " for i, layer in enumerate(model.layers[-1 * model.adapter_layer:]):\n",
168 | " vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(model.llamatype), vqa_video_start)\n",
169 | " vqa_h = model.norm(vqa_h)\n",
170 | " vqa_output = model.output(vqa_h)\n",
171 | " vqa_output = vqa_output.reshape(-1, model.vocab_size)\n",
172 | " vqa_output[-1,920]=-100\n",
173 | " vqa_output[-1,1128]=-100\n",
174 | " next_token = vqa_output[-1,:].argmax()\n",
175 | " tokens.append(next_token.item())\n",
176 | " token_emb = model.tok_embeddings(next_token.unsqueeze(0))\n",
177 | " input_embedding = torch.cat([input_embedding,token_emb],dim=0)\n",
178 | " return tokens"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 8,
184 | "id": "7d043520-913a-40f1-a392-1b432196b78e",
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "import os\n",
189 | "dataset_path = '../data/videos/msrvtt/'\n",
190 | "test_file = os.path.join(dataset_path,'test_videodatainfo.json')\n",
191 | "\n",
192 | "with open(test_file,'r') as f:\n",
193 | " test_info = json.load(f)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 9,
199 | "id": "a6a2f83d-843d-4983-9a4e-cb1eeb6e9227",
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "annotations_gt = {}\n",
204 | "annotations_gt['images'] = test_info['videos']\n",
205 | "\n",
206 | "for sentence in test_info['sentences']:\n",
207 | " sentence['image_id'] = int(sentence['video_id'].replace('video',''))\n",
208 | " sentence['id'] = sentence['video_id']\n",
209 | "annotations_gt['annotations'] =test_info['sentences']"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 10,
215 | "id": "8b766fa8-7843-4627-b8fd-5ebaebef8f03",
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "import cv2\n",
220 | "import numpy as np\n",
221 | "from PIL import Image\n",
222 | "def sample_images_from_video(video_path, num_samples=10):\n",
223 | " # Open the video file\n",
224 | " cap = cv2.VideoCapture(video_path)\n",
225 | "\n",
226 | " # Get the total number of frames in the video\n",
227 | " total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
228 | " frame_rate = cap.get(cv2.CAP_PROP_FPS)\n",
229 | "\n",
230 | " # Calculate total duration in seconds\n",
231 | " total_duration = total_frames / frame_rate\n",
232 | " # print(total_duration)\n",
233 | " # Check if the video opened successfully\n",
234 | " if not cap.isOpened():\n",
235 | " print(\"Error opening video file.\")\n",
236 | " return []\n",
237 | "\n",
238 | " # Calculate the interval for sampling\n",
239 | " interval = total_frames // num_samples\n",
240 | "\n",
241 | " # Initialize a list to store the sampled images\n",
242 | " sampled_images = []\n",
243 | "\n",
244 | " for i in range(num_samples):\n",
245 | " # Set the frame position\n",
246 | " frame_id = i * interval\n",
247 | " cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)\n",
248 | "\n",
249 | " # Read the frame\n",
250 | " ret, frame = cap.read()\n",
251 | "\n",
252 | " # If frame reading was successful, save the frame\n",
253 | " if ret:\n",
254 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
255 | " pil_image = Image.fromarray(frame)\n",
256 | " sampled_images.append(pil_image)\n",
257 | " \n",
258 | " else:\n",
259 | " print(f\"Error reading frame at position {frame_id}\")\n",
260 | "\n",
261 | " # Release the video capture object\n",
262 | " cap.release()\n",
263 | "\n",
264 | " return sampled_images, total_frames\n"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 11,
270 | "id": "d8f3cd5d-6b6a-4b75-8e01-3502967c71a5",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "import clip"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 12,
280 | "id": "7d2bf58b-c2f0-4bda-9c55-d25b53d1863d",
281 | "metadata": {},
282 | "outputs": [],
283 | "source": [
284 | "clip_model, preprocess = clip.load(\"ViT-L/14\")\n",
285 | "clip_model.eval()\n",
286 | "clip_model = clip_model.cuda()"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "id": "c8012c00-b980-4cea-8eb2-5d5f4624a14b",
293 | "metadata": {},
294 | "outputs": [],
295 | "source": []
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 13,
300 | "id": "05527846-984b-43a5-8136-15f5267ce783",
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "from pycocotools.coco import COCO\n",
305 | "from pycocoevalcap.eval import COCOEvalCap\n",
306 | "\n",
307 | "from json import encoder\n",
308 | "encoder.FLOAT_REPR = lambda o: format(o, '.3f')\n",
309 | "import sys"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 14,
315 | "id": "c299f86f-e4ee-404d-8ad4-32b625e33c9e",
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "name": "stdout",
320 | "output_type": "stream",
321 | "text": [
322 | "loading annotations into memory...\n",
323 | "Done (t=0.04s)\n",
324 | "creating index...\n",
325 | "index created!\n"
326 | ]
327 | }
328 | ],
329 | "source": [
330 | "annotation_file = os.path.join(dataset_path,'annotation_file')\n",
331 | "with open(annotation_file,'w') as f:\n",
332 | " json.dump(annotations_gt,f)\n",
333 | "coco = COCO(annotation_file)\n",
334 | "video_ids = coco.getImgIds()"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 15,
340 | "id": "6898eb1a-7233-46f3-8060-fea7fd076a67",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "# prompt = \"Instruction: Predict the answer based on the video and question.\\nVideo:\"\n",
345 | "# prompt2 = \"\\nQuestion: Summarize the video.\\nAnswer: It is a video showing\" #26.7\n",
346 | "\n",
347 | "prompt = \"Instruction: Predict the answer based on the video and question.\\nVideo:\"\n",
348 | "prompt2 = \"\\nQuestion: Can you describe this video?\\nAnswer: It is a video of\""
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "id": "af77f161-4b9b-4079-b585-595b27900be3",
355 | "metadata": {},
356 | "outputs": [
357 | {
358 | "name": "stderr",
359 | "output_type": "stream",
360 | "text": [
361 | " 5%|▌ | 151/2990 [02:28<36:09, 1.31it/s] "
362 | ]
363 | }
364 | ],
365 | "source": [
366 | "from tqdm import tqdm\n",
367 | "results = []\n",
368 | " \n",
369 | "for video_id in tqdm(video_ids[:]):\n",
370 | " with torch.no_grad():\n",
371 | " try:\n",
372 | " video_path = os.path.join(dataset_path,'TestVideo','video'+str(video_id)+'.mp4')\n",
373 | " sampled_images,_ = sample_images_from_video(video_path)\n",
374 | "\n",
375 | " image_features = [preprocess(image) for image in sampled_images]\n",
376 | " image_features = torch.stack(image_features,dim=0).cuda()\n",
377 | " image_features = clip_model.encode_image(image_features)\n",
378 | " image_features/=image_features.norm(dim=-1,keepdim=True)\n",
379 | " tokens = decoding(model,tokenizer,prompt,prompt2,image_features)\n",
380 | " generate_text = tokenizer.decode(tokens[:])\n",
381 | " generate_text = generate_text.split('It is a video of')[1].strip().split('.')[0]\n",
382 | " results.append({'image_id':video_id,'caption': generate_text})\n",
383 | " except:\n",
384 | " results.append({'image_id':video_id,'caption': 'A video'})"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": null,
390 | "id": "602b0e4f-462e-41c1-9fcc-4a59630f0c4f",
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "coco_result = coco.loadRes(results) \n",
395 | "\n",
396 | "coco_eval = COCOEvalCap(coco, coco_result)\n",
397 | "coco_eval.evaluate()\n",
398 | "# print output evaluation s|cores\n",
399 | "scores = {}\n",
400 | "for metric, score in coco_eval.eval.items():\n",
401 | " print(f\"{metric}: {score:.3f}\")\n",
402 | " scores[metric] = score\n",
403 | "with open('13B_msrvtt_results.json','w') as f:\n",
404 | " json.dump(scores,f)"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": null,
410 | "id": "10d8bf4e-8ba2-4565-b080-feec9bf406c4",
411 | "metadata": {},
412 | "outputs": [],
413 | "source": []
414 | }
415 | ],
416 | "metadata": {
417 | "kernelspec": {
418 | "display_name": "llama",
419 | "language": "python",
420 | "name": "llama"
421 | },
422 | "language_info": {
423 | "codemirror_mode": {
424 | "name": "ipython",
425 | "version": 3
426 | },
427 | "file_extension": ".py",
428 | "mimetype": "text/x-python",
429 | "name": "python",
430 | "nbconvert_exporter": "python",
431 | "pygments_lexer": "ipython3",
432 | "version": "3.11.5"
433 | }
434 | },
435 | "nbformat": 4,
436 | "nbformat_minor": 5
437 | }
438 |
--------------------------------------------------------------------------------
/demos/upload_leaderboard_13B_zero_shot.json:
--------------------------------------------------------------------------------
1 | {"Action Sequence": 38.0, "Action Prediction": 40.0, "Action Antonym": 42.5, "Fine-grained Action": 35.0, "Unexpected Action": 69.0, "Object Existence": 52.5, "Object Interaction": 58.5, "Object Shuffle": 29.5, "Moving Direction": 22.5, "Action Localization": 43.5, "Scene Transition": 80.5, "Action Count": 38.0, "Moving Count": 25.5, "Moving Attribute": 43.0, "State Change": 43.0, "Fine-grained Pose": 29.5, "Character Order": 37.5, "Egocentric Navigation": 38.5, "Episodic Reasoning": 50.0, "Counterfactual Inference": 32.5, "Avg": 42.449999999999996}
--------------------------------------------------------------------------------
/demos/upload_leaderboard_7B_ZS.json:
--------------------------------------------------------------------------------
1 | {"Action Sequence": 42.0, "Action Prediction": 38.5, "Action Antonym": 35.0, "Fine-grained Action": 34.5, "Unexpected Action": 66.0, "Object Existence": 52.5, "Object Interaction": 47.5, "Object Shuffle": 28.000000000000004, "Moving Direction": 22.0, "Action Localization": 37.5, "Scene Transition": 81.0, "Action Count": 38.0, "Moving Count": 24.0, "Moving Attribute": 42.5, "State Change": 41.5, "Fine-grained Pose": 28.499999999999996, "Character Order": 34.0, "Egocentric Navigation": 23.5, "Episodic Reasoning": 49.0, "Counterfactual Inference": 30.5, "Avg": 39.800000000000004}
--------------------------------------------------------------------------------
/demos/video_transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 |
10 | class GroupRandomCrop(object):
11 | def __init__(self, size):
12 | if isinstance(size, numbers.Number):
13 | self.size = (int(size), int(size))
14 | else:
15 | self.size = size
16 |
17 | def __call__(self, img_group):
18 |
19 | w, h = img_group[0].size
20 | th, tw = self.size
21 |
22 | out_images = list()
23 |
24 | x1 = random.randint(0, w - tw)
25 | y1 = random.randint(0, h - th)
26 |
27 | for img in img_group:
28 | assert(img.size[0] == w and img.size[1] == h)
29 | if w == tw and h == th:
30 | out_images.append(img)
31 | else:
32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33 |
34 | return out_images
35 |
36 |
37 | class MultiGroupRandomCrop(object):
38 | def __init__(self, size, groups=1):
39 | if isinstance(size, numbers.Number):
40 | self.size = (int(size), int(size))
41 | else:
42 | self.size = size
43 | self.groups = groups
44 |
45 | def __call__(self, img_group):
46 |
47 | w, h = img_group[0].size
48 | th, tw = self.size
49 |
50 | out_images = list()
51 |
52 | for i in range(self.groups):
53 | x1 = random.randint(0, w - tw)
54 | y1 = random.randint(0, h - th)
55 |
56 | for img in img_group:
57 | assert(img.size[0] == w and img.size[1] == h)
58 | if w == tw and h == th:
59 | out_images.append(img)
60 | else:
61 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62 |
63 | return out_images
64 |
65 |
66 | class GroupCenterCrop(object):
67 | def __init__(self, size):
68 | self.worker = torchvision.transforms.CenterCrop(size)
69 |
70 | def __call__(self, img_group):
71 | return [self.worker(img) for img in img_group]
72 |
73 |
74 | class GroupRandomHorizontalFlip(object):
75 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76 | """
77 |
78 | def __init__(self, is_flow=False):
79 | self.is_flow = is_flow
80 |
81 | def __call__(self, img_group, is_flow=False):
82 | v = random.random()
83 | if v < 0.5:
84 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85 | if self.is_flow:
86 | for i in range(0, len(ret), 2):
87 | # invert flow pixel values when flipping
88 | ret[i] = ImageOps.invert(ret[i])
89 | return ret
90 | else:
91 | return img_group
92 |
93 |
94 | class GroupNormalize(object):
95 | def __init__(self, mean, std):
96 | self.mean = mean
97 | self.std = std
98 |
99 | def __call__(self, tensor):
100 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101 | rep_std = self.std * (tensor.size()[0] // len(self.std))
102 |
103 | # TODO: make efficient
104 | for t, m, s in zip(tensor, rep_mean, rep_std):
105 | t.sub_(m).div_(s)
106 |
107 | return tensor
108 |
109 |
110 | class GroupScale(object):
111 | """ Rescales the input PIL.Image to the given 'size'.
112 | 'size' will be the size of the smaller edge.
113 | For example, if height > width, then image will be
114 | rescaled to (size * height / width, size)
115 | size: size of the smaller edge
116 | interpolation: Default: PIL.Image.BILINEAR
117 | """
118 |
119 | def __init__(self, size, interpolation=Image.BILINEAR):
120 | self.worker = torchvision.transforms.Resize(size, interpolation)
121 |
122 | def __call__(self, img_group):
123 | return [self.worker(img) for img in img_group]
124 |
125 |
126 | class GroupOverSample(object):
127 | def __init__(self, crop_size, scale_size=None, flip=True):
128 | self.crop_size = crop_size if not isinstance(
129 | crop_size, int) else (crop_size, crop_size)
130 |
131 | if scale_size is not None:
132 | self.scale_worker = GroupScale(scale_size)
133 | else:
134 | self.scale_worker = None
135 | self.flip = flip
136 |
137 | def __call__(self, img_group):
138 |
139 | if self.scale_worker is not None:
140 | img_group = self.scale_worker(img_group)
141 |
142 | image_w, image_h = img_group[0].size
143 | crop_w, crop_h = self.crop_size
144 |
145 | offsets = GroupMultiScaleCrop.fill_fix_offset(
146 | False, image_w, image_h, crop_w, crop_h)
147 | oversample_group = list()
148 | for o_w, o_h in offsets:
149 | normal_group = list()
150 | flip_group = list()
151 | for i, img in enumerate(img_group):
152 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153 | normal_group.append(crop)
154 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155 |
156 | if img.mode == 'L' and i % 2 == 0:
157 | flip_group.append(ImageOps.invert(flip_crop))
158 | else:
159 | flip_group.append(flip_crop)
160 |
161 | oversample_group.extend(normal_group)
162 | if self.flip:
163 | oversample_group.extend(flip_group)
164 | return oversample_group
165 |
166 |
167 | class GroupFullResSample(object):
168 | def __init__(self, crop_size, scale_size=None, flip=True):
169 | self.crop_size = crop_size if not isinstance(
170 | crop_size, int) else (crop_size, crop_size)
171 |
172 | if scale_size is not None:
173 | self.scale_worker = GroupScale(scale_size)
174 | else:
175 | self.scale_worker = None
176 | self.flip = flip
177 |
178 | def __call__(self, img_group):
179 |
180 | if self.scale_worker is not None:
181 | img_group = self.scale_worker(img_group)
182 |
183 | image_w, image_h = img_group[0].size
184 | crop_w, crop_h = self.crop_size
185 |
186 | w_step = (image_w - crop_w) // 4
187 | h_step = (image_h - crop_h) // 4
188 |
189 | offsets = list()
190 | offsets.append((0 * w_step, 2 * h_step)) # left
191 | offsets.append((4 * w_step, 2 * h_step)) # right
192 | offsets.append((2 * w_step, 2 * h_step)) # center
193 |
194 | oversample_group = list()
195 | for o_w, o_h in offsets:
196 | normal_group = list()
197 | flip_group = list()
198 | for i, img in enumerate(img_group):
199 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200 | normal_group.append(crop)
201 | if self.flip:
202 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203 |
204 | if img.mode == 'L' and i % 2 == 0:
205 | flip_group.append(ImageOps.invert(flip_crop))
206 | else:
207 | flip_group.append(flip_crop)
208 |
209 | oversample_group.extend(normal_group)
210 | oversample_group.extend(flip_group)
211 | return oversample_group
212 |
213 |
214 | class GroupMultiScaleCrop(object):
215 |
216 | def __init__(self, input_size, scales=None, max_distort=1,
217 | fix_crop=True, more_fix_crop=True):
218 | self.scales = scales if scales is not None else [1, .875, .75, .66]
219 | self.max_distort = max_distort
220 | self.fix_crop = fix_crop
221 | self.more_fix_crop = more_fix_crop
222 | self.input_size = input_size if not isinstance(input_size, int) else [
223 | input_size, input_size]
224 | self.interpolation = Image.BILINEAR
225 |
226 | def __call__(self, img_group):
227 |
228 | im_size = img_group[0].size
229 |
230 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231 | crop_img_group = [
232 | img.crop(
233 | (offset_w,
234 | offset_h,
235 | offset_w +
236 | crop_w,
237 | offset_h +
238 | crop_h)) for img in img_group]
239 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240 | for img in crop_img_group]
241 | return ret_img_group
242 |
243 | def _sample_crop_size(self, im_size):
244 | image_w, image_h = im_size[0], im_size[1]
245 |
246 | # find a crop size
247 | base_size = min(image_w, image_h)
248 | crop_sizes = [int(base_size * x) for x in self.scales]
249 | crop_h = [
250 | self.input_size[1] if abs(
251 | x - self.input_size[1]) < 3 else x for x in crop_sizes]
252 | crop_w = [
253 | self.input_size[0] if abs(
254 | x - self.input_size[0]) < 3 else x for x in crop_sizes]
255 |
256 | pairs = []
257 | for i, h in enumerate(crop_h):
258 | for j, w in enumerate(crop_w):
259 | if abs(i - j) <= self.max_distort:
260 | pairs.append((w, h))
261 |
262 | crop_pair = random.choice(pairs)
263 | if not self.fix_crop:
264 | w_offset = random.randint(0, image_w - crop_pair[0])
265 | h_offset = random.randint(0, image_h - crop_pair[1])
266 | else:
267 | w_offset, h_offset = self._sample_fix_offset(
268 | image_w, image_h, crop_pair[0], crop_pair[1])
269 |
270 | return crop_pair[0], crop_pair[1], w_offset, h_offset
271 |
272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273 | offsets = self.fill_fix_offset(
274 | self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275 | return random.choice(offsets)
276 |
277 | @staticmethod
278 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279 | w_step = (image_w - crop_w) // 4
280 | h_step = (image_h - crop_h) // 4
281 |
282 | ret = list()
283 | ret.append((0, 0)) # upper left
284 | ret.append((4 * w_step, 0)) # upper right
285 | ret.append((0, 4 * h_step)) # lower left
286 | ret.append((4 * w_step, 4 * h_step)) # lower right
287 | ret.append((2 * w_step, 2 * h_step)) # center
288 |
289 | if more_fix_crop:
290 | ret.append((0, 2 * h_step)) # center left
291 | ret.append((4 * w_step, 2 * h_step)) # center right
292 | ret.append((2 * w_step, 4 * h_step)) # lower center
293 | ret.append((2 * w_step, 0 * h_step)) # upper center
294 |
295 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299 |
300 | return ret
301 |
302 |
303 | class GroupRandomSizedCrop(object):
304 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306 | This is popularly used to train the Inception networks
307 | size: size of the smaller edge
308 | interpolation: Default: PIL.Image.BILINEAR
309 | """
310 |
311 | def __init__(self, size, interpolation=Image.BILINEAR):
312 | self.size = size
313 | self.interpolation = interpolation
314 |
315 | def __call__(self, img_group):
316 | for attempt in range(10):
317 | area = img_group[0].size[0] * img_group[0].size[1]
318 | target_area = random.uniform(0.08, 1.0) * area
319 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
320 |
321 | w = int(round(math.sqrt(target_area * aspect_ratio)))
322 | h = int(round(math.sqrt(target_area / aspect_ratio)))
323 |
324 | if random.random() < 0.5:
325 | w, h = h, w
326 |
327 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328 | x1 = random.randint(0, img_group[0].size[0] - w)
329 | y1 = random.randint(0, img_group[0].size[1] - h)
330 | found = True
331 | break
332 | else:
333 | found = False
334 | x1 = 0
335 | y1 = 0
336 |
337 | if found:
338 | out_group = list()
339 | for img in img_group:
340 | img = img.crop((x1, y1, x1 + w, y1 + h))
341 | assert(img.size == (w, h))
342 | out_group.append(
343 | img.resize(
344 | (self.size, self.size), self.interpolation))
345 | return out_group
346 | else:
347 | # Fallback
348 | scale = GroupScale(self.size, interpolation=self.interpolation)
349 | crop = GroupRandomCrop(self.size)
350 | return crop(scale(img_group))
351 |
352 |
353 | class ConvertDataFormat(object):
354 | def __init__(self, model_type):
355 | self.model_type = model_type
356 |
357 | def __call__(self, images):
358 | if self.model_type == '2D':
359 | return images
360 | tc, h, w = images.size()
361 | t = tc // 3
362 | images = images.view(t, 3, h, w)
363 | images = images.permute(1, 0, 2, 3)
364 | return images
365 |
366 |
367 | class Stack(object):
368 |
369 | def __init__(self, roll=False):
370 | self.roll = roll
371 |
372 | def __call__(self, img_group):
373 | if img_group[0].mode == 'L':
374 | return np.concatenate([np.expand_dims(x, 2)
375 | for x in img_group], axis=2)
376 | elif img_group[0].mode == 'RGB':
377 | if self.roll:
378 | return np.concatenate([np.array(x)[:, :, ::-1]
379 | for x in img_group], axis=2)
380 | else:
381 | #print(np.concatenate(img_group, axis=2).shape)
382 | # print(img_group[0].shape)
383 | return np.concatenate(img_group, axis=2)
384 |
385 |
386 | class ToTorchFormatTensor(object):
387 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389 |
390 | def __init__(self, div=True):
391 | self.div = div
392 |
393 | def __call__(self, pic):
394 | if isinstance(pic, np.ndarray):
395 | # handle numpy array
396 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397 | else:
398 | # handle PIL Image
399 | img = torch.ByteTensor(
400 | torch.ByteStorage.from_buffer(
401 | pic.tobytes()))
402 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403 | # put it from HWC to CHW format
404 | # yikes, this transpose takes 80% of the loading time/CPU
405 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
406 | return img.float().div(255) if self.div else img.float()
407 |
408 |
409 | class IdentityTransform(object):
410 |
411 | def __call__(self, data):
412 | return data
413 |
414 |
415 | if __name__ == "__main__":
416 | trans = torchvision.transforms.Compose([
417 | GroupScale(256),
418 | GroupRandomCrop(224),
419 | Stack(),
420 | ToTorchFormatTensor(),
421 | GroupNormalize(
422 | mean=[.485, .456, .406],
423 | std=[.229, .224, .225]
424 | )]
425 | )
426 |
427 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428 |
429 | color_group = [im] * 3
430 | rst = trans(color_group)
431 |
432 | gray_group = [im.convert('L')] * 9
433 | gray_rst = trans(gray_group)
434 |
435 | trans2 = torchvision.transforms.Compose([
436 | GroupRandomSizedCrop(256),
437 | Stack(),
438 | ToTorchFormatTensor(),
439 | GroupNormalize(
440 | mean=[.485, .456, .406],
441 | std=[.229, .224, .225])
442 | ])
443 | print(trans2(color_group))
444 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import sys
4 | from typing import Iterable
5 | import util.misc as misc
6 | import util.lr_sched as lr_sched
7 | import json
8 | import random
9 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, loss_scaler, args=None):
10 | model.train(True)
11 | metric_logger = misc.MetricLogger(delimiter=" ")
12 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
13 | header = 'Epoch: [{}]'.format(epoch)
14 | print_freq = int(len(data_loader) / 50)
15 | accum_iter = args.accum_iter
16 |
17 | optimizer.zero_grad()
18 |
19 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
20 |
21 | if data_iter_step % accum_iter == 0:
22 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
23 | mode='vqa'
24 | if args.video_caption:
25 | if random.random()>0.5:
26 | mode='caption'
27 | if args.webvid:
28 | mode = 'caption'
29 | vqa_loss, vaq_loss, qav_loss = model(data,mode=mode)
30 | # print(vqa_loss, vaq_loss, qav_loss)
31 | loss = vqa_loss + vaq_loss*args.weight_captioning + qav_loss
32 | loss_value = loss.item()
33 | vqa_loss_value = vqa_loss.item()
34 | vaq_loss_value = vaq_loss.item()
35 | qav_loss_value = qav_loss.item()
36 |
37 |
38 | loss = loss / accum_iter
39 |
40 | loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0)
41 | if (data_iter_step + 1) % accum_iter == 0:
42 | optimizer.zero_grad()
43 |
44 | torch.cuda.synchronize()
45 |
46 | metric_logger.update(loss=loss_value)
47 | metric_logger.update(vqa_loss=vqa_loss_value)
48 | metric_logger.update(vaq_loss=vaq_loss_value)
49 | metric_logger.update(qav_loss=qav_loss_value)
50 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
51 |
52 | # gather the stats from all processes
53 | metric_logger.synchronize_between_processes()
54 | print("Averaged stats:", metric_logger)
55 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
56 |
57 |
58 | def val_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, args=None):
59 | model.eval()
60 |
61 | metric_logger = misc.MetricLogger(delimiter=" ")
62 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
63 | header = 'Epoch: [{}]'.format(epoch)
64 | print_freq = int(len(data_loader) / 10)
65 | results = {}
66 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
67 | answer = data['answer'].cuda()
68 | bsz = answer.shape[0]
69 | if args.openvqa_eval:
70 | mode = 'caption'
71 | else:
72 | mode = 'vqa'
73 | with torch.no_grad():
74 | logits = model(data, inference=True,mode=mode)
75 | count = (logits != 0).sum(-1)
76 | prediction = (logits.sum(-1) / count).argmin(-1)
77 |
78 | eval = (answer == prediction)
79 | acc = eval.sum().item() / bsz
80 |
81 | misc.log_qtype(data, eval, metric_logger, args)
82 |
83 | lr = optimizer.param_groups[0]["lr"]
84 | metric_logger.update(lr=lr)
85 | metric_logger.update(n=bsz, acc=acc)
86 |
87 | metric_logger.synchronize_between_processes()
88 | print("Averaged stats:", metric_logger)
89 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
90 |
91 |
92 | def test_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, epoch: int, args=None):
93 | model.eval()
94 | metric_logger = misc.MetricLogger(delimiter=" ")
95 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
96 | header = 'Epoch: [{}]'.format(epoch)
97 | print_freq = int(len(data_loader) /10)
98 | results = {}
99 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
100 | answer = data['answer'].cuda()
101 | bsz = answer.shape[0]
102 | if args.openvqa_eval:
103 | mode = 'caption'
104 | else:
105 | mode = 'vqa'
106 | with torch.no_grad():
107 | logits = model(data, inference=True,mode=mode)
108 |
109 | count = (logits != 0).sum(-1)
110 | prediction = (logits.sum(-1) / count).argmin(-1)
111 |
112 | results[data['qid'][0]]=prediction.item()
113 |
114 | eval = (answer == prediction)
115 | acc = eval.sum().item() / bsz
116 |
117 | misc.log_qtype(data, eval, metric_logger, args)
118 |
119 | lr = optimizer.param_groups[0]["lr"]
120 | metric_logger.update(lr=lr)
121 | metric_logger.update(n=bsz, acc=acc)
122 |
123 | # gather the stats from all processes
124 | with open(f'{args.output_dir}/egos.json','w') as f:
125 | json.dump(results,f)
126 | metric_logger.synchronize_between_processes()
127 | print("Averaged stats:", metric_logger)
128 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
129 |
--------------------------------------------------------------------------------
/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from .generation import LLaMA
5 | from .model import ModelArgs, Transformer
6 | from .tokenizer import Tokenizer
7 | from .tokenizer_llama3 import Tokenizer_llama3
8 | from .model_llama3 import Transformer_llama3,ModelArgs_llama3
9 |
--------------------------------------------------------------------------------
/llama/generation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import List
5 |
6 | import torch
7 |
8 | from llama.tokenizer import Tokenizer
9 | from llama.model import Transformer
10 |
11 |
12 | class LLaMA:
13 | def __init__(self, model: Transformer, tokenizer: Tokenizer):
14 | self.model = model
15 | self.tokenizer = tokenizer
16 |
17 | def generate(self, prompts: List[str], max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95,) -> List[str]:
18 | bsz = len(prompts)
19 | params = self.model.params
20 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
21 | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
22 |
23 | min_prompt_size = min([len(t) for t in prompt_tokens])
24 | max_prompt_size = max([len(t) for t in prompt_tokens])
25 |
26 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
27 |
28 | tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
29 | for k, t in enumerate(prompt_tokens):
30 | tokens[k, : len(t)] = torch.tensor(t).long()
31 | input_text_mask = tokens != self.tokenizer.pad_id
32 | start_pos = min_prompt_size
33 | prev_pos = 0
34 | for cur_pos in range(start_pos, total_len):
35 | logits = self.model.inference(None, tokens[:, prev_pos:cur_pos], prev_pos)
36 | if temperature > 0:
37 | probs = torch.softmax(logits / temperature, dim=-1)
38 | next_token = sample_top_p(probs, top_p)
39 | else:
40 | next_token = torch.argmax(logits, dim=-1)
41 | next_token = next_token.reshape(-1)
42 | # only replace token if prompt has already been generated
43 | next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
44 | tokens[:, cur_pos] = next_token
45 | prev_pos = cur_pos
46 |
47 | decoded = []
48 | for i, t in enumerate(tokens.tolist()):
49 | # cut to max gen len
50 | t = t[: len(prompt_tokens[i]) + max_gen_len]
51 | # cut to eos tok if any
52 | try:
53 | t = t[: t.index(self.tokenizer.eos_id)]
54 | except ValueError:
55 | pass
56 | decoded.append(self.tokenizer.decode(t))
57 | return decoded
58 |
59 |
60 | def sample_top_p(probs, p):
61 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
62 | probs_sum = torch.cumsum(probs_sort, dim=-1)
63 | mask = probs_sum - probs_sort > p
64 | probs_sort[mask] = 0.0
65 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
66 | next_token = torch.multinomial(probs_sort, num_samples=1)
67 | next_token = torch.gather(probs_idx, -1, next_token)
68 | return next_token
69 |
--------------------------------------------------------------------------------
/llama/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional, Tuple
5 | from dataclasses import dataclass
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | from torch.nn import Embedding, Linear
13 | import torch
14 | import pickle
15 | import timm.models.hub as timm_hub
16 |
17 |
18 | @dataclass
19 | class ModelArgs:
20 | dim: int = 512
21 | n_layers: int = 8
22 | n_heads: int = 8
23 | vocab_size: int = -1 # defined later by tokenizer
24 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
25 | norm_eps: float = 1e-5
26 | n_kv_heads: int=1
27 | max_batch_size: int = 32
28 | max_seq_len: int = 2048
29 | adapter_len: int=10
30 | adapter_layer: int=30
31 |
32 |
33 | class RMSNorm(torch.nn.Module):
34 | def __init__(self, dim: int, eps: float = 1e-6):
35 | super().__init__()
36 | self.eps = eps
37 | self.weight = nn.Parameter(torch.ones(dim))
38 |
39 | def _norm(self, x):
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | output = self._norm(x.float()).type_as(x)
44 | return output * self.weight
45 |
46 |
47 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
49 | t = torch.arange(end, device=freqs.device) # type: ignore
50 | freqs = torch.outer(t, freqs).float() # type: ignore
51 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
52 | return freqs_cis
53 |
54 |
55 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
56 | ndim = x.ndim
57 | assert 0 <= 1 < ndim
58 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
59 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
60 | return freqs_cis.view(*shape)
61 |
62 |
63 | def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
64 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
65 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
66 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
67 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
68 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
69 | return xq_out.type_as(xq), xk_out.type_as(xk)
70 |
71 |
72 | class Attention(nn.Module):
73 | def __init__(self, args: ModelArgs):
74 | super().__init__()
75 | self.n_local_heads = args.n_heads
76 | self.head_dim = args.dim // args.n_heads
77 | self.max_feats = args.max_feats
78 |
79 | self.wq = Linear(args.dim, args.n_heads * self.head_dim, bias=False)
80 | self.wk = Linear(args.dim, args.n_heads * self.head_dim, bias=False)
81 | self.wv = Linear(args.dim, args.n_heads * self.head_dim, bias=False)
82 | self.wo = Linear(args.n_heads * self.head_dim, args.dim, bias=False)
83 |
84 | self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()
85 | self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()
86 | self.gate1 = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
87 | self.gate2 = torch.nn.Parameter(torch.ones(1, self.n_local_heads, 1, 1) * -args.bias)
88 | self.llamatype = args.llamatype
89 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None):
90 | bsz, seqlen, _ = x.shape
91 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
92 |
93 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
94 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
95 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
96 |
97 |
98 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
99 | if adapter is not None:
100 | adapter_len = adapter.shape[1]
101 | adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
102 | adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
103 | xk = torch.cat([adapter_k, xk], dim=1)
104 | xv = torch.cat([adapter_v, xv], dim=1)
105 | extra_mask = torch.zeros(1, 1, seqlen, adapter_len).to(mask)
106 | mask = torch.cat([extra_mask, mask], dim=-1)
107 | keys = xk
108 | values = xv
109 |
110 | xq = xq.transpose(1, 2)
111 | keys = keys.transpose(1, 2)
112 | values = values.transpose(1, 2)
113 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
114 | if mask is not None:
115 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
116 | if adapter is not None:
117 | adapter_scores = F.softmax(scores[..., :adapter_len].float(), dim=-1).type_as(xq) * self.gate1.tanh().type(self.llamatype)
118 | if video_start is not None:
119 | vt_scores = scores[..., adapter_len:].clone()
120 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] = \
121 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] + self.gate2.type(self.llamatype)
122 | vt_scores = F.softmax(vt_scores.float(), dim=-1).type_as(xq)
123 | else:
124 | vt_scores = F.softmax(scores[..., adapter_len:], dim=-1)
125 | scores = torch.cat([adapter_scores, vt_scores], dim=-1)
126 | else:
127 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
128 | output = torch.matmul(scores, values)
129 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
130 | return self.wo(output)
131 |
132 |
133 | class FeedForward(nn.Module):
134 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
135 | super().__init__()
136 | hidden_dim = int(2 * hidden_dim / 3)
137 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
138 |
139 | self.w1 = Linear(dim, hidden_dim, bias=False)
140 | self.w2 = Linear(hidden_dim, dim, bias=False)
141 | self.w3 = Linear(dim, hidden_dim, bias=False)
142 |
143 | def forward(self, x):
144 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
145 |
146 |
147 | class TransformerBlock(nn.Module):
148 | def __init__(self, layer_id: int, args: ModelArgs):
149 | super().__init__()
150 | self.n_heads = args.n_heads
151 | self.dim = args.dim
152 | self.head_dim = args.dim // args.n_heads
153 | self.attention = Attention(args)
154 | self.feed_forward = FeedForward(dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of)
155 | self.layer_id = layer_id
156 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
157 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
158 |
159 |
160 |
161 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None):
162 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, video_start)
163 | out = h + self.feed_forward.forward(self.ffn_norm(h))
164 | return out
165 |
166 |
167 | class Transformer(nn.Module):
168 | def __init__(self, params: ModelArgs, args):
169 | super().__init__()
170 | params.max_feats = args.max_feats
171 | params.bias = args.bias
172 | self.args = args
173 |
174 | if self.args.llama2:
175 | self.llamatype = torch.bfloat16
176 | params.llamatype = torch.bfloat16
177 | else:
178 | self.llamatype = torch.half
179 | params.llamatype = torch.half
180 |
181 | self.params = params
182 | self.vocab_size = params.vocab_size
183 | self.n_layers = params.n_layers
184 | self.max_feats = args.max_feats
185 |
186 |
187 | self.tok_embeddings = Embedding(params.vocab_size, params.dim)
188 |
189 | self.adapter_query = Embedding(params.adapter_len * params.adapter_layer, params.dim)
190 |
191 | clip_feature_dim=768
192 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False)
193 | self.temporal_emb = Embedding(self.max_feats, params.dim)
194 | self.adapter_len = params.adapter_len
195 | self.adapter_layer = params.adapter_layer
196 |
197 | self.vqa_criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
198 | self.vaq_criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
199 |
200 | self.inference_criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none')
201 |
202 | self.layers = torch.nn.ModuleList()
203 | for layer_id in range(params.n_layers):
204 | self.layers.append(TransformerBlock(layer_id, params))
205 |
206 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
207 | self.output = Linear(params.dim, params.vocab_size, bias=False)
208 |
209 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
210 |
211 | self.video_label = torch.arange(1, self.max_feats)
212 | self.tau = args.tau
213 | if self.args.memory:
214 | with open('./data/textvid/memory.pkl','rb') as f:
215 | self.memory = pickle.load(f).float()[:1000000].cuda()
216 |
217 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False)
218 |
219 |
220 | def re_init_freqs(self,max_seq_len):
221 | self.params.max_seq_len=max_seq_len
222 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
223 |
224 | def forward(self, data, inference=False,mode='vqa'):
225 |
226 | video = data['video'].cuda()
227 | # video = video/video.norm(dim=-1,keepdim=True)
228 | if self.args.memory and inference==True:
229 | sim = video@self.memory.T
230 |
231 | sim = (sim*100).softmax(dim=-1)
232 | video = sim@self.memory
233 | video = video/video.norm(dim=-1,keepdim=True)
234 | if self.args.onlyqa:
235 | video=video*0
236 |
237 |
238 | vqa_id, vaq_id, qav_id = data['text_id']['vqa'].cuda(), data['text_id']['vaq'].cuda(), data['text_id']['qav'].cuda()
239 | vqa_label, vaq_label, qav_label = data['label']['vqa'].cuda(), data['label']['vaq'].cuda(), data['label']['qav'].cuda()
240 | vqa_video_start, vaq_video_start, qav_video_index = data['video_start']['vqa'][0], data['video_start']['vaq'][0], data['video_index']['qav'].cuda()
241 |
242 | bsz, n_options, seqlen = vqa_id.shape
243 | vqa_id, vaq_id = vqa_id.reshape(-1, seqlen), vaq_id.reshape(-1, seqlen)
244 | vqa_label, vaq_label = vqa_label.reshape(-1, seqlen), vaq_label.reshape(-1, seqlen)
245 | vqa_label, vaq_label = vqa_label[:, 1:].flatten(), vaq_label[:, 1:].flatten()
246 |
247 | qav_id = qav_id.reshape(-1, seqlen)
248 | qav_label = qav_label.reshape(-1, seqlen)
249 | qav_video_mask = qav_label.ge(0)
250 | qav_label = qav_label[:, 1:].flatten()
251 |
252 |
253 | with torch.no_grad():
254 | vqa_h = self.tok_embeddings(vqa_id)
255 |
256 | if self.args.vaq and not inference:
257 | vaq_h = self.tok_embeddings(vaq_id)
258 | if self.args.openvqa_eval:
259 | vaq_h = self.tok_embeddings(vaq_id)
260 | if self.args.qav and not inference:
261 | qav_h = self.tok_embeddings(qav_id)
262 |
263 | freqs_cis = self.freqs_cis.to(vqa_h.device)
264 | freqs_cis = freqs_cis[:seqlen]
265 | mask = None
266 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=vqa_h.device)
267 | mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h)
268 | start_pos = 0
269 | vqa_loss, vaq_loss, qav_loss = torch.tensor([0]).cuda(),torch.tensor([0]).cuda(), torch.tensor([0]).cuda()
270 |
271 | adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, self.params.dim).unsqueeze(1)
272 |
273 | _video_feature = self.visual_proj(video)
274 |
275 | if inference:
276 | _video_feature = _video_feature.unsqueeze(1).repeat(1, n_options, 1, 1).view(-1, _video_feature.shape[-2], _video_feature.shape[-1])
277 |
278 | video_feature = (_video_feature + self.temporal_emb.weight[None, :, :]).type(self.llamatype)
279 |
280 |
281 | if mode == 'vqa':
282 |
283 | vqa_h = vqa_h.clone()
284 | vqa_h[:, vqa_video_start:vqa_video_start+self.max_feats] = video_feature
285 |
286 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]):
287 | vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vqa_video_start)
288 |
289 | vqa_h = self.norm(vqa_h)
290 | vqa_output = self.output(vqa_h)
291 | vqa_output = vqa_output[:, :-1, :].reshape(-1, self.vocab_size)
292 | vqa_loss = self.vqa_criterion(vqa_output, vqa_label)
293 |
294 | if mode =='caption':
295 | vaq_h = vaq_h.clone()
296 | vaq_h[:, vaq_video_start:vaq_video_start+self.max_feats] = video_feature
297 |
298 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]):
299 | vaq_h = layer(vaq_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vaq_video_start)
300 | vaq_h = self.norm(vaq_h)
301 | vaq_output = self.output(vaq_h)
302 | vaq_output = vaq_output[:, :-1, :].reshape(-1, self.vocab_size)
303 | vaq_loss = self.vaq_criterion(vaq_output, vaq_label)
304 |
305 | if inference:
306 | if self.args.openvqa_eval:
307 | logits = self.inference_criterion(vaq_output, vaq_label)
308 | logits = logits.reshape(bsz, n_options, -1)
309 | else:
310 | logits = self.inference_criterion(vqa_output, vqa_label)
311 | logits = logits.reshape(bsz, n_options, -1)
312 | return logits
313 | else:
314 | return vqa_loss, vaq_loss, qav_loss
315 |
--------------------------------------------------------------------------------
/llama/model_llama3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional, Tuple
5 | from dataclasses import dataclass
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | from torch.nn import Embedding, Linear
13 | import torch
14 | import pickle
15 | import timm.models.hub as timm_hub
16 | from fairscale.nn.model_parallel.layers import (
17 | ColumnParallelLinear,
18 | RowParallelLinear,
19 | VocabParallelEmbedding,
20 | )
21 |
22 |
23 | @dataclass
24 | class ModelArgs_llama3:
25 | dim: int = 4096
26 | n_layers: int = 32
27 | n_heads: int = 32
28 | n_kv_heads: Optional[int] = None
29 | vocab_size: int = -1
30 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
31 | ffn_dim_multiplier: Optional[float] = None
32 | norm_eps: float = 1e-5
33 | rope_theta: float = 500000
34 |
35 | max_batch_size: int = 32
36 | max_seq_len: int = 2048
37 |
38 | adapter_len: int=10
39 | adapter_layer: int=30
40 |
41 |
42 | class RMSNorm(torch.nn.Module):
43 | def __init__(self, dim: int, eps: float = 1e-6):
44 | super().__init__()
45 | self.eps = eps
46 | self.weight = nn.Parameter(torch.ones(dim))
47 |
48 | def _norm(self, x):
49 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50 |
51 | def forward(self, x):
52 | output = self._norm(x.float()).type_as(x)
53 | return output * self.weight
54 |
55 |
56 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
57 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
58 | t = torch.arange(end, device=freqs.device, dtype=torch.float32)
59 | freqs = torch.outer(t, freqs)
60 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
61 | return freqs_cis
62 |
63 |
64 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
65 | ndim = x.ndim
66 | assert 0 <= 1 < ndim
67 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
68 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
69 | return freqs_cis.view(*shape)
70 |
71 |
72 | def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
73 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
74 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
75 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
76 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
77 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
78 | return xq_out.type_as(xq), xk_out.type_as(xk)
79 |
80 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
81 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
82 | bs, slen, n_kv_heads, head_dim = x.shape
83 | if n_rep == 1:
84 | return x
85 | return (
86 | x[:, :, :, None, :]
87 | .expand(bs, slen, n_kv_heads, n_rep, head_dim)
88 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
89 | )
90 |
91 |
92 | class Attention(nn.Module):
93 | def __init__(self, args: ModelArgs_llama3):
94 | super().__init__()
95 | self.n_local_heads = args.n_heads
96 | self.head_dim = args.dim // args.n_heads
97 | self.max_feats = args.max_feats
98 | self.n_kv_heads = args.n_kv_heads
99 | model_parallel_size=1
100 | self.n_local_heads = args.n_heads // model_parallel_size
101 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
102 | self.n_rep = self.n_local_heads // self.n_local_kv_heads
103 |
104 | self.wq = Linear(args.dim, args.n_heads * self.head_dim, bias=False)
105 | self.wk = Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
106 | self.wv = Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
107 | self.wo = Linear(args.n_heads * self.head_dim, args.dim, bias=False)
108 |
109 | self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()
110 | self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()
111 | self.gate1 = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
112 | self.gate2 = torch.nn.Parameter(torch.ones(1, self.n_local_heads, 1, 1) * -args.bias)
113 | self.llamatype = args.llamatype
114 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None):
115 | bsz, seqlen, _ = x.shape
116 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
117 |
118 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
119 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
120 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
121 |
122 |
123 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
124 | if adapter is not None:
125 | adapter_len = adapter.shape[1]
126 | adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_kv_heads, self.head_dim).repeat(bsz, 1, 1, 1)
127 | adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_kv_heads, self.head_dim).repeat(bsz, 1, 1, 1)
128 | xk = torch.cat([adapter_k, xk], dim=1)
129 | xv = torch.cat([adapter_v, xv], dim=1)
130 | extra_mask = torch.zeros(1, 1, seqlen, adapter_len).to(mask)
131 | mask = torch.cat([extra_mask, mask], dim=-1)
132 | keys = xk
133 | values = xv
134 | keys = repeat_kv(
135 | keys, self.n_rep
136 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
137 | values = repeat_kv(
138 | values, self.n_rep
139 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
140 |
141 | xq = xq.transpose(1, 2)
142 | keys = keys.transpose(1, 2)
143 | values = values.transpose(1, 2)
144 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
145 | if mask is not None:
146 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
147 | if adapter is not None:
148 | adapter_scores = F.softmax(scores[..., :adapter_len].float(), dim=-1).type_as(xq) * self.gate1.tanh().type(self.llamatype)
149 | if video_start is not None:
150 | vt_scores = scores[..., adapter_len:].clone()
151 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] = \
152 | vt_scores[:, :, video_start + self.max_feats:, video_start:video_start + self.max_feats] + self.gate2.type(self.llamatype)
153 | vt_scores = F.softmax(vt_scores.float(), dim=-1).type_as(xq)
154 | else:
155 | vt_scores = F.softmax(scores[..., adapter_len:], dim=-1)
156 | scores = torch.cat([adapter_scores, vt_scores], dim=-1)
157 | else:
158 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
159 | output = torch.matmul(scores, values)
160 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
161 | return self.wo(output)
162 |
163 |
164 | class FeedForward(nn.Module):
165 | def __init__(
166 | self,
167 | dim: int,
168 | hidden_dim: int,
169 | multiple_of: int,
170 | ffn_dim_multiplier: Optional[float],
171 | ):
172 | super().__init__()
173 | hidden_dim = int(2 * hidden_dim / 3)
174 | # custom dim factor multiplier
175 | if ffn_dim_multiplier is not None:
176 | hidden_dim = int(ffn_dim_multiplier * hidden_dim)
177 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
178 |
179 | self.w1 = Linear(dim, hidden_dim, bias=False)
180 | self.w2 = Linear(hidden_dim, dim, bias=False)
181 | self.w3 = Linear(dim, hidden_dim, bias=False)
182 |
183 | def forward(self, x):
184 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
185 |
186 |
187 | class TransformerBlock(nn.Module):
188 | def __init__(self, layer_id: int, args: ModelArgs_llama3):
189 | super().__init__()
190 | self.n_heads = args.n_heads
191 | self.dim = args.dim
192 | self.head_dim = args.dim // args.n_heads
193 | self.attention = Attention(args)
194 | self.feed_forward = FeedForward(
195 | dim=args.dim,
196 | hidden_dim=4 * args.dim,
197 | multiple_of=args.multiple_of,
198 | ffn_dim_multiplier=args.ffn_dim_multiplier,
199 | )
200 | self.layer_id = layer_id
201 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
202 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
203 |
204 |
205 |
206 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, video_start=None):
207 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, video_start)
208 | out = h + self.feed_forward.forward(self.ffn_norm(h))
209 | return out
210 |
211 |
212 | class Transformer_llama3(nn.Module):
213 | def __init__(self, params: ModelArgs_llama3, args):
214 | super().__init__()
215 | params.max_feats = args.max_feats
216 | params.bias = args.bias
217 | self.args = args
218 |
219 | if self.args.llama2:
220 | self.llamatype = torch.bfloat16
221 | params.llamatype = torch.bfloat16
222 | else:
223 | self.llamatype = torch.half
224 | params.llamatype = torch.half
225 |
226 | self.params = params
227 | self.vocab_size = params.vocab_size
228 | self.n_layers = params.n_layers
229 | self.max_feats = args.max_feats
230 |
231 |
232 | self.tok_embeddings = Embedding(params.vocab_size, params.dim)
233 | self.adapter_query = Embedding(params.adapter_len * params.adapter_layer, params.dim)
234 |
235 | clip_feature_dim=768
236 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False)
237 | self.temporal_emb = Embedding(self.max_feats, params.dim)
238 | self.adapter_len = params.adapter_len
239 | self.adapter_layer = params.adapter_layer
240 |
241 | self.vqa_criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
242 | self.vaq_criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
243 | self.qav_criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
244 | self.inference_criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none')
245 |
246 | self.layers = torch.nn.ModuleList()
247 | for layer_id in range(params.n_layers):
248 | self.layers.append(TransformerBlock(layer_id, params))
249 |
250 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
251 | self.output = Linear(
252 | params.dim, params.vocab_size, bias=False
253 | )
254 |
255 | self.freqs_cis = precompute_freqs_cis(
256 | params.dim // params.n_heads,
257 | params.max_seq_len * 2,
258 | params.rope_theta,
259 | )
260 |
261 | self.video_label = torch.arange(1, self.max_feats)
262 | self.tau = args.tau
263 | if self.args.memory:
264 | with open('./data/textvid/memory.pkl','rb') as f:
265 | self.memory = pickle.load(f).float()[:1000000].cuda()
266 |
267 | self.visual_proj = Linear(clip_feature_dim, params.dim, bias=False)
268 |
269 |
270 |
271 |
272 | def re_init_freqs(self,max_seq_len):
273 | self.params.max_seq_len=max_seq_len
274 | self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
275 |
276 | def forward(self, data, inference=False,mode='vqa'):
277 |
278 | video = data['video'].cuda()
279 | # video = video/video.norm(dim=-1,keepdim=True)
280 | if self.args.memory and inference==True:
281 | sim = video@self.memory.T
282 |
283 | sim = (sim*100).softmax(dim=-1)
284 | video = sim@self.memory
285 | video = video/video.norm(dim=-1,keepdim=True)
286 | if self.args.onlyqa:
287 | video=video*0
288 |
289 | vqa_id, vaq_id, qav_id = data['text_id']['vqa'].cuda(), data['text_id']['vaq'].cuda(), data['text_id']['qav'].cuda()
290 | vqa_label, vaq_label, qav_label = data['label']['vqa'].cuda(), data['label']['vaq'].cuda(), data['label']['qav'].cuda()
291 | vqa_video_start, vaq_video_start, qav_video_index = data['video_start']['vqa'][0], data['video_start']['vaq'][0], data['video_index']['qav'].cuda()
292 |
293 | bsz, n_options, seqlen = vqa_id.shape
294 | vqa_id, vaq_id = vqa_id.reshape(-1, seqlen), vaq_id.reshape(-1, seqlen)
295 | vqa_label, vaq_label = vqa_label.reshape(-1, seqlen), vaq_label.reshape(-1, seqlen)
296 | vqa_label, vaq_label = vqa_label[:, 1:].flatten(), vaq_label[:, 1:].flatten()
297 |
298 | qav_id = qav_id.reshape(-1, seqlen)
299 | qav_label = qav_label.reshape(-1, seqlen)
300 | qav_video_mask = qav_label.ge(0)
301 | qav_label = qav_label[:, 1:].flatten()
302 |
303 |
304 | with torch.no_grad():
305 | vqa_h = self.tok_embeddings(vqa_id)
306 |
307 | if self.args.vaq and not inference:
308 | vaq_h = self.tok_embeddings(vaq_id)
309 | if self.args.openvqa_eval:
310 | vaq_h = self.tok_embeddings(vaq_id)
311 | if self.args.qav and not inference:
312 | qav_h = self.tok_embeddings(qav_id)
313 |
314 | freqs_cis = self.freqs_cis.to(vqa_h.device)
315 | freqs_cis = freqs_cis[:seqlen]
316 | mask = None
317 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=vqa_h.device)
318 | mask = torch.triu(mask, diagonal=0 + 1).type_as(vqa_h)
319 | start_pos = 0
320 | vqa_loss, vaq_loss, qav_loss = torch.tensor([0]).cuda(),torch.tensor([0]).cuda(), torch.tensor([0]).cuda()
321 |
322 | adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, self.params.dim).unsqueeze(1)
323 |
324 | _video_feature = self.visual_proj(video)
325 |
326 | if inference:
327 | _video_feature = _video_feature.unsqueeze(1).repeat(1, n_options, 1, 1).view(-1, _video_feature.shape[-2], _video_feature.shape[-1])
328 | video_feature = (_video_feature + self.temporal_emb.weight[None, :, :]).type(self.llamatype)
329 |
330 |
331 | if mode == 'vqa':
332 |
333 | vqa_h = vqa_h.clone()
334 | vqa_h[:, vqa_video_start:vqa_video_start+self.max_feats] = video_feature
335 |
336 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]):
337 | vqa_h = layer(vqa_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vqa_video_start)
338 |
339 | vqa_h = self.norm(vqa_h)
340 | vqa_output = self.output(vqa_h)
341 | vqa_output = vqa_output[:, :-1, :].reshape(-1, self.vocab_size)
342 | vqa_loss = self.vqa_criterion(vqa_output, vqa_label)
343 |
344 | if mode =='caption':
345 | vaq_h = vaq_h.clone()
346 | vaq_h[:, vaq_video_start:vaq_video_start+self.max_feats] = video_feature
347 |
348 | for i, layer in enumerate(self.layers[-1 * self.adapter_layer:]):
349 | vaq_h = layer(vaq_h, start_pos, freqs_cis, mask, adapter[i].type(self.llamatype), vaq_video_start)
350 | vaq_h = self.norm(vaq_h)
351 | vaq_output = self.output(vaq_h)
352 | vaq_output = vaq_output[:, :-1, :].reshape(-1, self.vocab_size)
353 | vaq_loss = self.vaq_criterion(vaq_output, vaq_label)
354 |
355 | if inference:
356 | if self.args.openvqa_eval:
357 | logits = self.inference_criterion(vaq_output, vaq_label)
358 | logits = logits.reshape(bsz, n_options, -1)
359 | else:
360 | logits = self.inference_criterion(vqa_output, vqa_label)
361 | logits = logits.reshape(bsz, n_options, -1)
362 | return logits
363 | else:
364 | return vqa_loss, vaq_loss, qav_loss
365 |
--------------------------------------------------------------------------------
/llama/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from sentencepiece import SentencePieceProcessor
5 | from logging import getLogger
6 | from typing import List
7 | import os
8 | import torch
9 |
10 | logger = getLogger()
11 |
12 |
13 | class Tokenizer:
14 | def __init__(self, model_path: str):
15 | # reload tokenizer
16 | assert os.path.isfile(model_path), model_path
17 | self.sp_model = SentencePieceProcessor(model_file=model_path)
18 | logger.info(f"Reloaded SentencePiece model from {model_path}")
19 |
20 | # BOS / EOS token IDs
21 | self.n_words: int = self.sp_model.vocab_size()
22 | self.bos_id: int = self.sp_model.bos_id()
23 | self.eos_id: int = self.sp_model.eos_id()
24 | self.pad_id: int = self.sp_model.pad_id()
25 |
26 | self.v_token_id = 15167
27 | self.q_token_id = 16492
28 | self.a_token_id = 22550
29 | self.c_token_id = 9868
30 | self.nl_id = 13
31 | logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
32 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
33 |
34 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
35 | assert type(s) is str
36 | t = self.sp_model.encode(s)
37 | if bos:
38 | t = [self.bos_id] + t
39 | if eos:
40 | t = t + [self.eos_id]
41 | return t
42 |
43 | def encode_vqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
44 | i_text = "Instruction: Choose the correct answer based on the video and question.\n"
45 | q_text = text['q_text']
46 | o_text = text['o_text']
47 | a_text = text['a_text']
48 |
49 | s1 = i_text + 'Video:'
50 | t1 = [self.bos_id] + self.sp_model.encode(s1)
51 | video_start = len(t1)
52 |
53 | s2 = q_text + o_text + a_text
54 |
55 | if split == 'train':
56 | s2 = s2 + answer_mapping[answer]
57 | t2 = self.sp_model.encode(s2) + [self.eos_id]
58 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
59 | prefix_index = t[0].index(self.a_token_id) + 4
60 | else:
61 | t = []
62 | for k, v in answer_mapping.items():
63 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id]
64 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
65 | prefix_index = t[answer].index(self.a_token_id) + 4
66 | return t, prefix_index, video_start
67 |
68 | def encode_vaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
69 | i_text = "Instruction: Predict the question based on the video and answer.\n"
70 | q_text = text['q_text'].strip()
71 | o_text = text['o_text']
72 | a_text = text['a_text']
73 |
74 | s1 = i_text + 'Video:'
75 | t1 = [self.bos_id] + self.sp_model.encode(s1)
76 | video_start = len(t1)
77 |
78 | s2 = o_text + a_text
79 |
80 | if split == 'train':
81 | s2 = s2 + answer_mapping[answer] + "\n" + q_text
82 | t2 = self.sp_model.encode(s2) + [self.eos_id]
83 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
84 | prefix_index = t[0].index(self.q_token_id) + 1
85 | else:
86 | t = []
87 | for k, v in answer_mapping.items():
88 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
89 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
90 | prefix_index = t[answer].index(self.q_token_id) + 1
91 | return t, prefix_index, video_start
92 |
93 |
94 | def encode_qav(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
95 | i_text = "Instruction: Predict the video based on the question and answer.\n"
96 | q_text = text['q_text']
97 | o_text = text['o_text']
98 | a_text = text['a_text']
99 |
100 | s1 = i_text + q_text + o_text + a_text
101 |
102 | if split == 'train':
103 | s1 = s1 + answer_mapping[answer] + "\n" + "Video:"
104 | t1 = [self.bos_id] + self.sp_model.encode(s1)
105 | t = [t1 + [-2 for _ in range(max_feats)] + [self.eos_id]]
106 | prefix_index = t[0].index(self.v_token_id) + 1
107 | else:
108 | t = []
109 | for k, v in answer_mapping.items():
110 | t1 = [self.bos_id] + self.sp_model.encode(s1 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id]
111 | t.append(t1)
112 | prefix_index = t[answer].index(self.v_token_id) + 1
113 | return t, prefix_index
114 |
115 | def decode(self, t: List[int]) -> str:
116 | return self.sp_model.decode(t)
117 |
118 | def encode_dvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
119 | i_text = "Instruction: Predict the answer based on the dialogue, video and question.\n"
120 | q_text = text['q_text']
121 | o_text = text['o_text']
122 | a_text = text['a_text']
123 | d_text = text['d_text']
124 |
125 | s1 = i_text + 'Video:'
126 | t1 = [self.bos_id] + self.sp_model.encode(s1)
127 | video_start = len(t1)
128 |
129 | prefix_i = video_start + max_feats + 1
130 | d1 = self.sp_model.encode(d_text)
131 | prefix_main = prefix_i + len(d1)
132 |
133 | s2 = q_text + o_text + a_text
134 |
135 | if split == 'train':
136 | s2 = s2 + answer_mapping[answer]
137 | t2 = self.sp_model.encode(s2) + [self.eos_id]
138 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2]
139 | else:
140 | t = []
141 | for k, v in answer_mapping.items():
142 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id]
143 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2)
144 |
145 | prefix_index = len(t[0]) - 4
146 |
147 | return t, prefix_index, video_start, prefix_i, prefix_main
148 |
149 | def encode_dvaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
150 | i_text = "Instruction: Predict the question based on the dialogue, video and answer.\n"
151 | q_text = text['q_text'].strip()
152 | o_text = text['o_text']
153 | a_text = text['a_text']
154 | d_text = text['d_text']
155 |
156 | s1 = i_text + 'Video:'
157 | t1 = [self.bos_id] + self.sp_model.encode(s1)
158 | video_start = len(t1)
159 |
160 | prefix_i = video_start + max_feats + 1
161 | d1 = self.sp_model.encode(d_text)
162 | prefix_main = prefix_i + len(d1)
163 |
164 | s2 = o_text + a_text
165 |
166 | if split == 'train':
167 | s2 = s2 + answer_mapping[answer] + "\n" + q_text
168 | t2 = self.sp_model.encode(s2) + [self.eos_id]
169 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2]
170 | else:
171 | t = []
172 | for k, v in answer_mapping.items():
173 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
174 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2)
175 |
176 | prefix_index = t[0].index(self.q_token_id) + 1
177 |
178 | return t, prefix_index, video_start, prefix_i, prefix_main
179 |
180 | def encode_dqav(self, text=None, max_feats=10, max_seq_len=128, split='train', answer_mapping=None, answer=None) -> List[int]:
181 | i_text = "Instruction: Predict the video based on the dialogue, question and answer.\n"
182 | d_text = text['d_text']
183 | q_text = text['q_text']
184 | o_text = text['o_text']
185 | a_text = text['a_text']
186 | s1, s2, s3 = i_text, d_text, q_text + o_text + a_text
187 |
188 | t1 = [self.bos_id] + self.sp_model.encode(s1)
189 | t2 = self.sp_model.encode(s2)
190 | prefix_i, prefix_q = len(t1), len(t1) + len(t2)
191 |
192 | if split == 'train':
193 | t3 = self.sp_model.encode(s3 + answer_mapping[answer] + "\n" + "Video:")
194 | t = [t1 + t2 + t3 + [-2 for _ in range(max_feats)] + [self.eos_id]]
195 | else:
196 | t = []
197 | for k, v in answer_mapping.items():
198 | t3 = self.sp_model.encode(s3 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id]
199 | t.append(t1 + t2 + t3)
200 |
201 | prefix_index = len(t[0]) - max_feats - 1
202 |
203 | return t, prefix_index, prefix_i, prefix_q
204 |
205 |
206 |
207 | def encode_videocap(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
208 | i_text = "Instruction: Generate a summary for the video.\n"
209 | # q_text = text['q_text'].strip()
210 | # o_text = text['o_text']
211 | # a_text = text['a_text']
212 |
213 | s1 = i_text + 'Video:'
214 | t1 = [self.bos_id] + self.sp_model.encode(s1)
215 | video_start = len(t1)
216 |
217 | s2 = text['c_text']
218 |
219 | if split == 'train':
220 | # s2 = s2 + answer_mapping[answer] + "\n" + q_text
221 | s2 = "\n"+s2
222 | t2 = self.sp_model.encode(s2) + [self.eos_id]
223 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
224 | prefix_index = t[0].index(self.c_token_id) + 1
225 | else:
226 | t = []
227 | for k, v in answer_mapping.items():
228 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
229 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
230 | prefix_index = t[answer].index(self.c_token_id) + 1
231 | return t, prefix_index, video_start
232 |
233 |
234 |
235 | def encode_openvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
236 | i_text = "Instruction: Predict the answer based on the video and question.\n"
237 | q_text = text['q_text']
238 | oa_text = text['oa_text']
239 |
240 | s1 = i_text + 'Video:'
241 | t1 = [self.bos_id] + self.sp_model.encode(s1)
242 | video_start = len(t1)
243 |
244 | if split == 'train':
245 | s2 = q_text + oa_text
246 | t2 = self.sp_model.encode(s2) + [self.eos_id]
247 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
248 | prefix_index = t[0].index(self.a_token_id)+1
249 | else:
250 | t = []
251 | for open_option in text['open_options']:
252 | s2 = q_text + open_option
253 | t2 = self.sp_model.encode(s2) + [self.eos_id]
254 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
255 | prefix_index = t[0].index(self.a_token_id)+1
256 | # else:
257 | # t = []
258 | # prefix_index = []
259 | # for open_option in text['open_options']:
260 | # s2 = q_text + open_option+'. For this video, this answer is correct.'
261 | # t2 = self.sp_model.encode(s2) + [self.eos_id]
262 | # t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
263 | # prefix_index.append(len(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)-50)
264 |
265 | return t, prefix_index, video_start
--------------------------------------------------------------------------------
/llama/tokenizer_llama3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3 |
4 | import os
5 | from logging import getLogger
6 | from pathlib import Path
7 | from typing import (
8 | AbstractSet,
9 | cast,
10 | Collection,
11 | Dict,
12 | Iterator,
13 | List,
14 | Literal,
15 | Sequence,
16 | TypedDict,
17 | Union,
18 | )
19 |
20 | import tiktoken
21 | from tiktoken.load import load_tiktoken_bpe
22 |
23 |
24 | logger = getLogger(__name__)
25 |
26 |
27 | Role = Literal["system", "user", "assistant"]
28 |
29 |
30 | class Message(TypedDict):
31 | role: Role
32 | content: str
33 |
34 |
35 | Dialog = Sequence[Message]
36 |
37 |
38 | class Tokenizer_llama3:
39 | """
40 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
41 | """
42 |
43 | special_tokens: Dict[str, int]
44 |
45 | num_reserved_special_tokens = 256
46 |
47 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
48 |
49 | def __init__(self, model_path: str):
50 | """
51 | Initializes the Tokenizer with a Tiktoken model.
52 |
53 | Args:
54 | model_path (str): The path to the Tiktoken model file.
55 | """
56 | assert os.path.isfile(model_path), model_path
57 |
58 | mergeable_ranks = load_tiktoken_bpe(model_path)
59 | num_base_tokens = len(mergeable_ranks)
60 | special_tokens = [
61 | "<|begin_of_text|>",
62 | "<|end_of_text|>",
63 | "<|reserved_special_token_0|>",
64 | "<|reserved_special_token_1|>",
65 | "<|reserved_special_token_2|>",
66 | "<|reserved_special_token_3|>",
67 | "<|start_header_id|>",
68 | "<|end_header_id|>",
69 | "<|reserved_special_token_4|>",
70 | "<|eot_id|>", # end of turn
71 | ] + [
72 | f"<|reserved_special_token_{i}|>"
73 | for i in range(5, self.num_reserved_special_tokens - 5)
74 | ]
75 | self.special_tokens = {
76 | token: num_base_tokens + i for i, token in enumerate(special_tokens)
77 | }
78 | self.sp_model = tiktoken.Encoding(
79 | name=Path(model_path).name,
80 | pat_str=self.pat_str,
81 | mergeable_ranks=mergeable_ranks,
82 | special_tokens=self.special_tokens,
83 | )
84 | logger.info(f"Reloaded tiktoken model from {model_path}")
85 |
86 | self.n_words: int = self.sp_model.n_vocab
87 | # BOS / EOS token IDs
88 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
89 | self.eos_id: int = self.special_tokens["<|end_of_text|>"]
90 | self.pad_id: int = -1
91 | self.stop_tokens = {
92 | self.special_tokens["<|end_of_text|>"],
93 | self.special_tokens["<|eot_id|>"],
94 | }
95 | logger.info(
96 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
97 | )
98 |
99 | self.v_token_id = 10955
100 | self.q_token_id = 14924
101 | self.a_token_id = 16533
102 | self.c_token_id = 5116
103 | self.nl_id = 627
104 |
105 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
106 | assert type(s) is str
107 | t = self.sp_model.encode(s)
108 | if bos:
109 | t = [self.bos_id] + t
110 | if eos:
111 | t = t + [self.eos_id]
112 | return t
113 |
114 | def encode_vqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
115 | i_text = "Instruction: Choose the correct answer based on the video and question.\n"
116 | q_text = text['q_text']
117 | o_text = text['o_text']
118 | a_text = text['a_text']
119 |
120 | s1 = i_text + 'Video:'
121 | t1 = [self.bos_id] + self.sp_model.encode(s1)
122 | video_start = len(t1)
123 |
124 | s2 = q_text + o_text + a_text
125 |
126 | if split == 'train':
127 | s2 = s2 + answer_mapping[answer]
128 | t2 = self.sp_model.encode(s2) + [self.eos_id]
129 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
130 | prefix_index = t[0].index(self.a_token_id) + 5
131 | else:
132 | t = []
133 | for k, v in answer_mapping.items():
134 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id]
135 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
136 | prefix_index = t[answer].index(self.a_token_id) + 5
137 | return t, prefix_index, video_start
138 |
139 | def encode_vaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
140 | i_text = "Instruction: Predict the question based on the video and answer.\n"
141 | q_text = text['q_text'].strip()
142 | o_text = text['o_text']
143 | a_text = text['a_text']
144 |
145 | s1 = i_text + 'Video:'
146 | t1 = [self.bos_id] + self.sp_model.encode(s1)
147 | video_start = len(t1)
148 |
149 | s2 = o_text + a_text
150 |
151 | if split == 'train':
152 | s2 = s2 + answer_mapping[answer] + "\n" + q_text
153 | t2 = self.sp_model.encode(s2) + [self.eos_id]
154 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
155 | prefix_index = t[0].index(self.q_token_id) + 1
156 | else:
157 | t = []
158 | for k, v in answer_mapping.items():
159 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
160 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
161 | prefix_index = t[answer].index(self.q_token_id) + 1
162 | return t, prefix_index, video_start
163 |
164 |
165 | def encode_qav(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
166 | i_text = "Instruction: Predict the video based on the question and answer.\n"
167 | q_text = text['q_text']
168 | o_text = text['o_text']
169 | a_text = text['a_text']
170 |
171 | s1 = i_text + q_text + o_text + a_text
172 |
173 | if split == 'train':
174 | s1 = s1 + answer_mapping[answer] + "\n" + "Video:"
175 | t1 = [self.bos_id] + self.sp_model.encode(s1)
176 | t = [t1 + [-2 for _ in range(max_feats)] + [self.eos_id]]
177 | prefix_index = t[0].index(self.v_token_id) + 1
178 | else:
179 | t = []
180 | for k, v in answer_mapping.items():
181 | t1 = [self.bos_id] + self.sp_model.encode(s1 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id]
182 | t.append(t1)
183 | prefix_index = t[answer].index(self.v_token_id) + 1
184 | return t, prefix_index
185 |
186 | def decode(self, t: List[int]) -> str:
187 | return self.sp_model.decode(t)
188 |
189 | def encode_dvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
190 | i_text = "Instruction: Predict the answer based on the dialogue, video and question.\n"
191 | q_text = text['q_text']
192 | o_text = text['o_text']
193 | a_text = text['a_text']
194 | d_text = text['d_text']
195 |
196 | s1 = i_text + 'Video:'
197 | t1 = [self.bos_id] + self.sp_model.encode(s1)
198 | video_start = len(t1)
199 |
200 | prefix_i = video_start + max_feats + 1
201 | d1 = self.sp_model.encode(d_text)
202 | prefix_main = prefix_i + len(d1)
203 |
204 | s2 = q_text + o_text + a_text
205 |
206 | if split == 'train':
207 | s2 = s2 + answer_mapping[answer]
208 | t2 = self.sp_model.encode(s2) + [self.eos_id]
209 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2]
210 | else:
211 | t = []
212 | for k, v in answer_mapping.items():
213 | t2 = self.sp_model.encode(s2 + v) + [self.eos_id]
214 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2)
215 |
216 | prefix_index = len(t[0]) - 4
217 |
218 | return t, prefix_index, video_start, prefix_i, prefix_main
219 |
220 | def encode_dvaq(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
221 | i_text = "Instruction: Predict the question based on the dialogue, video and answer.\n"
222 | q_text = text['q_text'].strip()
223 | o_text = text['o_text']
224 | a_text = text['a_text']
225 | d_text = text['d_text']
226 |
227 | s1 = i_text + 'Video:'
228 | t1 = [self.bos_id] + self.sp_model.encode(s1)
229 | video_start = len(t1)
230 |
231 | prefix_i = video_start + max_feats + 1
232 | d1 = self.sp_model.encode(d_text)
233 | prefix_main = prefix_i + len(d1)
234 |
235 | s2 = o_text + a_text
236 |
237 | if split == 'train':
238 | s2 = s2 + answer_mapping[answer] + "\n" + q_text
239 | t2 = self.sp_model.encode(s2) + [self.eos_id]
240 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2]
241 | else:
242 | t = []
243 | for k, v in answer_mapping.items():
244 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
245 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + d1 + t2)
246 |
247 | prefix_index = t[0].index(self.q_token_id) + 1
248 |
249 | return t, prefix_index, video_start, prefix_i, prefix_main
250 |
251 | def encode_dqav(self, text=None, max_feats=10, max_seq_len=128, split='train', answer_mapping=None, answer=None) -> List[int]:
252 | i_text = "Instruction: Predict the video based on the dialogue, question and answer.\n"
253 | d_text = text['d_text']
254 | q_text = text['q_text']
255 | o_text = text['o_text']
256 | a_text = text['a_text']
257 | s1, s2, s3 = i_text, d_text, q_text + o_text + a_text
258 |
259 | t1 = [self.bos_id] + self.sp_model.encode(s1)
260 | t2 = self.sp_model.encode(s2)
261 | prefix_i, prefix_q = len(t1), len(t1) + len(t2)
262 |
263 | if split == 'train':
264 | t3 = self.sp_model.encode(s3 + answer_mapping[answer] + "\n" + "Video:")
265 | t = [t1 + t2 + t3 + [-2 for _ in range(max_feats)] + [self.eos_id]]
266 | else:
267 | t = []
268 | for k, v in answer_mapping.items():
269 | t3 = self.sp_model.encode(s3 + v + "\n" + "Video:") + [-2 for _ in range(max_feats)] + [self.eos_id]
270 | t.append(t1 + t2 + t3)
271 |
272 | prefix_index = len(t[0]) - max_feats - 1
273 |
274 | return t, prefix_index, prefix_i, prefix_q
275 |
276 |
277 |
278 | def encode_videocap(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
279 | i_text = "Instruction: Generate a dense description for the video.\n"
280 | # q_text = text['q_text'].strip()
281 | # o_text = text['o_text']
282 | # a_text = text['a_text']
283 |
284 | s1 = i_text + 'Video:'
285 | t1 = [self.bos_id] + self.sp_model.encode(s1)
286 | video_start = len(t1)
287 |
288 | s2 = text['c_text']
289 |
290 | if split == 'train':
291 | # s2 = s2 + answer_mapping[answer] + "\n" + q_text
292 | s2 = "\n"+s2
293 | t2 = self.sp_model.encode(s2) + [self.eos_id]
294 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
295 | prefix_index = t[0].index(self.c_token_id) + 1
296 | else:
297 | t = []
298 | for k, v in answer_mapping.items():
299 | t2 = self.sp_model.encode(s2 + v + "\n" + q_text) + [self.eos_id]
300 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
301 | prefix_index = t[answer].index(self.c_token_id) + 1
302 | return t, prefix_index, video_start
303 |
304 |
305 |
306 | def encode_openvqa(self, text=None, max_feats=10, split='train', answer_mapping=None, answer=None) -> List[int]:
307 | i_text = "Instruction: Predict the answer based on the video and question.\n"
308 | q_text = text['q_text']
309 | oa_text = text['oa_text']
310 |
311 | s1 = i_text + 'Video:'
312 | t1 = [self.bos_id] + self.sp_model.encode(s1)
313 | video_start = len(t1)
314 |
315 | if split == 'train':
316 | s2 = q_text + oa_text
317 | t2 = self.sp_model.encode(s2) + [self.eos_id]
318 | t = [t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2]
319 | prefix_index = t[0].index(self.a_token_id)+1
320 | else:
321 | t = []
322 | for open_option in text['open_options']:
323 | s2 = q_text + open_option
324 | t2 = self.sp_model.encode(s2) + [self.eos_id]
325 | t.append(t1 + [-2 for _ in range(max_feats)] + [self.nl_id] + t2)
326 | prefix_index = t[0].index(self.a_token_id)+1
327 | return t, prefix_index, video_start
--------------------------------------------------------------------------------
/llama_vqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | from llama import ModelArgs, Tokenizer, Transformer
4 | from llama import Transformer_llama3,Tokenizer_llama3,ModelArgs_llama3
5 | from pathlib import Path
6 |
7 | def LLaMA_VQA(args, **kwargs):
8 | with open(f'{args.llama_model_path}{args.model}/params.json', "r") as f:
9 | params = json.loads(f.read())
10 |
11 | if args.llama3:
12 | tokenizer = Tokenizer_llama3(model_path=f'{args.llama_model_path}/tokenizer.model')
13 | else:
14 | tokenizer = Tokenizer(model_path=f'{args.llama_model_path}/tokenizer.model')
15 | print(f"Using model: {args.model}")
16 |
17 |
18 | checkpoints = (Path(args.llama_model_path) / args.model).glob("*.pth")
19 | checkpoints = sorted(checkpoints)
20 |
21 | loaded = []
22 | for x in checkpoints:
23 | print("loading from", x)
24 | loaded.append(torch.load(x, map_location="cpu"))
25 |
26 | if len(loaded) == 1:
27 | full_state_dict = loaded[0]
28 | else:
29 | full_state_dict = {}
30 | split_dims = {}
31 |
32 | def add_weight_with_split_dim(name, dim):
33 | if dim < 0: # bcast without split
34 | full_state_dict[name] = loaded[0][name].clone()
35 | else:
36 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim)
37 | for x in loaded:
38 | del x[name]
39 | split_dims[name] = dim
40 |
41 | add_weight_with_split_dim("tok_embeddings.weight", 1)
42 | add_weight_with_split_dim("norm.weight", -1)
43 | add_weight_with_split_dim("output.weight", 0)
44 | for i in range(params["n_layers"]):
45 | print("gathering layer %d of %d" % (i, params["n_layers"]))
46 | layer_prefix = f"layers.{i}."
47 | bcast_names = ["attention_norm.weight", "ffn_norm.weight"]
48 | column_parallel_names = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight", "feed_forward.w1.weight", "feed_forward.w3.weight"]
49 | row_parallel_names = ["attention.wo.weight", "feed_forward.w2.weight"]
50 | for key in bcast_names:
51 | add_weight_with_split_dim(layer_prefix + key, -1)
52 | for key in column_parallel_names:
53 | add_weight_with_split_dim(layer_prefix + key, 0)
54 | for key in row_parallel_names:
55 | add_weight_with_split_dim(layer_prefix + key, 1)
56 |
57 | if args.llama3:
58 | model_args: ModelArgs = ModelArgs_llama3(max_seq_len=args.max_seq_len, max_batch_size=32, adapter_len=args.adapter_len, adapter_layer=args.adapter_layer, **params)
59 | else:
60 | model_args: ModelArgs = ModelArgs(max_seq_len=args.max_seq_len, max_batch_size=32, adapter_len=args.adapter_len, adapter_layer=args.adapter_layer, **params)
61 |
62 |
63 | model_args.vocab_size = tokenizer.n_words
64 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
65 | if args.llama3:
66 | model_llama_vqa = Transformer_llama3(model_args, args)
67 | else:
68 | model_llama_vqa = Transformer(model_args, args)
69 | torch.set_default_tensor_type(torch.FloatTensor)
70 | missing_keys, unexpected_keys = model_llama_vqa.load_state_dict(full_state_dict, strict=False)
71 |
72 | for name, param in model_llama_vqa.named_parameters():
73 |
74 | if ('gate' in name) or ('adapter' in name) or ('temporal_emb' in name) or ('visual_proj' in name) or ('query_tokens' in name):
75 | param.requires_grad = True
76 | param.data = param.data.float()
77 |
78 | else:
79 | # print(name)
80 | # print(param.data.dtype)
81 | if args.llama2:
82 | param.data = param.data.bfloat16()
83 | param.requires_grad = False
84 |
85 |
86 | return model_llama_vqa
--------------------------------------------------------------------------------
/pics/topa_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/pics/topa_framework.jpg
--------------------------------------------------------------------------------
/pretrained/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/pretrained/.gitkeep
--------------------------------------------------------------------------------
/scripts/baseline/llama2_13b.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
2 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 4 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_nextqa_1e2 --adapter_len 50 \
4 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \
5 |
6 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
7 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
8 | --blr 2e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_star_2e2 --adapter_len 50 \
9 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40\
10 |
11 |
12 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
13 | --max_seq_len 128 --batch_size 6 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
14 | --blr 2e-2 --weight_decay 0.1 --accum_iter 8 --output_dir vqa_checkpoint/baseline/llama2_13b_vtqa_2e2 --adapter_len 50 \
15 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40\
--------------------------------------------------------------------------------
/scripts/baseline/llama2_7b.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
2 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_nextqa_2e2_acc4 --adapter_len 50 \
4 | --llama2 --llama_model_path ./pretrained/llama2/ \
5 |
6 |
7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
8 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
9 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_star_1e2_acc4 --adapter_len 50 \
10 | --llama2 --llama_model_path ./pretrained/llama2/ \
11 |
12 |
13 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
14 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
15 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama2_7b_star_1e2_acc4 --adapter_len 50 \
16 | --llama2 --llama_model_path ./pretrained/llama2/ \
17 |
18 |
19 |
--------------------------------------------------------------------------------
/scripts/baseline/llama3_8b.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
2 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
3 | --blr 1e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_tvqa_1e2_acc4 --adapter_len 50 \
4 | --llama3 --llama_model_path ./pretrained/llama3/ \
5 |
6 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
7 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
8 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_nextqa_2e2_acc4 --adapter_len 50 \
9 | --llama3 --llama_model_path ./pretrained/llama3/ \
10 |
11 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
12 | --max_seq_len 128 --batch_size 20 --epochs 10 --warmup_epochs 2 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
13 | --blr 2e-2 --weight_decay 0.1 --accum_iter 4 --output_dir vqa_checkpoint/baseline/llama3_star_2e2_acc4 --adapter_len 50 \
14 | --llama3 --llama_model_path ./pretrained/llama3/ \
15 |
16 |
--------------------------------------------------------------------------------
/scripts/eval/zeroshot_eval_egos.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 7B \
2 | --max_seq_len 128 --batch_size 1 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \
3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
4 | --output_dir results/llama2_7B/egos \
5 | --llama2 --llama_model_path ./pretrained/llama2/ \
6 | --memory \
7 | --test
8 |
9 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 13B --adapter_layer 40 \
10 | # --max_seq_len 128 --batch_size 1 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \
11 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \
12 | # --output_dir results/llama2_13B/egos \
13 | # --llama2 --llama_model_path ./pretrained/llama2/ \
14 | # --memory \
15 | # --test
16 |
17 |
18 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 1 train.py --model 8B \
19 | # --max_seq_len 150 --batch_size 1 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset egos \
20 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
21 | # --output_dir results/llama3_8B/egos \
22 | # --llama3 --llama_model_path ./pretrained/llama3/ \
23 | # --memory \
24 | # --test
25 |
26 |
--------------------------------------------------------------------------------
/scripts/eval/zeroshot_eval_nextqa.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
2 | --max_seq_len 128 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \
3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
4 | --output_dir results/llama2_7B/nextqa \
5 | --llama2 --llama_model_path ./pretrained/llama2/ \
6 | --memory \
7 |
8 |
9 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \
10 | # --max_seq_len 128 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \
11 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \
12 | # --output_dir results/llama2_13B/nextqa \
13 | # --llama2 --llama_model_path ./pretrained/llama2/ \
14 | # --memory \
15 |
16 |
17 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
18 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset nextqa \
19 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
20 | # --output_dir results/llama3_8B/nextqa \
21 | # --llama3 --llama_model_path ./pretrained/llama3/ \
22 | # --memory \
23 |
24 |
25 |
--------------------------------------------------------------------------------
/scripts/eval/zeroshot_eval_star.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
2 | --max_seq_len 128 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \
3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
4 | --output_dir results/llama2_7B/star \
5 | --llama2 --llama_model_path ./pretrained/llama2/ \
6 | --memory \
7 |
8 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \
9 | # --max_seq_len 128 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \
10 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \
11 | # --output_dir results/llama2_13B/star \
12 | # --llama2 --llama_model_path ./pretrained/llama2/ \
13 | # --memory \
14 |
15 |
16 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
17 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset star \
18 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
19 | # --output_dir results/llama3_8B/star \
20 | # --llama3 --llama_model_path ./pretrained/llama3/ \
21 | # --memory \
22 |
23 |
24 |
--------------------------------------------------------------------------------
/scripts/eval/zeroshot_eval_tvqa.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
2 | --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \
3 | --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
4 | --output_dir results/llama2_7B/tvqa \
5 | --llama2 --llama_model_path ./pretrained/llama2/ \
6 | --memory \
7 |
8 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B --adapter_layer 40 \
9 | # --max_seq_len 150 --batch_size 10 --epochs 1 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \
10 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs4_vnips/checkpoint_18.pth --adapter_len 50 --accum_iter 2 --eval \
11 | # --output_dir results/llama2_13B/tvqa \
12 | # --llama2 --llama_model_path ./pretrained/llama2/ \
13 | # --memory \
14 |
15 |
16 | # torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
17 | # --max_seq_len 150 --batch_size 10 --epochs 5 --warmup_epochs 2 --bias 3 --tau 100. --max_feats 10 --dataset tvqa \
18 | # --blr 9e-2 --weight_decay 0.16 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --adapter_len 50 --accum_iter 2 --eval \
19 | # --output_dir results/llama3_8B/tvqa \
20 | # --llama3 --llama_model_path ./pretrained/llama3/ \
21 | # --memory \
22 |
23 |
24 |
--------------------------------------------------------------------------------
/scripts/finetune/LLama2_13b_finetune.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
2 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
3 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_nextqa_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \
4 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \
5 |
6 |
7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
8 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
9 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_star_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \
10 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \
11 |
12 |
13 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 13B \
14 | --max_seq_len 128 --batch_size 6 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
15 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_13b_acc8_br8e3_bs6/checkpoint_18.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnip_llama2_13b_finetune_tvqa_5e3_acc16 --accum_iter 16 --adapter_len 50 --finetune \
16 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \
17 |
--------------------------------------------------------------------------------
/scripts/finetune/LLama2_7b_finetune.sh:
--------------------------------------------------------------------------------
1 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
2 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
3 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_nextqa_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \
4 | --llama2 --llama_model_path ./pretrained/llama2/
5 |
6 |
7 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
8 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
9 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_realstar_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \
10 | --llama2 --llama_model_path ./pretrained/llama2/
11 |
12 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 7B \
13 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
14 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama2_7b_acc4_br5e3_correct_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/vnips_7b_finetune_tvqa_5e3_acc4 --accum_iter 4 --adapter_len 50 --finetune \
15 | --llama2 --llama_model_path ./pretrained/llama2/
--------------------------------------------------------------------------------
/scripts/finetune/LLama3_finetune.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
3 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset nextqa \
4 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_nextqa_1e2_acc4 --accum_iter 4 --adapter_len 50 --finetune \
5 | --llama3 --llama_model_path ./pretrained/llama3/ \
6 |
7 |
8 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
9 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset star \
10 | --blr 1e-2 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_star_1e2_acc4_ep10 --accum_iter 4 --adapter_len 50 --finetune \
11 | --llama3 --llama_model_path ./pretrained/llama3/ \
12 |
13 |
14 |
15 | torchrun --rdzv_endpoint 127.0.0.1:1234 --nproc_per_node 4 train.py --model 8B \
16 | --max_seq_len 128 --batch_size 20 --epochs 5 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset tvqa \
17 | --blr 5e-3 --weight_decay 0.1 --resume vqa_checkpoint/checkpoint_pretrain/llama3_7b_acc4_br5e3_ep20_vnips/checkpoint_19.pth --output_dir vqa_checkpoint/checkpoint_finetune/llama3_finetune_tvqa_5e3_acc4 --accum_iter 4 --adapter_len 50 --finetune \
18 | --llama3 --llama_model_path ./pretrained/llama3/ \
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/scripts/pretrain/llama2_13b.sh:
--------------------------------------------------------------------------------
1 |
2 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number
3 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 13B \
4 | --max_seq_len 150 --batch_size 4 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \
5 | --blr 8e-3 --weight_decay 0.1 --output_dir vqa_checkpoint/checkpoint_pretrain/llama2_13b_test --accum_iter 8 --textvid --variance 0.0 --memory --video_caption --vaq --openvqa --answer_balance --adapter_len 50 \
6 | --llama2 --llama_model_path ./pretrained/llama2/ --adapter_layer 40 \
7 |
8 |
--------------------------------------------------------------------------------
/scripts/pretrain/llama2_7b.sh:
--------------------------------------------------------------------------------
1 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number
2 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 7B \
3 | --max_seq_len 150 --batch_size 18 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \
4 | --blr 5e-3 --weight_decay 0.1 --output_dir vqa_checkpoint/checkpoint_pretrain/llama2_test --accum_iter 4 --textvid --variance 0.0 --video_caption --vaq --openvqa --answer_balance --adapter_len 50 --memory \
5 | --llama2 --llama_model_path ./pretrained/llama2/ \
6 |
--------------------------------------------------------------------------------
/scripts/pretrain/llama3_8b.sh:
--------------------------------------------------------------------------------
1 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number
2 | torchrun --rdzv_endpoint 127.0.0.1:${randport} --nproc_per_node 4 train.py --model 8B \
3 | --max_seq_len 150 --batch_size 14 --epochs 20 --warmup_epochs 1 --bias 3.5 --tau 100. --max_feats 10 --dataset textvid \
4 | --blr 5e-3 --weight_decay 0.1 --output_dir ./vqa_checkpoint/checkpoint_pretrain/llama3_test --accum_iter 8 --textvid --variance 0.0 --memory --video_caption --vaq --openvqa --answer_balance --adapter_len 50 \
5 | --llama3 --llama_model_path ./pretrained/llama3/ \
6 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
2 | pip install fairscale
3 | pip install fire
4 | pip install sentencepiece
5 | pip install transformers
6 | pip install timm
7 | pip install pandas
8 | pip install setuptools==59.5.0
9 | pip install pysrt
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import datetime
4 | import json
5 | import time
6 | import numpy as np
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 |
12 | import timm
13 | import timm.optim.optim_factory as optim_factory
14 |
15 | import util.misc as misc
16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
17 | from engine import train_one_epoch, val_one_epoch,test_one_epoch
18 | from llama import Tokenizer, Tokenizer_llama3
19 | from llama_vqa import LLaMA_VQA
20 | from dataloader import load_data, load_data_instruct
21 | from torch.utils.data import DataLoader, ConcatDataset
22 |
23 | def save_arguments(args, filepath):
24 | with open(filepath, 'w') as file:
25 | json.dump(vars(args), file)
26 |
27 | def load_arguments(filepath):
28 | with open(filepath, 'r') as file:
29 | args_dict = json.load(file)
30 | return args_dict
31 |
32 | # Optionally, repopulate argparse.ArgumentParser with these arguments
33 | def repopulate_arguments(args_dict):
34 | parser = argparse.ArgumentParser(description="Example script")
35 | for key, value in args_dict.items():
36 | parser.add_argument(f'--{key}', type=type(value),default=value)
37 | return parser.parse_args([])
38 |
39 | def get_args_parser():
40 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
41 | parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
42 | parser.add_argument('--epochs', default=400, type=int)
43 | parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
44 |
45 | # Model parameters
46 | parser.add_argument('--llama_model_path', default='./pretrained/llama/', type=str, help='path of llama model')
47 | parser.add_argument('--model', default='llama7B_adapter', type=str, metavar='MODEL', help='Name of model to train')
48 | parser.add_argument('--adapter_layer', type=int, default=32, metavar='LENGTH', help='the number of adapter layer')
49 | parser.add_argument('--adapter_len', type=int, default=10, metavar='LENGTH', help='the adapter length')
50 | parser.add_argument('--max_seq_len', type=int, default=512, metavar='LENGTH', help='the maximum sequence length')
51 | parser.add_argument('--max_feats', type=int, default=10, metavar='LENGTH', help='the maximum feature length')
52 |
53 | # Optimizer parameters
54 | parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)')
55 | parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
56 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
57 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
58 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
59 |
60 | # Dataset parameters
61 | parser.add_argument('--dataset', default='nextqa', type=str, help='dataset')
62 | parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving')
63 | parser.add_argument('--device', default='cuda', help='device to use for training / testing')
64 | parser.add_argument('--seed', default=0, type=int)
65 | parser.add_argument('--resume', default='', help='resume from checkpoint')
66 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
67 | parser.add_argument('--num_workers', default=2, type=int)
68 | parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
69 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
70 | parser.set_defaults(pin_mem=True)
71 |
72 | # distributed training parameters
73 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
74 | parser.add_argument('--local_rank', default=-1, type=int)
75 | parser.add_argument('--dist_on_itp', action='store_true')
76 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
77 |
78 | parser.add_argument('--vaq', action='store_true', help='vaq loss')
79 | parser.add_argument('--qav', action='store_true', help='qav loss')
80 | parser.add_argument('--bias', type=float, default=3., help='attention bias')
81 | parser.add_argument('--tau', type=float, default=100., help='tau')
82 | parser.add_argument('--sub', action='store_true', help='subtitles for VLEP and TVQA')
83 | parser.add_argument('--eval', action='store_true', help='eval')
84 | parser.add_argument('--test', action='store_true', help='test')
85 | parser.add_argument('--memory', action='store_true', help='meomory')
86 | parser.add_argument('--finetune', action='store_true', help='finetune')
87 | parser.add_argument('--data_ratio', type=float, default=1., help='tau')
88 | parser.add_argument('--textvid', action='store_true', help='virtual video training')
89 | parser.add_argument('--variance', type=float, default=0., help='variance')
90 | parser.add_argument('--evalall', action='store_true', help='evalall')
91 | parser.add_argument('--debug', action='store_true', help='debug')
92 | parser.add_argument('--onlyqa', action='store_true', help='onlyqa')
93 | parser.add_argument('--llama2', action='store_true', help='llama2')
94 | parser.add_argument('--llama3', action='store_true', help='llama3')
95 | parser.add_argument('--answer_balance', action='store_true', help='balance_abcde')
96 | parser.add_argument('--video_caption', action='store_true', help='video captioning training')
97 | parser.add_argument('--instruct', action='store_true', help='instruct')
98 | parser.add_argument('--openvqa', action='store_true', help='openvqa')
99 | parser.add_argument('--weight_captioning', type=float, default=1.0, help='weight_captioning')
100 | parser.add_argument('--webvid', action='store_true', help='webvidfituning')
101 | parser.add_argument('--openvqa_eval', action='store_true', help='logits for MCQA')
102 | parser.add_argument('--single_frame', action='store_true', help='single_frame')
103 |
104 |
105 |
106 |
107 | return parser
108 |
109 |
110 | def main(args):
111 | misc.init_distributed_mode(args)
112 |
113 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
114 | print("{}".format(args).replace(', ', ',\n'))
115 |
116 | device = torch.device(args.device)
117 |
118 | # fix the seed for reproducibility
119 | seed = args.seed + misc.get_rank()
120 | torch.manual_seed(seed)
121 | np.random.seed(seed)
122 |
123 | cudnn.benchmark = True
124 | if args.llama3:
125 | tokenizer = Tokenizer_llama3(model_path=f'{args.llama_model_path}./tokenizer.model')
126 | else:
127 | tokenizer = Tokenizer(model_path=f'{args.llama_model_path}./tokenizer.model')
128 |
129 |
130 | model = LLaMA_VQA(args)
131 | model.to(device)
132 |
133 | model_without_ddp = model
134 | # print("Model = %s" % str(model_without_ddp))
135 |
136 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
137 |
138 | if args.lr is None: # only base_lr is specified
139 | args.lr = args.blr * eff_batch_size / 256
140 |
141 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
142 | print("actual lr: %.2e" % args.lr)
143 |
144 | print("accumulate grad iterations: %d" % args.accum_iter)
145 | print("effective batch size: %d" % eff_batch_size)
146 |
147 | if args.distributed:
148 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
149 | model_without_ddp = model.module
150 |
151 | # following timm: set wd as 0 for bias and norm layers
152 |
153 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
154 |
155 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
156 | print(optimizer)
157 | loss_scaler = NativeScaler()
158 | best_acc = 0.
159 |
160 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
161 |
162 |
163 | print(f"Start training for {args.epochs} epochs")
164 | start_time = time.time()
165 |
166 | if args.eval:
167 |
168 | epoch=0
169 | if args.dataset == 'egos':
170 | args.batch_size=1
171 | args.max_seq_len = 600
172 |
173 | model.module.re_init_freqs(600)
174 | if args.test:
175 | data_loader_val = load_data(args, tokenizer, split='test')
176 | if args.distributed:
177 | data_loader_val.sampler.set_epoch(epoch)
178 | val_stats = test_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args)
179 | else:
180 | data_loader_val = load_data(args, tokenizer, split='val')
181 | if args.distributed:
182 | data_loader_val.sampler.set_epoch(epoch)
183 | val_stats = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args)
184 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}}
185 |
186 | if args.output_dir and misc.is_main_process():
187 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
188 | f.write(json.dumps(log_stats) + "\n")
189 |
190 | elif args.textvid:
191 | data_loader_vals = {}
192 | batch_size = args.batch_size
193 | max_seq_len = args.max_seq_len
194 | data_loader_train = load_data(args, tokenizer, split='train')
195 |
196 | eval_datasets = ['egos']
197 |
198 | for dataset_name in eval_datasets:
199 | # for dataset_name in ['egos','tvqa']:
200 | args.dataset=dataset_name
201 | if dataset_name in ['egos']:
202 | args.batch_size=1
203 | args.max_seq_len = 600
204 | else:
205 | args.batch_size= batch_size
206 | args.max_seq_len = 200
207 | data_loader_vals[dataset_name] = load_data(args, tokenizer, split='val')
208 |
209 | for epoch in range(args.start_epoch, args.epochs):
210 |
211 | if args.distributed:
212 | data_loader_train.sampler.set_epoch(epoch)
213 | model.module.re_init_freqs(600)
214 | train_stats = train_one_epoch(model, data_loader_train, optimizer, epoch, loss_scaler, args=args)
215 | val_stats = {}
216 | for key,data_loader_val in data_loader_vals.items():
217 | if args.distributed:
218 | data_loader_val.sampler.set_epoch(epoch)
219 | val_stats[key] = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args)
220 |
221 | if True:
222 | model_name = f"checkpoint_{epoch}"
223 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name)
224 |
225 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch}
226 |
227 | if args.output_dir and misc.is_main_process():
228 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
229 | f.write(json.dumps(log_stats) + "\n")
230 | for key, val_stat in val_stats.items():
231 | log_stat = {'dataset:':key, **{f'val_{k}': v for k, v in val_stat.items()}}
232 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
233 | f.write(json.dumps(log_stat) + "\n")
234 |
235 | else:
236 | model.module.re_init_freqs(600)
237 | data_loader_train = load_data(args, tokenizer, split='train')
238 | data_loader_val = load_data(args, tokenizer, split='val')
239 | for epoch in range(args.start_epoch, args.epochs):
240 |
241 | if args.distributed:
242 | data_loader_train.sampler.set_epoch(epoch)
243 | data_loader_val.sampler.set_epoch(epoch)
244 |
245 | train_stats = train_one_epoch(model, data_loader_train, optimizer, epoch, loss_scaler, args=args)
246 | val_stats = val_one_epoch(model_without_ddp, data_loader_val, optimizer, epoch, args=args)
247 |
248 | if args.output_dir and best_acc < val_stats['acc']:
249 | best_acc = val_stats['acc']
250 | model_name = 'checkpoint_best'
251 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name)
252 | if True:
253 | model_name = f"checkpoint_{epoch}"
254 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, name=model_name)
255 |
256 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, **{f'val_{k}': v for k, v in val_stats.items()}}
257 |
258 | if args.output_dir and misc.is_main_process():
259 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
260 | f.write(json.dumps(log_stats) + "\n")
261 |
262 | total_time = time.time() - start_time
263 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
264 | print('Training time {}'.format(total_time_str))
265 |
266 |
267 | if __name__ == '__main__':
268 | args = get_args_parser()
269 | args = args.parse_args()
270 | if args.output_dir:
271 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
272 | save_arguments(args, args.output_dir+'/args.json')
273 |
274 | main(args)
275 |
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if epoch < args.warmup_epochs:
12 | lr = args.lr * epoch / args.warmup_epochs
13 | else:
14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import os
15 | import time
16 | from collections import defaultdict, deque
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch import inf
22 |
23 | class SmoothedValue(object):
24 | """Track a series of values and provide access to smoothed values over a
25 | window or the global series average.
26 | """
27 |
28 | def __init__(self, window_size=20, fmt=None):
29 | if fmt is None:
30 | fmt = "{median:.4f} ({global_avg:.4f})"
31 | self.deque = deque(maxlen=window_size)
32 | self.total = 0.0
33 | self.count = 0
34 | self.fmt = fmt
35 |
36 | def update(self, value, n=1):
37 | self.deque.append(value)
38 | self.count += n
39 | self.total += value * n
40 |
41 | def synchronize_between_processes(self):
42 | """
43 | Warning: does not synchronize the deque!
44 | """
45 | if not is_dist_avail_and_initialized():
46 | return
47 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
48 | dist.barrier()
49 | dist.all_reduce(t)
50 | t = t.tolist()
51 | self.count = int(t[0])
52 | self.total = t[1]
53 |
54 | @property
55 | def median(self):
56 | d = torch.tensor(list(self.deque))
57 | return d.median().item()
58 |
59 | @property
60 | def avg(self):
61 | d = torch.tensor(list(self.deque), dtype=torch.float32)
62 | return d.mean().item()
63 |
64 | @property
65 | def global_avg(self):
66 | return self.total / self.count
67 |
68 | @property
69 | def max(self):
70 | return max(self.deque)
71 |
72 | @property
73 | def value(self):
74 | return self.deque[-1]
75 |
76 | def __str__(self):
77 | return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, n=1, **kwargs):
86 | for k, v in kwargs.items():
87 | if v is None:
88 | continue
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v, n=n)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
100 |
101 | def __str__(self):
102 | loss_str = []
103 | for name, meter in self.meters.items():
104 | loss_str.append("{}: {}".format(name, str(meter)))
105 | return self.delimiter.join(loss_str)
106 |
107 | def synchronize_between_processes(self):
108 | for meter in self.meters.values():
109 | meter.synchronize_between_processes()
110 |
111 | def add_meter(self, name, meter):
112 | self.meters[name] = meter
113 |
114 | def log_every(self, iterable, print_freq, header=None):
115 | i = 0
116 | if not header:
117 | header = ''
118 | start_time = time.time()
119 | end = time.time()
120 | iter_time = SmoothedValue(fmt='{avg:.4f}')
121 | data_time = SmoothedValue(fmt='{avg:.4f}')
122 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
123 | log_msg = [header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}']
124 | if torch.cuda.is_available():
125 | log_msg.append('max mem: {memory:.0f}')
126 | log_msg = self.delimiter.join(log_msg)
127 | MB = 1024.0 * 1024.0
128 | for obj in iterable:
129 | data_time.update(time.time() - end)
130 | yield obj
131 | iter_time.update(time.time() - end)
132 | if i % print_freq == 0 or i == len(iterable) - 1:
133 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
134 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
135 | if torch.cuda.is_available():
136 | print(log_msg.format(i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB))
137 | else:
138 | print(log_msg.format(i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)))
139 | i += 1
140 | end = time.time()
141 | total_time = time.time() - start_time
142 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
143 | print('{} Total time: {} ({:.4f} s / it)'.format(header, total_time_str, total_time / len(iterable)))
144 |
145 |
146 | def setup_for_distributed(is_master):
147 | """
148 | This function disables printing when not in master process
149 | """
150 | builtin_print = builtins.print
151 |
152 | def print(*args, **kwargs):
153 | force = kwargs.pop('force', False)
154 | force = force or (get_world_size() > 8)
155 | if is_master or force:
156 | now = datetime.datetime.now().time()
157 | builtin_print('[{}] '.format(now), end='') # print with time stamp
158 | builtin_print(*args, **kwargs)
159 |
160 | builtins.print = print
161 |
162 |
163 | def is_dist_avail_and_initialized():
164 | if not dist.is_available():
165 | return False
166 | if not dist.is_initialized():
167 | return False
168 | return True
169 |
170 |
171 | def get_world_size():
172 | if not is_dist_avail_and_initialized():
173 | return 1
174 | return dist.get_world_size()
175 |
176 |
177 | def get_rank():
178 | if not is_dist_avail_and_initialized():
179 | return 0
180 | return dist.get_rank()
181 |
182 |
183 | def is_main_process():
184 | return get_rank() == 0
185 |
186 |
187 | def save_on_master(*args, **kwargs):
188 | if is_main_process():
189 | torch.save(*args, **kwargs)
190 |
191 |
192 | def init_distributed_mode(args):
193 | if args.dist_on_itp:
194 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
195 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
196 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
197 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
198 | os.environ['LOCAL_RANK'] = str(args.gpu)
199 | os.environ['RANK'] = str(args.rank)
200 | os.environ['WORLD_SIZE'] = str(args.world_size)
201 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
202 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
203 | args.rank = int(os.environ["RANK"])
204 | args.world_size = int(os.environ['WORLD_SIZE'])
205 | args.gpu = int(os.environ['LOCAL_RANK'])
206 | elif 'SLURM_PROCID' in os.environ:
207 | args.rank = int(os.environ['SLURM_PROCID'])
208 | args.gpu = args.rank % torch.cuda.device_count()
209 | else:
210 | print('Not using distributed mode')
211 | setup_for_distributed(is_master=True) # hack
212 | args.distributed = False
213 | return
214 |
215 | args.distributed = True
216 |
217 | torch.cuda.set_device(args.gpu)
218 | args.dist_backend = 'nccl'
219 | print('| distributed init (rank {}): {}, gpu {}'.format(args.rank, args.dist_url, args.gpu), flush=True)
220 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
221 | torch.distributed.barrier()
222 |
223 | setup_for_distributed(args.rank == 0)
224 |
225 |
226 | class NativeScalerWithGradNormCount:
227 | state_dict_key = "amp_scaler"
228 |
229 | def __init__(self):
230 | self._scaler = torch.cuda.amp.GradScaler()
231 |
232 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
233 | self._scaler.scale(loss).backward(create_graph=create_graph)
234 | if update_grad:
235 | if clip_grad is not None:
236 | assert parameters is not None
237 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
238 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
239 | else:
240 | self._scaler.unscale_(optimizer)
241 | norm = get_grad_norm_(parameters)
242 | self._scaler.step(optimizer)
243 | self._scaler.update()
244 | else:
245 | norm = None
246 | return norm
247 |
248 | def state_dict(self):
249 | return self._scaler.state_dict()
250 |
251 | def load_state_dict(self, state_dict):
252 | self._scaler.load_state_dict(state_dict)
253 |
254 |
255 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
256 | if isinstance(parameters, torch.Tensor):
257 | parameters = [parameters]
258 | parameters = [p for p in parameters if p.grad is not None]
259 | norm_type = float(norm_type)
260 | if len(parameters) == 0:
261 | return torch.tensor(0.)
262 | device = parameters[0].grad.device
263 | if norm_type == inf:
264 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
265 | else:
266 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
267 | return total_norm
268 |
269 |
270 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, name):
271 | output_dir = Path(args.output_dir)
272 | epoch_name = str(epoch)
273 | if loss_scaler is not None:
274 | checkpoint_paths = [output_dir / (f'{name}.pth')]
275 |
276 | unfrozen_model = {}
277 | for n, p in model_without_ddp.named_parameters():
278 | if ('gate' in n) or ('adapter' in n) or ('temporal_emb' in n) or ('visual_proj' in n):
279 | unfrozen_model[n] = p
280 |
281 | for checkpoint_path in checkpoint_paths:
282 | to_save = {
283 | 'model': unfrozen_model,
284 | 'optimizer': optimizer.state_dict(),
285 | 'epoch': epoch,
286 | 'scaler': loss_scaler.state_dict(),
287 | 'args': args,
288 | }
289 |
290 | save_on_master(to_save, checkpoint_path)
291 | else:
292 | client_state = {'epoch': epoch}
293 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
294 |
295 |
296 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
297 | if args.resume:
298 | if args.resume.startswith('https'):
299 | checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)
300 | else:
301 | checkpoint = torch.load(args.resume, map_location='cpu')
302 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
303 | print("Resume checkpoint %s" % args.resume)
304 | if not args.finetune:
305 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
306 | optimizer.load_state_dict(checkpoint['optimizer'])
307 | args.start_epoch = checkpoint['epoch'] + 1
308 | if 'scaler' in checkpoint:
309 | loss_scaler.load_state_dict(checkpoint['scaler'])
310 | print("With optim & sched!")
311 |
312 |
313 | def all_reduce_mean(x):
314 | world_size = get_world_size()
315 | if world_size > 1:
316 | x_reduce = torch.tensor(x).cuda()
317 | dist.all_reduce(x_reduce)
318 | x_reduce /= world_size
319 | return x_reduce.item()
320 | else:
321 | return x
322 |
323 |
324 | def getCount(freq):
325 | count, total = freq[0], freq[1]
326 | return count / total if total != 0 else 0.0
327 |
328 | def log_qtype(data, eval, metric_logger, args):
329 | ep = 1e-10
330 |
331 | if args.dataset == 'nextqa':
332 | qtype2id= {'CH': 1, 'CW': 2, 'TN': 3, 'TC': 4, 'TP': 5, 'DL': 6, 'DC': 7, 'DO': 8}
333 | elif args.dataset == "star":
334 | qtype2id= {'In': 1, 'Seq': 2, 'Pre': 3, 'Feas': 4}
335 | else:
336 | return
337 |
338 | q_freq = {i : [0., 0.] for i in qtype2id.values()}
339 | q_freq[0] = [0., 0.]
340 | for i, v in enumerate(eval):
341 | qt = data['qtype'][i].item()
342 | q_freq[qt][0] += v.item()
343 | q_freq[qt][1] += 1
344 | q_freq[0][0] += v.item()
345 | q_freq[0][1] += 1
346 |
347 | if args.dataset == 'nextqa':
348 | metric_logger.update(n=(q_freq[1][1]+q_freq[2][1]+ ep), C=(q_freq[1][0]+q_freq[2][0]) / (q_freq[1][1]+q_freq[2][1]+ ep))
349 | metric_logger.update(n=(q_freq[3][1]+q_freq[4][1]+ q_freq[5][1]+ ep), T=(q_freq[3][0]+q_freq[4][0]+q_freq[5][0]) / (q_freq[3][1]+q_freq[4][1]+ q_freq[5][1]+ ep))
350 | metric_logger.update(n=(q_freq[6][1]+q_freq[7][1]+ q_freq[8][1]+ ep), D=(q_freq[6][0]+q_freq[7][0]+q_freq[8][0]) / (q_freq[6][1]+q_freq[7][1]+ q_freq[8][1]+ ep))
351 | metric_logger.update(n=q_freq[0][1]+ep, Total=getCount(q_freq[0]))
352 | elif args.dataset == "star":
353 | metric_logger.update(n=q_freq[1][1]+ep, In=getCount(q_freq[1]))
354 | metric_logger.update(n=q_freq[2][1]+ep, Seq=getCount(q_freq[2]))
355 | metric_logger.update(n=q_freq[3][1]+ep, Pre=getCount(q_freq[3]))
356 | metric_logger.update(n=q_freq[4][1]+ep, Feas=getCount(q_freq[4]))
357 | metric_logger.update(n=q_freq[0][1]+ep, Total=getCount(q_freq[0]))
358 |
--------------------------------------------------------------------------------
/vqa_checkpoint/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhg-wei/TOPA/609c48228bcacca2d72eee7fa3d1f39b261e7b7f/vqa_checkpoint/.gitkeep
--------------------------------------------------------------------------------