├── utils ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── optims.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── optims.cpython-310.pyc │ ├── simple_tokenizer.cpython-39.pyc │ └── simple_tokenizer.cpython-310.pyc ├── optims.py ├── extract_frames.py ├── pad_frames.py ├── extract_features.py ├── simple_tokenizer.py ├── write_statements.py └── utils.py ├── datasets ├── __pycache__ │ └── nextqa.cpython-39.pyc └── nextqa.py ├── models ├── __pycache__ │ ├── Qformer.cpython-39.pyc │ ├── eva_vit.cpython-39.pyc │ ├── Transformer.cpython-39.pyc │ ├── modeling_t5.cpython-39.pyc │ ├── grounding_module.cpython-39.pyc │ └── eva_clip_branch_encoder.cpython-39.pyc ├── Transformer.py ├── grounding_module.py ├── eva_clip_branch_encoder.py ├── eva_vit.py └── blip2_t5_instruct.py ├── requirements.txt ├── README.md └── finetune_ans.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /utils/__pycache__/optims.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/optims.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/nextqa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/datasets/__pycache__/nextqa.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Qformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/Qformer.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/eva_vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/eva_vit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optims.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/optims.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/Transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/Transformer.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/modeling_t5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/modeling_t5.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/grounding_module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/grounding_module.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/utils/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/eva_clip_branch_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHB139426/GCG/HEAD/models/__pycache__/eva_clip_branch_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /utils/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | class LinearWarmupCosineLRScheduler: 11 | def __init__( 12 | self, 13 | optimizer, 14 | max_epoch, 15 | min_lr, 16 | init_lr, 17 | warmup_steps=0, 18 | warmup_start_lr=-1, 19 | **kwargs 20 | ): 21 | self.optimizer = optimizer 22 | 23 | self.max_epoch = max_epoch 24 | self.min_lr = min_lr 25 | 26 | self.init_lr = init_lr 27 | self.warmup_steps = warmup_steps 28 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 29 | 30 | def step(self, cur_epoch, cur_step): 31 | # assuming the warmup iters less than one epoch 32 | if cur_epoch == 0: 33 | warmup_lr_schedule( 34 | step=cur_step, 35 | optimizer=self.optimizer, 36 | max_step=self.warmup_steps, 37 | init_lr=self.warmup_start_lr, 38 | max_lr=self.init_lr, 39 | ) 40 | else: 41 | cosine_lr_schedule( 42 | epoch=cur_epoch, 43 | optimizer=self.optimizer, 44 | max_epoch=self.max_epoch, 45 | init_lr=self.init_lr, 46 | min_lr=self.min_lr, 47 | ) 48 | 49 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 50 | """Decay the learning rate""" 51 | lr = (init_lr - min_lr) * 0.5 * ( 52 | 1.0 + math.cos(math.pi * epoch / max_epoch) 53 | ) + min_lr 54 | for param_group in optimizer.param_groups: 55 | param_group["lr"] = lr 56 | 57 | 58 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 59 | """Warmup the learning rate""" 60 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 61 | for param_group in optimizer.param_groups: 62 | param_group["lr"] = lr 63 | 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.26.1 2 | asttokens==2.4.1 3 | attrs==23.2.0 4 | backcall==0.2.0 5 | beautifulsoup4==4.12.3 6 | bleach==6.1.0 7 | cachetools==5.3.3 8 | certifi==2024.2.2 9 | charset-normalizer==3.3.2 10 | colorama==0.4.6 11 | contourpy==1.2.0 12 | cycler==0.12.1 13 | decorator==5.1.1 14 | defusedxml==0.7.1 15 | descartes==1.1.0 16 | docopt==0.6.2 17 | einops==0.7.0 18 | executing==2.0.1 19 | fastjsonschema==2.19.1 20 | filelock==3.13.1 21 | fire==0.6.0 22 | fonttools==4.50.0 23 | fsspec==2024.2.0 24 | ftfy==6.1.3 25 | h5py==3.10.0 26 | huggingface-hub==0.20.3 27 | icecream==2.1.3 28 | idna==3.6 29 | importlib_metadata==7.0.2 30 | importlib_resources==6.3.1 31 | ipython==8.12.3 32 | jedi==0.19.1 33 | Jinja2==3.1.3 34 | joblib==1.3.2 35 | jsonschema==4.21.1 36 | jsonschema-specifications==2023.12.1 37 | jupyter_client==8.6.1 38 | jupyter_core==5.7.2 39 | jupyterlab_pygments==0.3.0 40 | kiwisolver==1.4.5 41 | MarkupSafe==2.1.5 42 | matplotlib==3.5.3 43 | matplotlib-inline==0.1.6 44 | mistune==3.0.2 45 | nbclient==0.10.0 46 | nbconvert==7.16.2 47 | nbformat==5.10.3 48 | numpy==1.26.4 49 | nuscenes-devkit==1.1.11 50 | nvidia-cublas-cu11==11.10.3.66 51 | nvidia-cuda-nvrtc-cu11==11.7.99 52 | nvidia-cuda-runtime-cu11==11.7.99 53 | nvidia-cudnn-cu11==8.5.0.96 54 | opencv-python==4.9.0.80 55 | packaging==23.2 56 | pandas==2.2.0 57 | pandocfilters==1.5.1 58 | parso==0.8.3 59 | peft==0.3.0 60 | pexpect==4.9.0 61 | pickleshare==0.7.5 62 | pillow==10.2.0 63 | pipreqs==0.5.0 64 | platformdirs==4.2.0 65 | prompt-toolkit==3.0.43 66 | protobuf==3.20.0 67 | psutil==5.9.8 68 | ptyprocess==0.7.0 69 | pure-eval==0.2.2 70 | pycocotools==2.0.7 71 | Pygments==2.17.2 72 | pyparsing==3.1.2 73 | pyquaternion==0.9.9 74 | python-dateutil==2.8.2 75 | pytz==2024.1 76 | PyYAML==6.0.1 77 | pyzmq==25.1.2 78 | referencing==0.34.0 79 | regex==2023.12.25 80 | requests==2.31.0 81 | rouge==1.0.1 82 | rpds-py==0.18.0 83 | safetensors==0.4.2 84 | scikit-learn==1.4.1.post1 85 | scipy==1.12.0 86 | sentencepiece==0.1.99 87 | Shapely==1.8.5.post1 88 | six==1.16.0 89 | soupsieve==2.5 90 | stack-data==0.6.3 91 | termcolor==2.4.0 92 | threadpoolctl==3.4.0 93 | tiktoken==0.6.0 94 | timm==0.9.12 95 | tinycss2==1.2.1 96 | tokenizers==0.15.1 97 | torch==1.13.1 98 | torchvision==0.14.1 99 | tornado==6.4 100 | tqdm==4.66.1 101 | traitlets==5.14.2 102 | transformers==4.37.2 103 | transformers-stream-generator==0.0.5 104 | typing_extensions==4.9.0 105 | tzdata==2023.4 106 | urllib3==2.2.0 107 | wcwidth==0.2.13 108 | webencodings==0.5.1 109 | yarg==0.1.9 110 | zipp==3.18.1 111 | -------------------------------------------------------------------------------- /utils/extract_frames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from tqdm import tqdm 4 | import sys 5 | import json 6 | import pandas as pd 7 | 8 | def load_json(path): 9 | with open(path) as f: 10 | data = json.load(f) 11 | return data 12 | 13 | def load_csv(path): 14 | file_list = [] 15 | data = pd.read_csv(path) 16 | columns = data.columns.tolist() 17 | for index, row in data.iterrows(): 18 | file_list.append({}) 19 | for column in columns: 20 | file_list[index][column] = row[column] 21 | return file_list 22 | 23 | def extract_frames(video_path, frame_path, frames=16): 24 | # 打开视频文件 25 | cap = cv2.VideoCapture(video_path) 26 | os.makedirs(frame_path, exist_ok=True) 27 | # 获取视频帧率 28 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 29 | # 跳到指定的起始帧 30 | start_frame = 1 31 | # 计算抽帧间隔 32 | total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)-start_frame 33 | frame_interval = max(total_frames // frames, 1) 34 | # print(fps, start_frame, total_frames, cap.get(cv2.CAP_PROP_FRAME_COUNT), frame_interval) 35 | 36 | if total_frames < frames: 37 | print(frame_path, f"<{frames} frames!!!!!!!") 38 | 39 | ret, frame = cap.read() 40 | # 开始抽帧 41 | frame_count = 0 42 | current_frame = 0 43 | while True: 44 | ret, frame = cap.read() 45 | if not ret: 46 | break 47 | current_frame += 1 48 | # 每个frame_interval帧保存一帧 49 | if (current_frame-start_frame) % frame_interval == 0 and current_frame >= start_frame: 50 | output_file_path = os.path.join(frame_path, f'frame_{frame_count}.jpg') 51 | cv2.imwrite(output_file_path, frame) 52 | frame_count += 1 53 | # 如果已经抽取足够帧数,提前结束循环 54 | if frame_count >= frames: 55 | break 56 | # 释放视频文件对象 57 | cap.release() 58 | 59 | if frame_count < frames: 60 | print(frame_path, f"<{frames} frames!!!!!!!") 61 | 62 | 63 | train_data = load_csv("../nextqa/annotations_mc/train.csv") 64 | val_data = load_csv("../nextqa/annotations_mc/val.csv") 65 | test_data = load_csv("../nextqa/annotations_mc/test.csv") 66 | mapper = load_json('../nextqa/map_vid_vidorID.json') 67 | data = train_data + val_data + test_data 68 | 69 | video_ids = [] 70 | for item in data: 71 | video_id = item['video'] 72 | video_ids.append(video_id) 73 | video_ids = list(set(video_ids)) 74 | 75 | for video_id in tqdm(video_ids): 76 | video_path = f"../nextqa/videos/{mapper[str(video_id)]}.mp4" 77 | frame_path = f"../nextqa/frames_32/{video_id}" 78 | extract_frames(video_path, frame_path, frames=32) 79 | -------------------------------------------------------------------------------- /utils/pad_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | import torch 5 | from tqdm import tqdm 6 | 7 | def check_frames(source_dir, tgt_frame_num): 8 | all_num = 0 9 | count = 0 10 | for subdir in os.listdir(source_dir): 11 | subdir_path = os.path.join(source_dir, subdir) 12 | jpg_files = [f for f in os.listdir(subdir_path) if f.lower().endswith(".jpg")] 13 | if len(jpg_files) < tgt_frame_num: 14 | count += 1 15 | print(subdir_path) 16 | all_num += len(jpg_files) 17 | return count, all_num 18 | 19 | def generate_uniform_elements(T, W): 20 | return torch.linspace(0, T-1, W, dtype=torch.int) 21 | 22 | def move_files(source_dir, target_dir, tgt_frame_num): 23 | # 创建目标文件夹 24 | os.makedirs(target_dir, exist_ok=True) 25 | # 遍历源文件夹下的每个子文件夹 26 | for subdir in tqdm(os.listdir(source_dir)): 27 | subdir_path = os.path.join(source_dir, subdir) 28 | target_subdir_path = os.path.join(target_dir, subdir) 29 | if not os.path.exists(target_subdir_path): 30 | os.makedirs(target_subdir_path, exist_ok=True) 31 | # 列出当前子文件夹下的jpg文件 32 | jpg_files = [f for f in os.listdir(subdir_path) if f.lower().endswith(".jpg")] 33 | jpg_files = sorted(jpg_files, key=lambda x: int(x.split('_')[1].split('.')[0])) 34 | # 如果jpg文件数量超过tgt_frame_num,则均匀取tgt_frame_num个文件;否则重复旧文件填充至tgt_frame_num个文件 35 | selected_files = [] 36 | if len(jpg_files) >= tgt_frame_num: 37 | # 均匀选择tgt_frame_num个文件 38 | selected_indices = generate_uniform_elements(len(jpg_files), tgt_frame_num) 39 | selected_files = [jpg_files[i] for i in selected_indices] 40 | for filename in selected_files: 41 | src_path = os.path.join(subdir_path, filename) 42 | dest_path = os.path.join(target_subdir_path, filename) 43 | shutil.copy(src_path, dest_path) 44 | else: 45 | # 重复旧文件填充至tgt_frame_num个文件 46 | max_value = len(jpg_files)-1 47 | selected_files = jpg_files 48 | for filename in selected_files: 49 | src_path = os.path.join(subdir_path, filename) 50 | dest_path = os.path.join(target_subdir_path, filename) 51 | shutil.copy(src_path, dest_path) 52 | for i in range(tgt_frame_num - len(jpg_files)): 53 | src_path = os.path.join(subdir_path, random.choice(jpg_files)) 54 | dest_path = os.path.join(target_subdir_path, f"frame_{max_value+i+1}.jpg") 55 | shutil.copy(src_path, dest_path) 56 | print("转移完成。") 57 | 58 | 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCG - ACM MM'24 2 | **Weakly Supervised Gaussian Contrastive Grounding with Large Multimodal Models for Video Question Answering [ACM MM'24]**. This is the official implementation of the [[Paper](https://arxiv.org/abs/2401.10711)] accepted by ACM MM'24. 3 | 4 | ## Install 5 | 6 | 1. Clone this repository and navigate to GCG folder 7 | ```bash 8 | git clone https://github.com/WHB139426/GCG.git 9 | cd GCG 10 | mkdir experiments 11 | mkdir files 12 | ``` 13 | 14 | 2. Install Package 15 | ```Shell 16 | conda create -n gcg python=3.9.16 17 | conda activate gcg 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Pretrained Weights of InstructBLIP 22 | 23 | You can prepare the pretrained weights of InstructBLIP-T5-XL according to [[InstructBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)]. 24 | 25 | Since we have changed the structure of the code of the model, we recommend you download the pretrained weights of EVA-CLIP, and QFormer directly in [[🤗HF](https://huggingface.co/WHB139426/GCG/tree/main)]. The pretrained weights should be organized as follows, 26 | 27 | ``` 28 | ├── GCG 29 | │ └── experiments 30 | │ └── eva_vit_g.pth 31 | │ └── qformer_t5.pth 32 | │ └── query_tokens_t5.pth 33 | │ └── llm_proj_t5.pth 34 | │ └── eva_vit_post_layernorm.pth 35 | │ └── eva_clip_text_model.pth 36 | │ └── eva_clip_last_vision_head.pth 37 | │ └── eva_clip_last_vision_norm.pth 38 | │ └── eva_clip_last_vision_block.pth 39 | ``` 40 | 41 | ## Datasets 42 | You should download the videos of NExT-QA from https://github.com/doc-doc/NExT-QA?tab=readme-ov-file or directly with the link [[videos](https://drive.google.com/file/d/1jTcRCrVHS66ckOUfWRb-rXdzJ52XAWQH/view)]. The downloaded videos should be in the folder `nextqa/videos` 43 | 44 | We provide the annotation files in [[🤗HF](https://huggingface.co/WHB139426/GCG/tree/main)], and you should organize the data as follows, 45 | 46 | ``` 47 | ├── nextqa 48 | │ └── annotations_mc 49 | │ └── frames_32 50 | │ └── videos 51 | │ └── vision_features 52 | | └── map_vid_vidorID.json 53 | ├── GCG 54 | │ └── datasets 55 | │ └── models 56 | │ └──... 57 | ``` 58 | Then, you should extract 32 frames per video into the `nextqa/frames_32` folder with the python scripts 59 | 60 | ```Shell 61 | python utils/extract_frames.py 62 | ``` 63 | 64 | After that, you should extract the video features in advance into the `nextqa/vision_features` with the python scripts 65 | 66 | ```Shell 67 | python utils/extract_features.py 68 | ``` 69 | 70 | ## Training 71 | 72 | ```Shell 73 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=1111 finetune_ans.py 74 | ``` 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/extract_features.py: -------------------------------------------------------------------------------- 1 | from transformers import Blip2Config 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer, Blip2Processor 6 | import os 7 | import sys 8 | import numpy as np 9 | import h5py 10 | import cv2 11 | import torch 12 | from transformers import AutoConfig, StoppingCriteria 13 | import random 14 | import re 15 | import requests 16 | from PIL import Image 17 | from io import BytesIO 18 | import json 19 | import pickle 20 | import pandas as pd 21 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 22 | from utils import * 23 | from models.eva_vit import Blip2VisionModel 24 | 25 | def load_image(image_file): 26 | if image_file.startswith('http') or image_file.startswith('https'): 27 | response = requests.get(image_file) 28 | image = Image.open(BytesIO(response.content)).convert('RGB') 29 | else: 30 | image = Image.open(image_file).convert('RGB') 31 | return image 32 | 33 | class NEXTQADataset(Dataset): 34 | def __init__( 35 | self, 36 | frame_path = "../nextqa/frames_32" 37 | ): 38 | self.frame_path = frame_path 39 | self.image_processor = image_transform(image_size=224) 40 | self.video_ids = os.listdir(self.frame_path) 41 | self.image_files = [] 42 | self.image_ids = [] 43 | for video_id in self.video_ids: 44 | frame_files = os.listdir(self.frame_path + f"/{video_id}") 45 | frame_files = sorted(frame_files, key=lambda x: int(x.split('_')[1].split('.')[0])) 46 | for frame_file in frame_files: 47 | self.image_files.append(self.frame_path + f"/{video_id}/{frame_file}") 48 | self.image_ids.append(f"{video_id}_{frame_file.replace('.jpg','')}") 49 | 50 | def __len__(self): 51 | """returns the length of dataframe""" 52 | return len(self.image_ids) 53 | 54 | def __getitem__(self, index): 55 | image_id = self.image_ids[index] 56 | image_file = self.image_files[index] 57 | pixel_values = self.image_processor(Image.open(image_file)) 58 | 59 | return { 60 | "image_ids": image_id, 61 | "image_files": image_file, 62 | "pixel_values": pixel_values, 63 | } 64 | 65 | 66 | dataset=NEXTQADataset(frame_path = "../nextqa/frames_32") 67 | print(len(dataset)) 68 | print(dataset[0]["image_ids"]) 69 | print(dataset[0]["image_files"]) 70 | print(dataset[0]["pixel_values"].shape) 71 | 72 | blip2_config = Blip2Config.from_pretrained('Salesforce/blip2-flan-t5-xl') 73 | blip2_config.vision_config.torch_dtype = torch.float16 74 | vision_model = Blip2VisionModel(blip2_config.vision_config) 75 | vision_model.load_state_dict(torch.load("experiments/eva_vit_g.pth", map_location='cpu')) 76 | 77 | data_loader = DataLoader(dataset=dataset, batch_size=768, shuffle=False, drop_last=False, num_workers=16) 78 | 79 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 80 | vision_model.to(device) 81 | vision_model.eval() 82 | 83 | f = h5py.File('../nextqa/vision_features/feats_wo_norm_32.h5', "w") 84 | 85 | for i, data in enumerate(tqdm(data_loader)): 86 | image_ids = data['image_ids'] 87 | pixel_values = data['pixel_values'].to(device) 88 | # 抽取视觉特征 89 | with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): 90 | with torch.no_grad(): 91 | frame_features = vision_model(pixel_values).last_hidden_state_without_norm # shape: [bs, 257, 1408] 92 | frame_features = frame_features.cpu().numpy() 93 | if i==0: 94 | print(frame_features.shape) 95 | for j in range(frame_features.shape[0]): 96 | f.create_dataset(f"{image_ids[j]}", data=frame_features[j]) 97 | 98 | -------------------------------------------------------------------------------- /datasets/nextqa.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import random 3 | import numpy as np 4 | import torch 5 | from transformers import AutoTokenizer, InstructBlipProcessor 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import pickle 9 | import sys 10 | import os 11 | import requests 12 | from PIL import Image 13 | from collections import Counter 14 | from io import BytesIO 15 | import json 16 | import h5py 17 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 18 | from utils.utils import * 19 | 20 | class NEXTQADataset(Dataset): 21 | def __init__( 22 | self, 23 | anno_path = '../nextqa/annotations_mc/train.csv', 24 | mapper_path = '../nextqa/map_vid_vidorID.json', 25 | video_path = "../nextqa/videos", 26 | frame_path = "../nextqa/frames_32", 27 | feature_path = "../nextqa/vision_features/feats_wo_norm_32.h5", 28 | frame_count = 32 29 | ): 30 | 31 | self.data = load_csv(anno_path) 32 | self.mapper = load_json(mapper_path) 33 | self.video_path = video_path 34 | self.frame_path = frame_path 35 | self.frame_count = frame_count 36 | self.image_processor = image_transform(image_size=224) 37 | self.image_features = h5py.File(feature_path, "r") 38 | 39 | self.video_ids = [] 40 | self.videos = [] 41 | self.frames = [] 42 | self.questions = [] 43 | self.answers_option = [] 44 | self.answers_text = [] 45 | self.answers_ids = [] 46 | self.types = [] 47 | self.qids = [] 48 | self.options_a0 = [] 49 | self.options_a1 = [] 50 | self.options_a2 = [] 51 | self.options_a3 = [] 52 | self.options_a4 = [] 53 | 54 | for data in self.data: 55 | 56 | self.video_ids.append(data['video']) 57 | self.qids.append(data['qid']) 58 | self.types.append(data['type']) 59 | self.questions.append(data['question']+"?") 60 | self.options_a0.append(data['a0']) 61 | self.options_a1.append(data['a1']) 62 | self.options_a2.append(data['a2']) 63 | self.options_a3.append(data['a3']) 64 | self.options_a4.append(data['a4']) 65 | 66 | self.answers_ids.append(data['answer']) 67 | self.answers_text.append(data[f"a{str(data['answer'])}"] ) 68 | self.answers_option.append(["A", "B", "C", "D", "E"][data['answer']]) 69 | self.videos.append(self.video_path + f"/{self.mapper[str(data['video'])]}.mp4") 70 | self.frames.append(self.frame_path +f"/{str(data['video'])}") 71 | 72 | def __len__(self): 73 | """returns the length of dataframe""" 74 | return len(self.video_ids) 75 | 76 | def __getitem__(self, index): 77 | """return the input ids, attention masks and target ids""" 78 | video_id = str(self.video_ids[index]) 79 | qid = str(self.qids[index]) 80 | type = str(self.types[index]) 81 | question = str(self.questions[index]) 82 | option_a0 = str(self.options_a0[index]) 83 | option_a1 = str(self.options_a1[index]) 84 | option_a2 = str(self.options_a2[index]) 85 | option_a3 = str(self.options_a3[index]) 86 | option_a4 = str(self.options_a4[index]) 87 | answer_id = self.answers_ids[index] 88 | answer_text = str(self.answers_text[index]) 89 | answer_option = str(self.answers_option[index]) 90 | 91 | frame_files = os.listdir(str(self.frames[index])) 92 | frame_files = sorted(frame_files, key=lambda x: int(x.split('_')[1].split('.')[0])) 93 | frame_files = get_frames(frame_files, self.frame_count) 94 | 95 | frame_features = [] 96 | for i in range(len(frame_files)): 97 | frame_features.append(torch.from_numpy(self.image_features[f"{video_id}_{frame_files[i].replace('.jpg','')}"][:])) 98 | frame_features = torch.stack(frame_features, dim=0) # [frame_count, 257, 1408] 99 | 100 | return { 101 | "video_ids": video_id, 102 | "qids": qid, 103 | "types": type, 104 | 105 | "frame_features": frame_features, 106 | 107 | "questions": question, 108 | "options_a0": option_a0, 109 | "options_a1": option_a1, 110 | "options_a2": option_a2, 111 | "options_a3": option_a3, 112 | "options_a4": option_a4, 113 | "answers_id": answer_id, 114 | "answers_text": answer_text, 115 | "answers": answer_option, 116 | 117 | } 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | import torch 6 | import ftfy 7 | import regex as re 8 | from typing import Any, Union, List 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils/bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | 134 | _tokenizer = SimpleTokenizer() 135 | 136 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> torch.LongTensor: 137 | """ 138 | Returns the tokenized representation of given input string(s) 139 | 140 | Parameters 141 | ---------- 142 | texts : Union[str, List[str]] 143 | An input string or a list of input strings to tokenize 144 | 145 | context_length : int 146 | The context length to use; all CLIP models use 77 as the context length 147 | 148 | truncate: bool 149 | Whether to truncate the text in case its encoding is longer than the context length 150 | 151 | Returns 152 | ------- 153 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 154 | """ 155 | if isinstance(texts, str): 156 | texts = [texts] 157 | 158 | sot_token = _tokenizer.encoder["<|startoftext|>"] 159 | eot_token = _tokenizer.encoder["<|endoftext|>"] 160 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 161 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 162 | 163 | for i, tokens in enumerate(all_tokens): 164 | if len(tokens) > context_length: 165 | if truncate: 166 | tokens = tokens[:context_length] 167 | tokens[-1] = eot_token 168 | else: 169 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 170 | result[i, :len(tokens)] = torch.tensor(tokens) 171 | 172 | return result -------------------------------------------------------------------------------- /utils/write_statements.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | from tqdm import tqdm 4 | import sys 5 | import os 6 | from torch.utils.data import Dataset 7 | import random 8 | import numpy as np 9 | import torch 10 | from transformers import AutoTokenizer, InstructBlipProcessor 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import pickle 14 | import sys 15 | import os 16 | import requests 17 | from PIL import Image 18 | from collections import Counter 19 | from io import BytesIO 20 | import json 21 | import h5py 22 | from torch.backends import cudnn 23 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 24 | from utils import * 25 | 26 | def init_seeds(seed=42, cuda_deterministic=True): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 31 | if cuda_deterministic: # slower, more reproducible 32 | cudnn.deterministic = True 33 | cudnn.benchmark = False 34 | else: # faster, less reproducible 35 | cudnn.deterministic = False 36 | cudnn.benchmark = True 37 | 38 | def add_template(questions, answers): 39 | messages = [ 40 | {"role": "user", "content": "Give you a question and corresponding answer, turn it into a declarative sentence. Remember, don't answer the question, add any additional information or make mistakes in grammar!"}, 41 | {"role": "assistant", "content": "Yes, I understand."}, 42 | 43 | {"role": "user", "content": "Q: What gender is the person with a face? A: male."}, # TC 44 | {"role": "assistant", "content": "The person with a face is a male."}, 45 | 46 | {"role": "user", "content": "Q: Is the person in white a sailboat in the river? A: no."}, # DO 47 | {"role": "assistant", "content": "The person in white is not a sailboat in the river."}, 48 | 49 | {"role": "user", "content": "Q: Why do the cowboy clothes get down quickly? A: catch cow."}, # CW 50 | {"role": "assistant", "content": "The cowboy clothes get down quickly to catch cow."}, 51 | 52 | {"role": "user", "content": "Q: Does the dancers dance indoors? A: yes."}, # DL 53 | {"role": "assistant", "content": "The dancers dance indoors."}, 54 | 55 | {"role": "user", "content": "Q: how many people declare that the end is nigh? A: four."}, # TN 56 | {"role": "assistant", "content": "Four people declare that the end is nigh."}, 57 | 58 | {"role": "user", "content": "Q: What is behind the person in the video? A: horse."}, # CH 59 | {"role": "assistant", "content": "Horse is behind the person in the video."}, 60 | 61 | {"role": "user", "content": f"Q: {questions} A: {answers}."}, 62 | ] 63 | return messages 64 | 65 | init_seeds(42) 66 | 67 | class MSVDQADataset(Dataset): 68 | def __init__( 69 | self, 70 | anno_path = '/home/whb/workspace/msvdQA/annotations/train_qa.json', 71 | mapper_path = "/home/whb/workspace/msvdQA/annotations/youtube_mapping.txt", 72 | video_path = "/home/whb/workspace/msvdQA/videos", 73 | frame_path = "/home/whb/workspace/msvdQA/frames_32", 74 | feature_path = "/home/whb/workspace/msvdQA/vision_features/feats_wo_norm.h5", 75 | frame_count = 32 76 | ): 77 | self.data = load_json('/home/whb/workspace/msvdQA/annotations/train_qa.json') + load_json('/home/whb/workspace/msvdQA/annotations/val_qa.json') + load_json('/home/whb/workspace/msvdQA/annotations/test_qa.json') 78 | self.mapper = {} 79 | f = open(mapper_path, 'r').read() 80 | lines = f.split('\n') 81 | for line in lines: 82 | if line.strip(): 83 | parts = line.split() 84 | if len(parts) == 2: 85 | vid_id = parts[1] 86 | video_name = parts[0] 87 | self.mapper[vid_id] = video_name 88 | self.video_path = video_path 89 | self.frame_path = frame_path 90 | self.frame_count = frame_count 91 | 92 | self.video_ids = [] 93 | self.videos = [] 94 | self.frames = [] 95 | self.questions = [] 96 | self.answers_text = [] 97 | self.answers = [] 98 | self.types = [] 99 | self.qids = [] 100 | 101 | for data in self.data: 102 | 103 | temp_id = f"vid{data['video_id']}" 104 | video_id = self.mapper[temp_id] 105 | self.video_ids.append(video_id) 106 | self.qids.append(data['id']) 107 | self.types.append('N/A') 108 | self.questions.append(data['question']) 109 | self.answers.append(data['answer']) 110 | self.answers_text.append(data['answer']) 111 | 112 | self.videos.append(self.video_path + f"/{video_id}.mp4") 113 | self.frames.append(self.frame_path +f"/{video_id}") 114 | 115 | def __len__(self): 116 | """returns the length of dataframe""" 117 | return len(self.video_ids) 118 | 119 | def __getitem__(self, index): 120 | """return the input ids, attention masks and target ids""" 121 | video_id = str(self.video_ids[index]) 122 | qid = str(self.qids[index]) 123 | type = str(self.types[index]) 124 | question = str(self.questions[index]) 125 | answer = str(self.answers[index]) 126 | answer_text = str(self.answers_text[index]) 127 | 128 | 129 | return { 130 | "video_ids": video_id, 131 | "qids": qid, 132 | "types": type, 133 | "questions": question, 134 | "answers": answer, 135 | "answers_text": answer_text, 136 | } 137 | 138 | class MSRVTTQADataset(Dataset): 139 | def __init__( 140 | self, 141 | anno_path = '/home/whb/workspace/msrvttQA/annotations/train_qa.json', 142 | video_path = "/home/whb/workspace/msrvttQA/videos", 143 | frame_path = "/home/whb/workspace/msrvttQA/frames_32", 144 | feature_path = "/home/whb/workspace/msrvttQA/vision_features/feats_wo_norm.h5", 145 | frame_count = 32 146 | ): 147 | self.data = load_json('/home/whb/workspace/msrvttQA/annotations/train_qa.json') + load_json('/home/whb/workspace/msrvttQA/annotations/val_qa.json') + load_json('/home/whb/workspace/msrvttQA/annotations/test_qa.json') 148 | self.video_path = video_path 149 | self.frame_path = frame_path 150 | self.frame_count = frame_count 151 | self.image_processor = image_transform(image_size=224) 152 | 153 | self.video_ids = [] 154 | self.videos = [] 155 | self.frames = [] 156 | self.questions = [] 157 | self.answers_text = [] 158 | self.answers = [] 159 | self.types = [] 160 | self.qids = [] 161 | 162 | for data in self.data: 163 | 164 | self.video_ids.append(data['video_id']) 165 | self.qids.append(data['id']) 166 | self.types.append(data['category_id']) 167 | self.questions.append(data['question']) 168 | self.answers.append(data['answer']) 169 | self.answers_text.append(data['answer']) 170 | 171 | self.videos.append(self.video_path + f"/video{data['video_id']}.mp4") 172 | self.frames.append(self.frame_path +f"/video{data['video_id']}") 173 | 174 | def __len__(self): 175 | """returns the length of dataframe""" 176 | return len(self.video_ids) 177 | 178 | def __getitem__(self, index): 179 | """return the input ids, attention masks and target ids""" 180 | video_id = str(self.video_ids[index]) 181 | qid = str(self.qids[index]) 182 | type = str(self.types[index]) 183 | question = str(self.questions[index]) 184 | answer = str(self.answers[index]) 185 | answer_text = str(self.answers_text[index]) 186 | 187 | return { 188 | "video_ids": video_id, 189 | "qids": qid, 190 | "types": type, 191 | 192 | 193 | "questions": question, 194 | "answers": answer, 195 | "answers_text": answer_text, 196 | } 197 | 198 | class ActivityQADataset(Dataset): 199 | def __init__( 200 | self, 201 | anno_path = '/home/whb/workspace/activityQA/annotations/train.json', 202 | video_path = "/home/whb/workspace/activityQA/videos", 203 | frame_path = "/home/whb/workspace/activityQA/frames_32", 204 | frame_count = 32 205 | ): 206 | self.data = load_json('/home/whb/workspace/activityQA/annotations/train.json') + load_json('/home/whb/workspace/activityQA/annotations/val.json') + load_json('/home/whb/workspace/activityQA/annotations/test.json') 207 | self.video_path = video_path 208 | self.frame_path = frame_path 209 | self.frame_count = frame_count 210 | self.image_processor = image_transform(image_size=224) 211 | 212 | self.video_ids = [] 213 | self.videos = [] 214 | self.frames = [] 215 | self.questions = [] 216 | self.answers = [] 217 | self.types = [] 218 | self.qids = [] 219 | 220 | for data in self.data: 221 | 222 | self.video_ids.append('v_'+data['video_name']) 223 | self.qids.append(data['question_id']) 224 | self.types.append(data['type']) 225 | self.questions.append(data['question'].capitalize()+"?") 226 | self.answers.append(data['answer']) 227 | 228 | self.videos.append(self.video_path + f"/v_{data['video_name']}.mp4") 229 | self.frames.append(self.frame_path +f"/v_{data['video_name']}") 230 | 231 | def __len__(self): 232 | """returns the length of dataframe""" 233 | return len(self.video_ids) 234 | 235 | def __getitem__(self, index): 236 | """return the input ids, attention masks and target ids""" 237 | video_id = str(self.video_ids[index]) 238 | qid = str(self.qids[index]) 239 | type = str(self.types[index]) 240 | question = str(self.questions[index]) 241 | answer = str(self.answers[index]) 242 | 243 | return { 244 | "video_ids": video_id, 245 | "qids": qid, 246 | "types": type, 247 | 248 | "questions": question, 249 | "answers": answer 250 | } 251 | 252 | dataset = ActivityQADataset() 253 | 254 | device = "cuda:6" # the device to load the model onto 255 | tokenizer = AutoTokenizer.from_pretrained("/home/whb/workspace/Mistral-7B-Instruct-v0.1/") 256 | model = AutoModelForCausalLM.from_pretrained("/home/whb/workspace/Mistral-7B-Instruct-v0.1/", torch_dtype=torch.bfloat16) 257 | model.to(device) 258 | 259 | activityqa_statement = [] 260 | for entry in tqdm(dataset): 261 | video_id = entry['video_ids'] 262 | qids = entry['qids'] 263 | types = entry['types'] 264 | questions = entry['questions'] 265 | answers = entry['answers'] 266 | 267 | encodeds = tokenizer.apply_chat_template(add_template(questions, answers), return_tensors="pt") 268 | model_inputs = encodeds.to(device) 269 | generated_ids = model.generate(model_inputs, max_new_tokens=256, do_sample=False) 270 | decoded_sequence = tokenizer.decode(generated_ids[0][len(model_inputs[0]):], skip_special_tokens=True).replace('[/INST]', '') 271 | 272 | activityqa_statement.append({ 273 | "video_ids": video_id, 274 | "qids": qids, 275 | "types": types, 276 | "questions": questions, 277 | "answers": answers, 278 | "statements": decoded_sequence, 279 | }) 280 | 281 | with open('activityqa_statement.json', 'w') as f: 282 | json.dump(activityqa_statement, f, indent=2) -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | from transformers.activations import gelu 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch 5 | import math 6 | import copy 7 | from transformers.modeling_outputs import BaseModelOutput 8 | import torch.nn.functional as F 9 | 10 | def create_sinusoidal_embeddings(n_pos, dim, out): 11 | with torch.no_grad(): 12 | position_enc = np.array( 13 | [ 14 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 15 | for pos in range(n_pos) 16 | ] 17 | ) 18 | out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 19 | out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 20 | # out.detach_() 21 | # out.requires_grad = False 22 | 23 | class MultiHeadSelfAttention(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | 27 | ################BERT & RoBERTa################# 28 | self.n_heads = config.num_attention_heads #config.n_heads 29 | self.dim = config.hidden_size #config.dim 30 | dp_rate = config.attention_probs_dropout_prob #config.attention_dropout 31 | 32 | ################DisVGT################# 33 | # self.n_heads = config.n_heads 34 | # self.dim = config.dim 35 | # dp_rate = config.attention_dropout 36 | 37 | self.dropout = nn.Dropout(p=dp_rate) 38 | 39 | assert self.dim % self.n_heads == 0 40 | 41 | self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim) 42 | self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim) 43 | self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim) 44 | self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim) 45 | 46 | self.pruned_heads = set() 47 | 48 | def forward(self, query, key, value, mask, head_mask=None, output_attentions=False, gauss_weight=None): 49 | """ 50 | Parameters 51 | ---------- 52 | query: torch.tensor(bs, seq_length, dim) 53 | key: torch.tensor(bs, seq_length, dim) 54 | value: torch.tensor(bs, seq_length, dim) 55 | mask: torch.tensor(bs, seq_length) 56 | 57 | Outputs 58 | ------- 59 | weights: torch.tensor(bs, n_heads, seq_length, seq_length) 60 | Attention weights 61 | context: torch.tensor(bs, seq_length, dim) 62 | Contextualized layer. Optional: only if `output_attentions=True` 63 | """ 64 | bs, q_length, dim = query.size() 65 | k_length = key.size(1) 66 | # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) 67 | # assert key.size() == value.size() 68 | 69 | dim_per_head = self.dim // self.n_heads 70 | 71 | mask_reshp = (bs, 1, 1, k_length) 72 | 73 | def shape(x): 74 | """ separate heads """ 75 | return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) 76 | 77 | def unshape(x): 78 | """ group heads """ 79 | return ( 80 | x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) 81 | ) 82 | 83 | q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) 84 | k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) 85 | v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) 86 | 87 | q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) 88 | scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) 89 | mask = ( 90 | (mask == 0).view(mask_reshp).expand_as(scores) 91 | ) # (bs, n_heads, q_length, k_length) 92 | scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length) 93 | 94 | weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length) 95 | 96 | if gauss_weight is not None: 97 | # gauss_weight = gauss_weight.unsqueeze(1).repeat(self.num_heads, tgt_len, 1) 98 | gauss_weight = gauss_weight.unsqueeze(1).unsqueeze(1)\ 99 | .expand(-1, self.n_heads, q_length, -1).reshape(*weights.shape) 100 | weights = weights * (gauss_weight + 1e-10) 101 | weights = weights / weights.sum(dim=-1, keepdim=True) 102 | 103 | weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) 104 | 105 | # Mask heads if we want to 106 | if head_mask is not None: 107 | weights = weights * head_mask 108 | 109 | context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) 110 | context = unshape(context) # (bs, q_length, dim) 111 | context = self.out_lin(context) # (bs, q_length, dim) 112 | 113 | if output_attentions: 114 | return (context, weights) 115 | else: 116 | return (context,) 117 | 118 | class FFN(nn.Module): 119 | def __init__(self, config): 120 | super().__init__() 121 | dropout, dim, hidden_dim = config.attention_probs_dropout_prob, config.hidden_size, config.intermediate_size 122 | activation = config.hidden_act 123 | ##########DisVGT############### 124 | # dropout, dim, hidden_dim = config.attention_dropout, config.dim, config.hidden_dim 125 | # activation = config.activation 126 | 127 | self.dropout = nn.Dropout(p=dropout) 128 | self.lin1 = nn.Linear(in_features=dim, out_features=hidden_dim) 129 | self.lin2 = nn.Linear(in_features=hidden_dim, out_features=dim) 130 | assert activation in [ 131 | "relu", 132 | "gelu", 133 | ], "activation ({}) must be in ['relu', 'gelu']".format(activation) 134 | self.activation = gelu if activation == "gelu" else nn.ReLU() 135 | 136 | def forward(self, input): 137 | x = self.lin1(input) 138 | x = self.activation(x) 139 | x = self.lin2(x) 140 | x = self.dropout(x) 141 | return x 142 | 143 | class TransformerBlock(nn.Module): 144 | def __init__(self, config): 145 | super().__init__() 146 | dim = config.hidden_size 147 | assert config.hidden_size % config.num_attention_heads == 0 148 | #########DisVGT########## 149 | # dim = config.dim 150 | # assert config.dim % config.n_heads == 0 151 | 152 | self.attention = MultiHeadSelfAttention(config) 153 | self.sa_layer_norm = nn.LayerNorm(normalized_shape=dim, eps=1e-12) 154 | 155 | self.ffn = FFN(config) 156 | self.output_layer_norm = nn.LayerNorm(normalized_shape=dim, eps=1e-12) 157 | 158 | def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, gauss_weight=None): 159 | """ 160 | Parameters 161 | ---------- 162 | x: torch.tensor(bs, seq_length, dim) 163 | attn_mask: torch.tensor(bs, seq_length) 164 | 165 | Outputs 166 | ------- 167 | sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) 168 | The attention weights 169 | ffn_output: torch.tensor(bs, seq_length, dim) 170 | The output of the transformer block contextualization. 171 | """ 172 | # Self-Attention 173 | residual_x = x 174 | x = self.sa_layer_norm(x) 175 | sa_output = self.attention( 176 | query=x, 177 | key=x, 178 | value=x, 179 | mask=attn_mask, 180 | head_mask=head_mask, 181 | output_attentions=output_attentions, 182 | gauss_weight=gauss_weight 183 | ) 184 | if output_attentions: 185 | ( 186 | sa_output, 187 | sa_weights, 188 | ) = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) 189 | else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples 190 | assert type(sa_output) == tuple 191 | sa_output = sa_output[0] 192 | sa_output = sa_output + residual_x # (bs, seq_length, dim) 193 | 194 | # Feed Forward Network 195 | residual_sa = sa_output 196 | sa_output = self.output_layer_norm(sa_output) 197 | ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) 198 | ffn_output = ffn_output + residual_sa # (bs, seq_length, dim) 199 | 200 | output = (ffn_output,) 201 | if output_attentions: 202 | output = (sa_weights,) + output 203 | return output 204 | 205 | class Transformer(nn.Module): 206 | def __init__(self, config): 207 | super().__init__() 208 | self.n_layers = config.num_hidden_layers 209 | ############DisBERT################ 210 | # self.n_layers = config.n_layers 211 | 212 | layer = TransformerBlock(config) 213 | self.layer = nn.ModuleList( 214 | [copy.deepcopy(layer) for _ in range(self.n_layers)] 215 | ) 216 | 217 | def forward( 218 | self, 219 | x, 220 | attn_mask=None, 221 | head_mask=None, 222 | output_attentions=False, 223 | output_hidden_states=False, 224 | return_dict=None, 225 | gauss_weight=None 226 | ): 227 | """ 228 | Parameters 229 | ---------- 230 | x: torch.tensor(bs, seq_length, dim) 231 | Input sequence embedded. 232 | attn_mask: torch.tensor(bs, seq_length) 233 | Attention mask on the sequence. 234 | 235 | Outputs 236 | ------- 237 | hidden_state: torch.tensor(bs, seq_length, dim) 238 | Sequence of hiddens states in the last (top) layer 239 | all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] 240 | Tuple of length n_layers with the hidden states from each layer. 241 | Optional: only if output_hidden_states=True 242 | all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] 243 | Tuple of length n_layers with the attention weights from each layer 244 | Optional: only if output_attentions=True 245 | """ 246 | all_hidden_states = () if output_hidden_states else None 247 | all_attentions = () if output_attentions else None 248 | 249 | hidden_state = x 250 | for i, layer_module in enumerate(self.layer): 251 | if output_hidden_states: 252 | all_hidden_states = all_hidden_states + (hidden_state,) 253 | if head_mask is not None: 254 | layer_outputs = layer_module( 255 | x=hidden_state, 256 | attn_mask=attn_mask, 257 | head_mask=head_mask[i], 258 | output_attentions=output_attentions, 259 | gauss_weight=gauss_weight 260 | ) 261 | else: 262 | layer_outputs = layer_module( 263 | x=hidden_state, 264 | attn_mask=attn_mask, 265 | head_mask=None, 266 | output_attentions=output_attentions, 267 | gauss_weight=gauss_weight 268 | ) 269 | hidden_state = layer_outputs[-1] 270 | 271 | if output_attentions: 272 | assert len(layer_outputs) == 2 273 | attentions = layer_outputs[0] 274 | all_attentions = all_attentions + (attentions,) 275 | else: 276 | assert len(layer_outputs) == 1 277 | 278 | # Add last layer 279 | if output_hidden_states: 280 | all_hidden_states = all_hidden_states + (hidden_state,) 281 | 282 | if not return_dict: 283 | return tuple( 284 | v 285 | for v in [hidden_state, all_hidden_states, all_attentions] 286 | if v is not None 287 | ) 288 | return BaseModelOutput( 289 | last_hidden_state=hidden_state, 290 | hidden_states=all_hidden_states, 291 | attentions=all_attentions, 292 | ) 293 | 294 | class Embeddings(nn.Module): 295 | def __init__( 296 | self, d_model, language_len, vision_len, dropout, sinusoidal_pos_embds, d_pos=128 297 | ): 298 | super().__init__() 299 | max_position_embeddings = language_len + vision_len 300 | self.position_embeddings = nn.Embedding(max_position_embeddings, d_model) 301 | if sinusoidal_pos_embds: 302 | create_sinusoidal_embeddings( 303 | n_pos=max_position_embeddings, 304 | dim=d_model, 305 | out=self.position_embeddings.weight, 306 | ) 307 | # for name, param in self.position_embeddings.named_parameters(): 308 | # param.requires_grad = False 309 | self.language_len = language_len 310 | self.vision_len = vision_len 311 | # self.dropout = nn.Dropout(dropout) 312 | 313 | def forward(self, embeddings): 314 | seq_length = embeddings.size(1) 315 | position_ids = torch.arange(seq_length, dtype=torch.long, device=embeddings.device) # (max_seq_length) 316 | position_ids = position_ids.unsqueeze(0).expand_as(embeddings[:, :, 0]) # (bs, max_seq_length) 317 | position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) 318 | 319 | embeddings = embeddings + position_embeddings # (bs, max_seq_length, dim) 320 | # embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) 321 | 322 | return embeddings 323 | -------------------------------------------------------------------------------- /finetune_ans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from tqdm import tqdm 5 | import argparse 6 | from torch import cuda 7 | import time 8 | from utils import * 9 | import torch.distributed as dist 10 | from torch.cuda.amp import autocast as autocast 11 | import pickle 12 | import random 13 | import json 14 | from datasets.nextqa import NEXTQADataset 15 | from torch.backends import cudnn 16 | from utils.utils import * 17 | from utils.optims import * 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--experiment_path', type=str, default='experiments') 23 | parser.add_argument('--seed', type=int, default=3407) 24 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 25 | 26 | parser.add_argument('--word_size', default=8, help="n_gpus") 27 | parser.add_argument('--bs', type=int, default=4) 28 | parser.add_argument('--eval_bs', type=int, default=4) 29 | parser.add_argument('--epoch', type=int, default=10) 30 | parser.add_argument('--lr', type=float, default=1e-5) 31 | parser.add_argument('--grounding_lr', type=float, default=7e-5) # 3e-5 32 | 33 | parser.add_argument('--use_schedule', type=bool, default=False) 34 | parser.add_argument('--warmup_start_lr', type=float, default=1e-8) 35 | parser.add_argument('--min_lr', type=float, default=5e-6, help='min_lr for consine annealing') 36 | parser.add_argument('--max_T', type=int, default=30, help='epoches for lr->min_lr / min_lr->lr') 37 | parser.add_argument('--eval_step', type=int, default=1, help="eval every 1/eval_step epoch") 38 | parser.add_argument('--save_ckpt', type=bool, default=False) 39 | 40 | parser.add_argument('--dataset', type=str, default='nextqa', choices=['nextqa']) 41 | parser.add_argument('--frame_count', type=int, default=32) 42 | parser.add_argument('--mode', type=str, default='grounding', choices=['grounding', 'uniform', 'oracle']) 43 | 44 | parser.add_argument('--window_size', type=int, default=4) 45 | parser.add_argument('--temperature', type=float, default=0.1) 46 | parser.add_argument('--width', type=float, default=0.2) 47 | parser.add_argument('--use_spatial', type=bool, default=True) 48 | parser.add_argument('--model', type=str, default='t5-xl', choices=['t5-xl']) 49 | parser.add_argument('--use_vit', type=bool, default=False) 50 | parser.add_argument('--use_lora', type=bool, default=False) 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | def reduce_metric(metric): 56 | metric_tensor = torch.tensor(metric).cuda(args.local_rank) 57 | dist.all_reduce(metric_tensor, op=torch.distributed.ReduceOp.SUM) 58 | metric = metric_tensor.item() / dist.get_world_size() 59 | return metric 60 | 61 | def init_seeds(seed=42, cuda_deterministic=True): 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 66 | if cuda_deterministic: # slower, more reproducible 67 | cudnn.deterministic = True 68 | cudnn.benchmark = False 69 | else: # faster, less reproducible 70 | cudnn.deterministic = False 71 | cudnn.benchmark = True 72 | 73 | def prepare_inputs(args, data): 74 | 75 | video_ids = data["video_ids"] 76 | qids = data["qids"] 77 | types = data["types"] 78 | 79 | questions = data["questions"] 80 | answers = data["answers"] 81 | 82 | mc_prompt = "Considering the information presented in the frame, select the correct answer from the options." 83 | 84 | if args.dataset == 'intentqa': 85 | options_a0 = data["options_a0"] 86 | options_a1 = data["options_a1"] 87 | options_a2 = data["options_a2"] 88 | options_a3 = data["options_a3"] 89 | options_a4 = data["options_a4"] 90 | text_input = ['Question: ' + question + f'\nOptions: \nA: {option_a0} \nB: {option_a1} \nC: {option_a2} \nD: {option_a3} \nE: {option_a4}' + '\nAnswer: ' for question, option_a0, option_a1, option_a2, option_a3, option_a4 in zip(questions, options_a0, options_a1, options_a2, options_a3, options_a4)] 91 | 92 | text_output = answers 93 | 94 | if args.dataset == 'nextqa': 95 | return text_input, text_output, questions, options_a0, options_a1, options_a2, options_a3, options_a4 96 | 97 | @torch.no_grad() 98 | def eval(args, val_loader, model): 99 | model.eval() 100 | 101 | val_loss = 0 102 | val_vqa_loss = 0 103 | val_reg_loss = 0 104 | val_info_loss = 0 105 | val_acc = 0 106 | overall_acc = 0 107 | 108 | acc_records = [] 109 | 110 | for step, data in enumerate(val_loader): 111 | 112 | if args.dataset == 'nextqa': 113 | text_input, text_output, questions, options_a0, options_a1, options_a2, options_a3, options_a4 = prepare_inputs(args, data) 114 | samples = { 115 | "text_input": text_input, 116 | "text_output": text_output, 117 | "questions": questions, 118 | "options_a0": options_a0, 119 | "options_a1": options_a1, 120 | "options_a2": options_a2, 121 | "options_a3": options_a3, 122 | "options_a4": options_a4, 123 | "frame_features": data["frame_features"].cuda(args.local_rank), 124 | "answers_text": data["answers_text"], 125 | "answers_id": data["answers_id"] 126 | } 127 | 128 | generate_kwargs = { 129 | "do_sample": True, 130 | "num_beams": 5, 131 | "min_length": 1, 132 | "num_return_sequences": 1, 133 | "max_new_tokens": 30, 134 | "temperature":1, 135 | "top_p":0.9, 136 | "repetition_penalty":1, 137 | "length_penalty":1 138 | } 139 | 140 | with torch.cuda.amp.autocast(enabled=True, dtype=model.module.dtype): # 前后开启autocast 141 | with torch.no_grad(): 142 | outputs = model(samples) 143 | pred_texts = model.module.generate(samples, **generate_kwargs) 144 | 145 | for i in range(args.eval_bs): 146 | qid = data['qids'][i] 147 | video_id = data['video_ids'][i] 148 | type = data['types'][i] 149 | input_text = text_input[i] 150 | label = text_output[i] 151 | pred = pred_texts[i] 152 | 153 | acc_records.append({ 154 | 'qid': qid, 155 | 'video_id': video_id, 156 | 'type': type, 157 | 'input': input_text, 158 | 'label': label, 159 | 'pred': pred 160 | }) 161 | 162 | loss = outputs['loss'] 163 | val_loss += loss.item() 164 | val_vqa_loss += outputs['vqa_loss'].item() 165 | val_reg_loss += outputs['regression_loss'].item() 166 | val_info_loss += outputs['infoNCE_loss'].item() 167 | val_acc += compute_acc(bs = args.eval_bs, labels = text_output, preds = pred_texts) 168 | 169 | if dist.get_rank() == 0 and step<=4: 170 | for i in range(len(text_input)): 171 | print() 172 | print("---------------------eval-------------------------") 173 | print("---------------------ids-------------------------") 174 | print("video_id: " + data["video_ids"][i] + " qid: " + data["qids"][i]) 175 | print("---------------------type-------------------------") 176 | print(data["types"][i]) 177 | print("---------------------input-------------------------") 178 | print(text_input[i]) 179 | print("---------------------preds-------------------------") 180 | print(pred_texts[i]) 181 | print("--------------------answers------------------------") 182 | print(text_output[i]) 183 | print() 184 | 185 | with open(f'files/{args.dataset}_records_{dist.get_rank()}.json', 'w') as f: 186 | json.dump(acc_records, f, indent=2) 187 | 188 | # 同步所有进程 189 | dist.barrier() 190 | 191 | for r in range(dist.get_world_size()): 192 | if dist.get_rank() == r: 193 | if len(os.listdir('files/')) >= dist.get_world_size(): 194 | if args.dataset == 'nextqa': 195 | overall_acc, class_acc = compute_acc_nextqa() 196 | if dist.get_rank() == 0: 197 | print('Overall Acc: ', overall_acc) 198 | print('Class Acc: ', class_acc) 199 | 200 | # 同步所有进程 201 | dist.barrier() 202 | if dist.get_rank() == 0: 203 | folder_path = 'files/' 204 | files = os.listdir(folder_path) 205 | for file in files: 206 | file_path = os.path.join(folder_path, file) 207 | if os.path.isfile(file_path): 208 | os.remove(file_path) 209 | 210 | # 对不同进程上的评价指标进行平均 211 | val_loss = round(reduce_metric(val_loss)/len(val_loader), 4) 212 | val_vqa_loss = round(reduce_metric(val_vqa_loss)/len(val_loader), 4) 213 | val_reg_loss = round(reduce_metric(val_reg_loss)/len(val_loader), 4) 214 | val_info_loss = round(reduce_metric(val_info_loss)/len(val_loader), 4) 215 | val_acc = round(reduce_metric(val_acc)/len(val_loader), 4) 216 | model.train() 217 | return val_loss, val_vqa_loss, val_reg_loss, val_info_loss, val_acc, overall_acc 218 | 219 | 220 | def train(args, train_dataset, val_dataset, model): 221 | 222 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 223 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, sampler=train_sampler, pin_memory=True, shuffle=False, drop_last=True, num_workers=4) 224 | 225 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 226 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_bs, sampler=val_sampler, pin_memory=True, shuffle=False, drop_last=True, num_workers=4) 227 | 228 | 229 | if args.mode == 'grounding': 230 | ignored_params = list(map(id, model.module.grounding.parameters())) # 返回的是parameters的 内存地址 231 | base_params = filter(lambda p: p.requires_grad and id(p) not in ignored_params, model.parameters()) 232 | optimizer = torch.optim.AdamW([ 233 | {'params': base_params}, 234 | {'params': model.module.grounding.parameters(), 'lr': args.grounding_lr}], 235 | lr = args.lr, betas=(0.9, 0.999), weight_decay=0.02) 236 | else: 237 | optimizer = torch.optim.AdamW(filter(lambda p : p.requires_grad, model.parameters()), lr = args.lr, betas=(0.9, 0.999), weight_decay=0.02) 238 | lr_schedule = LinearWarmupCosineLRScheduler(optimizer, max_epoch=args.max_T, min_lr=args.min_lr, init_lr=args.lr, warmup_steps=int(len(train_loader)/4), warmup_start_lr=args.warmup_start_lr) 239 | 240 | max_acc = 0 241 | 242 | scaler = torch.cuda.amp.GradScaler() #训练前实例化一个GradScaler对象 243 | 244 | for epoch in range(args.epoch): 245 | 246 | model.train() 247 | # 设置sampler的epoch,DistributedSampler需要这个来维持各个进程之间的相同随机数种子 248 | train_loader.sampler.set_epoch(epoch) 249 | start = time.time() 250 | train_loss = 0 251 | train_vqa_loss = 0 252 | train_reg_loss = 0 253 | train_info_loss = 0 254 | train_acc = 0 255 | for step, data in enumerate(tqdm(train_loader, disable=not dist.get_rank() == 0)): 256 | model.train() 257 | if args.dataset == 'nextqa': 258 | text_input, text_output, questions, options_a0, options_a1, options_a2, options_a3, options_a4 = prepare_inputs(args, data) 259 | samples = { 260 | "text_input": text_input, 261 | "text_output": text_output, 262 | "questions": questions, 263 | "options_a0": options_a0, 264 | "options_a1": options_a1, 265 | "options_a2": options_a2, 266 | "options_a3": options_a3, 267 | "options_a4": options_a4, 268 | "frame_features": data["frame_features"].cuda(args.local_rank), 269 | "answers_text": data["answers_text"], 270 | "answers_id": data["answers_id"] 271 | } 272 | 273 | with torch.cuda.amp.autocast(enabled=True, dtype=model.module.dtype): # 前后开启autocast 274 | outputs = model(samples) 275 | with torch.no_grad(): 276 | # pred_texts = model.module.generate(samples, **generate_kwargs) 277 | pred_texts = ['N/A' for i in range(args.bs)] 278 | 279 | loss = outputs['loss'] 280 | train_loss += outputs['loss'].item() 281 | train_vqa_loss += outputs['vqa_loss'].item() 282 | train_reg_loss += outputs['regression_loss'].item() 283 | train_info_loss += outputs['infoNCE_loss'].item() 284 | train_acc += compute_acc(bs = args.bs, labels = text_output, preds = pred_texts) 285 | 286 | scaler.scale(loss).backward() #为了梯度放大 287 | scaler.step(optimizer) 288 | scaler.update() #准备着,看是否要增大scaler 289 | 290 | if args.use_schedule: 291 | lr_schedule.step(cur_epoch=epoch, cur_step=step) 292 | 293 | optimizer.zero_grad() 294 | 295 | if step % int(len(train_loader)/args.eval_step) == 0 and epoch > 0 and step >= int(len(train_loader)/args.eval_step) and step < len(train_loader)*0.9: 296 | val_loss, val_acc, overall_acc = eval(args, val_loader, model) 297 | if dist.get_rank() == 0: 298 | print('epoch:{}/{} step:{} val_loss:{} val_acc:{}' 299 | .format(epoch + 1, args.epoch, step, val_loss, val_acc)) 300 | if (overall_acc >= max_acc): 301 | max_acc = overall_acc 302 | if args.save_ckpt: 303 | torch.save(model.module.state_dict(), './{}/{}_{}_{}.pth'.format(args.experiment_path, f'{args.model}_{args.dataset}', epoch+1, overall_acc)) 304 | 305 | # 对不同进程上的评价指标进行平均 306 | train_loss = round(reduce_metric(train_loss)/len(train_loader), 4) 307 | train_vqa_loss = round(reduce_metric(train_vqa_loss)/len(train_loader), 4) 308 | train_reg_loss = round(reduce_metric(train_reg_loss)/len(train_loader), 4) 309 | train_info_loss = round(reduce_metric(train_info_loss)/len(train_loader), 4) 310 | train_acc = round(reduce_metric(train_acc)/len(train_loader), 4) 311 | val_loss, val_vqa_loss, val_reg_loss, val_info_loss, val_acc, overall_acc = eval(args, val_loader, model) 312 | 313 | end = time.time() 314 | if dist.get_rank() == 0: 315 | print('epoch:{}/{} time:{}h lr:{} batchsize:{} train_loss:{} val_loss:{} train_acc: {} val_acc:{}' 316 | .format(epoch + 1, args.epoch, str(round((end-start)/3600, 2)), args.lr, args.bs, train_loss, val_loss, train_acc, val_acc)) 317 | print('train_vqa_loss:{} train_reg_loss:{} train_info_loss: {} val_vqa_loss:{} val_reg_loss:{} val_info_loss: {}' 318 | .format(train_vqa_loss, train_reg_loss, train_info_loss, val_vqa_loss, val_reg_loss, val_info_loss)) 319 | if (overall_acc >= max_acc): 320 | max_acc = overall_acc 321 | if args.save_ckpt: 322 | torch.save(model.module.state_dict(), './{}/{}_{}_{}.pth'.format(args.experiment_path, f'{args.model}_{args.dataset}', epoch+1, overall_acc)) 323 | 324 | dist.destroy_process_group() 325 | 326 | if __name__ == '__main__': 327 | args = parse_args() 328 | 329 | if args.dataset == 'nextqa': 330 | train_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/train.csv', frame_count=args.frame_count) 331 | val_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/val.csv', frame_count=args.frame_count) 332 | test_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/test.csv', frame_count=args.frame_count) 333 | 334 | 335 | from models.blip2_t5_instruct import Blip2T5Instruct 336 | 337 | if 't5' in args.model: 338 | model = Blip2T5Instruct( 339 | dtype=torch.bfloat16, 340 | frame_num=args.frame_count, 341 | mode = args.mode, 342 | window_size = args.window_size, 343 | use_spatial = args.use_spatial, 344 | model = args.model, 345 | temperature = args.temperature, 346 | width = args.width, 347 | use_vit = args.use_vit, 348 | use_lora = args.use_lora 349 | ) 350 | 351 | 352 | device = torch.device('cuda', args.local_rank) 353 | dist.init_process_group(backend='nccl',rank=args.local_rank, world_size=args.word_size) 354 | init_seeds(args.seed + torch.distributed.get_rank()) 355 | torch.cuda.set_device(device) 356 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(args.local_rank), 357 | device_ids=[args.local_rank], 358 | output_device=args.local_rank, 359 | ) 360 | 361 | if dist.get_rank() == 0: 362 | print(get_parameter_number(model)) 363 | print("trian_num: ", len(train_dataset), " val_num: ", len(val_dataset), " test_num: ", len(test_dataset)) 364 | print(args) 365 | 366 | train(args, train_dataset, test_dataset, model) 367 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, StoppingCriteria 3 | import random 4 | import re 5 | import os 6 | import json 7 | import requests 8 | from PIL import Image 9 | from io import BytesIO 10 | import json 11 | import pickle 12 | import pandas as pd 13 | from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop 14 | from typing import Optional, Tuple, Any, Union, List 15 | 16 | def _convert_to_rgb(image): 17 | return image.convert('RGB') 18 | 19 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 20 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 21 | 22 | def image_transform( 23 | image_size: int, 24 | mean: Optional[Tuple[float, ...]] = None, 25 | std: Optional[Tuple[float, ...]] = None, 26 | ): 27 | mean = mean or OPENAI_DATASET_MEAN 28 | if not isinstance(mean, (list, tuple)): 29 | mean = (mean,) * 3 30 | 31 | std = std or OPENAI_DATASET_STD 32 | if not isinstance(std, (list, tuple)): 33 | std = (std,) * 3 34 | 35 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 36 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 37 | image_size = image_size[0] 38 | 39 | normalize = Normalize(mean=mean, std=std) 40 | 41 | transforms = [ 42 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 43 | CenterCrop(image_size), 44 | ] 45 | transforms.extend([ 46 | _convert_to_rgb, 47 | ToTensor(), 48 | normalize, 49 | ]) 50 | return Compose(transforms) 51 | 52 | def get_frames(lst, M): 53 | 54 | frame_num = len(lst) 55 | 56 | if frame_num == 32: 57 | if M==16: 58 | result = [lst[0], lst[2], lst[4], lst[6], lst[8], lst[10], lst[12], lst[14], lst[16], lst[18], lst[20], lst[22], lst[24], lst[26], lst[28], lst[30]] 59 | elif M==8: 60 | result = [lst[2], lst[6], lst[10], lst[14], lst[18], lst[22], lst[26], lst[30]] 61 | elif M==4: 62 | result = [lst[4], lst[12], lst[20], lst[28]] 63 | elif M==1: 64 | result = [lst[16]] 65 | else: 66 | result = lst 67 | 68 | elif frame_num == 16: 69 | if M==8: 70 | result = [lst[0], lst[2], lst[4], lst[6], lst[8], lst[10], lst[12], lst[15]] 71 | elif M==4: 72 | result = [lst[2], lst[6], lst[10], lst[14]] 73 | elif M==1: 74 | result = [lst[7]] 75 | else: 76 | result = lst 77 | 78 | return result 79 | 80 | def load_image(image_file): 81 | if image_file.startswith('http') or image_file.startswith('https'): 82 | response = requests.get(image_file) 83 | image = Image.open(BytesIO(response.content)).convert('RGB') 84 | else: 85 | image = Image.open(image_file).convert('RGB') 86 | return image 87 | 88 | def save_json(file, path): 89 | with open(path, 'w') as f: 90 | json.dump(file, f, indent=2) 91 | 92 | def load_json(path): 93 | with open(path) as f: 94 | data = json.load(f) 95 | return data 96 | 97 | def load_jsonl(path): 98 | data = [] 99 | with open(path, 'r') as file: 100 | for line in file: 101 | json_object = json.loads(line) 102 | data.append(json_object) 103 | return data 104 | 105 | def load_csv(path): 106 | file_list = [] 107 | data = pd.read_csv(path) 108 | columns = data.columns.tolist() 109 | for index, row in data.iterrows(): 110 | file_list.append({}) 111 | for column in columns: 112 | file_list[index][column] = row[column] 113 | return file_list 114 | 115 | def load_pkl(path): 116 | with open(path, 'rb') as f: 117 | data = pickle.load(f) 118 | return data 119 | 120 | def get_parameter_number(model): 121 | total_num = sum(p.numel() for p in model.parameters()) 122 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 123 | return {'Total': total_num, 'Trainable': trainable_num} 124 | 125 | def compute_acc(bs, labels, preds): 126 | acc = 0 127 | for i in range(bs): 128 | label = labels[i] 129 | pred = preds[i] 130 | 131 | if pred.lower() == label.lower(): 132 | acc += 1 133 | return acc/bs 134 | 135 | def compute_acc_nextqa(): 136 | folder_path = 'files/' 137 | target_prefix = 'nextqa_records_' 138 | # 初始化一个空的列表来存储合并后的数据 139 | merged_data = [] 140 | # 遍历目标文件夹中的所有文件 141 | for filename in os.listdir(folder_path): 142 | if filename.startswith(target_prefix) and filename.endswith('.json'): 143 | file_path = os.path.join(folder_path, filename) 144 | with open(file_path, 'r') as file: 145 | data = json.load(file) 146 | merged_data += data 147 | 148 | total_samples = len(merged_data) 149 | correct_predictions = 0 150 | class_counts = {'C': 0, 'T': 0, 'D': 0} 151 | class_correct = {'C': 0, 'T': 0, 'D': 0} 152 | 153 | for sample in merged_data: 154 | if sample['pred'] == sample['label']: 155 | correct_predictions += 1 156 | class_correct[sample['type'][0]] += 1 157 | class_counts[sample['type'][0]] += 1 158 | 159 | overall_accuracy = correct_predictions / total_samples 160 | class_accuracies = {cls: class_correct[cls] / class_counts[cls] for cls in class_counts} 161 | 162 | return overall_accuracy, class_accuracies 163 | 164 | def compute_acc_intentqa(): 165 | folder_path = 'files/' 166 | target_prefix = 'intentqa_records_' 167 | # 初始化一个空的列表来存储合并后的数据 168 | merged_data = [] 169 | # 遍历目标文件夹中的所有文件 170 | for filename in os.listdir(folder_path): 171 | if filename.startswith(target_prefix) and filename.endswith('.json'): 172 | file_path = os.path.join(folder_path, filename) 173 | with open(file_path, 'r') as file: 174 | data = json.load(file) 175 | merged_data += data 176 | 177 | total_samples = len(merged_data) 178 | correct_predictions = 0 179 | class_counts = {'CW': 0, 'CH': 0, 'TN&TP': 0} 180 | class_correct = {'CW': 0, 'CH': 0, 'TN&TP': 0} 181 | 182 | for sample in merged_data: 183 | if sample['pred'] == sample['label']: 184 | correct_predictions += 1 185 | if sample['type'] == 'CW': 186 | class_correct['CW'] += 1 187 | elif sample['type'] == 'CH': 188 | class_correct['CH'] += 1 189 | elif 'T' in sample['type']: 190 | class_correct['TN&TP'] += 1 191 | if sample['type'] == 'CW': 192 | class_counts['CW'] += 1 193 | elif sample['type'] == 'CH': 194 | class_counts['CH'] += 1 195 | elif 'T' in sample['type']: 196 | class_counts['TN&TP'] += 1 197 | 198 | overall_accuracy = correct_predictions / total_samples 199 | class_accuracies = {cls: class_correct[cls] / class_counts[cls] for cls in class_counts} 200 | 201 | return overall_accuracy, class_accuracies 202 | 203 | def compute_acc_starqa(): 204 | folder_path = 'files/' 205 | target_prefix = 'starqa_records_' 206 | # 初始化一个空的列表来存储合并后的数据 207 | merged_data = [] 208 | # 遍历目标文件夹中的所有文件 209 | for filename in os.listdir(folder_path): 210 | if filename.startswith(target_prefix) and filename.endswith('.json'): 211 | file_path = os.path.join(folder_path, filename) 212 | with open(file_path, 'r') as file: 213 | data = json.load(file) 214 | merged_data += data 215 | 216 | total_samples = len(merged_data) 217 | correct_predictions = 0 218 | class_counts = {'Int': 0, 'Fea': 0, 'Pre': 0, 'Seq': 0} 219 | class_correct = {'Int': 0, 'Fea': 0, 'Pre': 0, 'Seq': 0} 220 | 221 | for sample in merged_data: 222 | if sample['pred'] == sample['label']: 223 | correct_predictions += 1 224 | class_correct[sample['type']] += 1 225 | class_counts[sample['type']] += 1 226 | 227 | class_accuracies = {cls: class_correct[cls] / class_counts[cls] for cls in class_counts} 228 | overall_accuracy = (class_accuracies['Int'] + class_accuracies['Fea'] + class_accuracies['Pre'] + class_accuracies['Seq']) / 4 229 | 230 | return overall_accuracy, class_accuracies 231 | 232 | def compute_acc_trafficqa(): 233 | folder_path = 'files/' 234 | target_prefix = 'trafficqa_records_' 235 | # 初始化一个空的列表来存储合并后的数据 236 | merged_data = [] 237 | # 遍历目标文件夹中的所有文件 238 | for filename in os.listdir(folder_path): 239 | if filename.startswith(target_prefix) and filename.endswith('.json'): 240 | file_path = os.path.join(folder_path, filename) 241 | with open(file_path, 'r') as file: 242 | data = json.load(file) 243 | merged_data += data 244 | 245 | total_samples = len(merged_data) 246 | correct_predictions = 0 247 | class_counts = {'U': 0, 'A': 0, 'F': 0, 'R': 0, 'C': 0, 'I': 0} 248 | class_correct = {'U': 0, 'A': 0, 'F': 0, 'R': 0, 'C': 0, 'I': 0} 249 | 250 | for sample in merged_data: 251 | if sample['pred'] == sample['label']: 252 | correct_predictions += 1 253 | class_correct[sample['type']] += 1 254 | class_counts[sample['type']] += 1 255 | 256 | class_accuracies = {cls: class_correct[cls] / class_counts[cls] for cls in class_counts} 257 | overall_accuracy = correct_predictions / total_samples 258 | 259 | return overall_accuracy, class_accuracies 260 | 261 | def compute_acc_vlep(): 262 | folder_path = 'files/' 263 | target_prefix = 'vlep_records_' 264 | # 初始化一个空的列表来存储合并后的数据 265 | merged_data = [] 266 | # 遍历目标文件夹中的所有文件 267 | for filename in os.listdir(folder_path): 268 | if filename.startswith(target_prefix) and filename.endswith('.json'): 269 | file_path = os.path.join(folder_path, filename) 270 | with open(file_path, 'r') as file: 271 | data = json.load(file) 272 | merged_data += data 273 | 274 | total_samples = len(merged_data) 275 | correct_predictions = 0 276 | class_counts = {'N/A.': 0} 277 | class_correct = {'N/A.': 0} 278 | 279 | for sample in merged_data: 280 | if sample['pred'] == sample['label']: 281 | correct_predictions += 1 282 | class_correct[sample['type']] += 1 283 | class_counts[sample['type']] += 1 284 | 285 | class_accuracies = {cls: class_correct[cls] / class_counts[cls] for cls in class_counts} 286 | overall_accuracy = correct_predictions / total_samples 287 | 288 | return overall_accuracy, class_accuracies 289 | 290 | def compute_acc_msrvttqa(): 291 | folder_path = 'files/' 292 | target_prefix = 'msrvttqa_records_' 293 | # 初始化一个空的列表来存储合并后的数据 294 | merged_data = [] 295 | # 遍历目标文件夹中的所有文件 296 | for filename in os.listdir(folder_path): 297 | if filename.startswith(target_prefix) and filename.endswith('.json'): 298 | file_path = os.path.join(folder_path, filename) 299 | with open(file_path, 'r') as file: 300 | data = json.load(file) 301 | merged_data += data 302 | 303 | total_samples = len(merged_data) 304 | correct_predictions = 0 305 | for sample in merged_data: 306 | if sample['pred'] == sample['label']: 307 | correct_predictions += 1 308 | overall_accuracy = correct_predictions / total_samples 309 | 310 | return overall_accuracy 311 | 312 | def compute_acc_msvdqa(): 313 | folder_path = 'files/' 314 | target_prefix = 'msvdqa_records_' 315 | # 初始化一个空的列表来存储合并后的数据 316 | merged_data = [] 317 | # 遍历目标文件夹中的所有文件 318 | for filename in os.listdir(folder_path): 319 | if filename.startswith(target_prefix) and filename.endswith('.json'): 320 | file_path = os.path.join(folder_path, filename) 321 | with open(file_path, 'r') as file: 322 | data = json.load(file) 323 | merged_data += data 324 | 325 | total_samples = len(merged_data) 326 | correct_predictions = 0 327 | for sample in merged_data: 328 | if sample['pred'] == sample['label']: 329 | correct_predictions += 1 330 | overall_accuracy = correct_predictions / total_samples 331 | 332 | return overall_accuracy 333 | 334 | def compute_acc_activityqa(): 335 | folder_path = 'files/' 336 | target_prefix = 'activityqa_records_' 337 | # 初始化一个空的列表来存储合并后的数据 338 | merged_data = [] 339 | # 遍历目标文件夹中的所有文件 340 | for filename in os.listdir(folder_path): 341 | if filename.startswith(target_prefix) and filename.endswith('.json'): 342 | file_path = os.path.join(folder_path, filename) 343 | with open(file_path, 'r') as file: 344 | data = json.load(file) 345 | merged_data += data 346 | 347 | total_samples = len(merged_data) 348 | correct_predictions = 0 349 | for sample in merged_data: 350 | if sample['pred'] == sample['label']: 351 | correct_predictions += 1 352 | overall_accuracy = correct_predictions / total_samples 353 | 354 | return overall_accuracy 355 | 356 | def compute_acc_causalqa(): 357 | folder_path = 'files/' 358 | target_prefix = 'causalqa_records_' 359 | # 初始化一个空的列表来存储合并后的数据 360 | merged_data = [] 361 | # 遍历目标文件夹中的所有文件 362 | for filename in os.listdir(folder_path): 363 | if filename.startswith(target_prefix) and filename.endswith('.json'): 364 | file_path = os.path.join(folder_path, filename) 365 | with open(file_path, 'r') as file: 366 | data = json.load(file) 367 | merged_data += data 368 | 369 | merged_predictions = {} 370 | for sample in merged_data: 371 | video_id = sample['video_id'] 372 | pred_type = sample['type'] 373 | 374 | if video_id not in merged_predictions: 375 | merged_predictions[video_id] = { 376 | 'video_id': video_id, 377 | 'descriptive': {}, 378 | 'explanatory': {}, 379 | 'predictive_answer': {}, 380 | 'predictive_reason': {}, 381 | 'counterfactual_answer': {}, 382 | 'counterfactual_reason': {} 383 | } 384 | 385 | merged_predictions[video_id][pred_type]['input'] = sample['input'] 386 | merged_predictions[video_id][pred_type]['label'] = sample['label'] 387 | merged_predictions[video_id][pred_type]['pred'] = sample['pred'] 388 | 389 | all_num = 0 390 | acc_descriptive = 0 391 | acc_explanatory = 0 392 | acc_predictive_answer = 0 393 | acc_predictive_reason = 0 394 | acc_counterfactual_answer = 0 395 | acc_counterfactual_reason = 0 396 | acc_predictive = 0 397 | acc_counterfactual = 0 398 | 399 | for key in merged_predictions.keys(): 400 | if (merged_predictions[key]['descriptive'] != {}) and (merged_predictions[key]['explanatory'] != {}) and (merged_predictions[key]['predictive_answer'] != {}) and (merged_predictions[key]['predictive_reason'] != {}) and (merged_predictions[key]['counterfactual_answer'] != {}) and (merged_predictions[key]['counterfactual_reason'] != {}): 401 | all_num += 1 402 | predictive_answer = False 403 | predictive_reason = False 404 | counterfactual_answer = False 405 | counterfactual_reason = False 406 | if merged_predictions[key]['descriptive']['pred'] == merged_predictions[key]['descriptive']['label']: 407 | acc_descriptive += 1 408 | if merged_predictions[key]['explanatory']['pred'] == merged_predictions[key]['explanatory']['label']: 409 | acc_explanatory += 1 410 | if merged_predictions[key]['predictive_answer']['pred'] == merged_predictions[key]['predictive_answer']['label']: 411 | acc_predictive_answer += 1 412 | predictive_answer = True 413 | if merged_predictions[key]['predictive_reason']['pred'] == merged_predictions[key]['predictive_reason']['label']: 414 | acc_predictive_reason += 1 415 | predictive_reason = True 416 | if merged_predictions[key]['counterfactual_answer']['pred'] == merged_predictions[key]['counterfactual_answer']['label']: 417 | acc_counterfactual_answer += 1 418 | counterfactual_answer = True 419 | if merged_predictions[key]['counterfactual_reason']['pred'] == merged_predictions[key]['counterfactual_reason']['label']: 420 | acc_counterfactual_reason += 1 421 | counterfactual_reason = True 422 | if predictive_answer and predictive_reason: 423 | acc_predictive += 1 424 | if counterfactual_answer and counterfactual_reason: 425 | acc_counterfactual += 1 426 | 427 | class_accuracies = { 428 | 'D': acc_descriptive/all_num, 429 | 'E': acc_explanatory/all_num, 430 | 'PA': acc_predictive_answer/all_num, 431 | 'PR': acc_predictive_reason/all_num, 432 | 'PAR': acc_predictive/all_num, 433 | 'CA': acc_counterfactual_answer/all_num, 434 | 'CR': acc_counterfactual_reason/all_num, 435 | 'CAR': acc_counterfactual/all_num 436 | } 437 | overall_acc = (class_accuracies['D'] + class_accuracies['E'] + class_accuracies['PAR'] + class_accuracies['CAR'])/4 438 | return overall_acc, class_accuracies 439 | -------------------------------------------------------------------------------- /models/grounding_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | import sys 7 | import os 8 | import contextlib 9 | import torch 10 | import torch.nn as nn 11 | from torch.cuda.amp import autocast as autocast 12 | import sys 13 | import math 14 | import os 15 | from transformers.activations import gelu 16 | from transformers.modeling_outputs import BaseModelOutput 17 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 18 | from utils.utils import * 19 | import torch.nn.functional as F 20 | from models.Transformer import Transformer, Embeddings 21 | import torch.distributed as dist 22 | 23 | class TokenTypeEmbeddings(nn.Module): 24 | def __init__( 25 | self, d_model, token_type_num 26 | ): 27 | super().__init__() 28 | self.modality_embedding = nn.Embedding(token_type_num, d_model) 29 | self.type2id = {'question': 0, 'video': 1, 'object': 2} 30 | nn.init.zeros_(self.modality_embedding.weight) 31 | 32 | def forward(self, embeddings, token_type): 33 | seq_length = embeddings.size(1) 34 | token_type_id = self.type2id[token_type] 35 | modality_embeddings = self.modality_embedding(torch.tensor([token_type_id] * seq_length, dtype=torch.long).to(embeddings.device)) 36 | return modality_embeddings 37 | 38 | class PerturbedTopKFunction(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, x, k: int, num_samples: int = 1000, sigma: float = 0.05): 41 | b, d = x.shape 42 | # for Gaussian: noise and gradient are the same. 43 | noise = torch.normal(mean=0.0, std=1.0, size=(b, num_samples, d)).to(x.device) 44 | perturbed_x = x[:, None, :] + noise * sigma # b, nS, d 45 | topk_results = torch.topk(perturbed_x, k=k, dim=-1, sorted=False) 46 | indices = topk_results.indices # b, nS, k 47 | indices = torch.sort(indices, dim=-1).values # b, nS, k 48 | 49 | perturbed_output = torch.nn.functional.one_hot(indices, num_classes=d).float() 50 | indicators = perturbed_output.mean(dim=1) # b, k, d 51 | 52 | # constants for backward 53 | ctx.k = k 54 | ctx.num_samples = num_samples 55 | ctx.sigma = sigma 56 | 57 | # tensors for backward 58 | ctx.perturbed_output = perturbed_output 59 | ctx.noise = noise 60 | return indicators 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | if grad_output is None: 65 | return tuple([None] * 5) 66 | 67 | noise_gradient = ctx.noise 68 | if ctx.sigma <= 1e-20: 69 | b, _, k, d = ctx.perturbed_output.size() 70 | expected_gradient = torch.zeros(b, k, d).to(grad_output.device) 71 | else: 72 | expected_gradient = ( 73 | torch.einsum("bnkd,bnd->bkd", ctx.perturbed_output, noise_gradient) 74 | / ctx.num_samples 75 | / (ctx.sigma) 76 | ) 77 | 78 | grad_input = torch.einsum("bkd,bkd->bd", grad_output, expected_gradient) 79 | 80 | return (grad_input,) + tuple([None] * 5) 81 | 82 | class Grounding(nn.Module): 83 | def __init__(self, dim=1024, heads=4, dropout=0.3, window_size=4, frame_num=32, width=0.15, temperature=0.1): 84 | super(Grounding, self).__init__() 85 | 86 | self.dtype = torch.float32 87 | self.frame_num = frame_num 88 | self.window_size = window_size 89 | self.width = width 90 | self.temperature = temperature 91 | self.sigma = 9 92 | self.mult = 4 93 | self.num_hidden_layers = 2 94 | self.inner_dim = dim // 4 95 | self.heads = heads 96 | self.activation = 'gelu' 97 | self.use_proj = True 98 | 99 | self.embedding = nn.Sequential( 100 | nn.Dropout(dropout), 101 | nn.Linear(dim, self.inner_dim, bias=False), 102 | ) 103 | self.modality_embedding = TokenTypeEmbeddings(d_model=self.inner_dim, token_type_num=2) 104 | self.position_v = Embeddings(self.inner_dim, 0, self.frame_num, dropout, True) 105 | 106 | self.encoder = Transformer( 107 | AutoConfig.from_pretrained( 108 | "FacebookAI/roberta-base", 109 | num_hidden_layers=self.num_hidden_layers, 110 | hidden_size=self.inner_dim, 111 | attention_probs_dropout_prob=dropout, 112 | intermediate_size=self.mult*self.inner_dim, 113 | num_attention_heads=self.heads, 114 | hidden_act = self.activation, 115 | )) 116 | self.satt_pool_frame = nn.Sequential( 117 | nn.Linear(self.inner_dim, self.inner_dim // 2), 118 | nn.Tanh(), 119 | nn.Linear(self.inner_dim // 2, 1), 120 | nn.Softmax(dim=-2)) 121 | self.logit_gauss_c = nn.Sequential( 122 | nn.Dropout(dropout), 123 | nn.Linear(self.inner_dim, self.window_size), 124 | nn.Sigmoid(), 125 | ) 126 | if self.use_proj: 127 | self.proj_head = nn.Sequential( 128 | nn.Dropout(dropout), 129 | nn.Linear(dim, self.inner_dim // 2, bias=False), 130 | nn.ReLU() if self.activation == 'relu' else nn.GELU(), 131 | nn.Linear(self.inner_dim // 2, self.inner_dim // 2) 132 | ) 133 | 134 | @property 135 | def device(self): 136 | return list(self.parameters())[0].device 137 | 138 | def maybe_autocast(self): 139 | # if on cpu, don't use autocast 140 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 141 | enable_autocast = self.device != torch.device("cpu") 142 | if enable_autocast: 143 | return torch.cuda.amp.autocast(dtype=self.dtype) 144 | else: 145 | return contextlib.nullcontext() 146 | 147 | def HardTopK(self, k, x): 148 | topk_results = torch.topk(x, k=k, dim=-1, sorted=False) 149 | indices = topk_results.indices # b, k 150 | indices = torch.sort(indices, dim=-1).values 151 | return indices 152 | 153 | def get_indicator(self, scores, k, sigma=0.05): 154 | indicator = PerturbedTopKFunction.apply(scores, k, 500, sigma) 155 | return indicator 156 | 157 | def calculate_kl_divergence(self, pred_scores, label_scores): 158 | """ 159 | scores: [bs, option_num] 160 | labels: [bs] 161 | """ 162 | pred_scores = F.log_softmax(pred_scores, dim=-1) 163 | label_scores = F.softmax(label_scores, dim=-1) 164 | kl_divergence = F.kl_div(pred_scores, label_scores, reduction='batchmean') 165 | return kl_divergence 166 | 167 | def mmt_encode(self, video_embeds, text_embeds, gaussian_weight=None): 168 | """ 169 | video_embeds: [bs, frame_num, dim] 170 | text_embeds: [bs, seq, dim] 171 | gaussian_weight: [bs, frame_num] / None 172 | """ 173 | bs = video_embeds.shape[0] 174 | video_embeds = self.position_v(self.embedding(video_embeds)) 175 | text_embeds = self.embedding(text_embeds) 176 | 177 | input_embeds = torch.cat([video_embeds, text_embeds], dim=1) 178 | input_embeds[:,:self.frame_num,:] = input_embeds[:,:self.frame_num,:] + self.modality_embedding(input_embeds[:,:self.frame_num,:], "video") 179 | input_embeds[:,self.frame_num:,:] = input_embeds[:,self.frame_num:,:] + self.modality_embedding(input_embeds[:,self.frame_num:,:], "question") 180 | hidden_states = self.encoder(x=input_embeds, 181 | attn_mask=torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(self.device), 182 | gauss_weight=gaussian_weight)[0] # [bs, frame_num+seq_len, dim] 183 | hidden_states = hidden_states[:,:self.frame_num,:] # [bs, frame_num, dim] 184 | hidden_states = self.position_v(hidden_states) 185 | fatt_gs = self.satt_pool_frame(hidden_states) # [bs, frame_num, 1] 186 | pooled_qv_feat = torch.sum(hidden_states*fatt_gs, dim=1) # [bs, dim] 187 | 188 | gauss_c = self.logit_gauss_c(pooled_qv_feat) # [bs, window_size] 189 | gauss_w = torch.full((gauss_c.shape[0], self.window_size), self.width).to(self.device) # [bs, window_size] 190 | pred_gaussians = self.generate_gmm_weight(gauss_c, gauss_w) # [bs, frame_num] 191 | 192 | return pred_gaussians, gauss_c, gauss_w 193 | 194 | def generate_gauss_weight(self, center, width): 195 | # code copied from https://github.com/minghangz/cpl 196 | weight = torch.linspace(0, 1, self.frame_num) 197 | weight = weight.view(1, -1).expand(center.size(0), -1).to(center.device) 198 | center = center.unsqueeze(-1) 199 | width = width.unsqueeze(-1).clamp(1e-2) / self.sigma 200 | w = 0.3989422804014327 #1/(math.sqrt(2*math.pi)) 201 | weight = w/width*torch.exp(-(weight-center)**2/(2*width**2)) 202 | return weight 203 | # return weight/weight.max(dim=-1, keepdim=True)[0] 204 | 205 | def generate_gmm_weight(self, centers, widths): 206 | """ 207 | centers: [bs, window_size] 208 | """ 209 | centers = rearrange(centers, "b w -> (b w)") # [bs*window_size] 210 | widths = rearrange(widths, "b w -> (b w)") # [bs*window_size] 211 | gaussians = self.generate_gauss_weight(centers, widths) 212 | gaussians = rearrange(gaussians, "(b w) t -> b w t", w=self.window_size) # [bs, window_size, frame_num] 213 | gaussians = torch.sum(gaussians, dim=1) # [bs, frame_num] 214 | gaussians = gaussians/gaussians.max(dim=-1, keepdim=True)[0] # [bs, frame_num] 215 | return gaussians 216 | 217 | def calculate_ce_loss(self, scores, labels): 218 | """ 219 | scores: [bs, option_num] 220 | labels: [bs] 221 | """ 222 | loss = nn.CrossEntropyLoss()(scores, labels) 223 | preds = torch.argmax(scores, dim=-1) 224 | return loss, preds 225 | 226 | def calculate_regression_loss(self, pred_center, label_probs): 227 | """ 228 | pred_center: [bs, window_size] 229 | label_probs: [bs, frame_num] 230 | """ 231 | pred_center = pred_center*(self.frame_num-1) 232 | positive_center = self.HardTopK(self.window_size, label_probs) 233 | zero = torch.tensor(0.0).to(self.device) 234 | one = torch.tensor(1.0).to(self.device) 235 | 236 | oreder_loss = torch.tensor(0.0).to(self.device) 237 | for i in range(pred_center.shape[0]): 238 | temp_loss = torch.tensor(0.0).to(self.device) 239 | for j in range(self.window_size-1): 240 | temp_loss += torch.max(one + pred_center[i,j] - pred_center[i,j+1], zero) 241 | oreder_loss += temp_loss/(self.window_size-1) 242 | oreder_loss = oreder_loss/(pred_center.shape[0]) 243 | loss_bbox = F.smooth_l1_loss(pred_center, positive_center) 244 | return loss_bbox + oreder_loss 245 | 246 | def calculate_contrastive_loss(self, video_embeds, options_embeds, answer_embeds, pred_gaussians, answers_id): 247 | """ 248 | video_embeds: [bs, frame_num, dim] 249 | options_embeds: [bs, 5, dim] 250 | answer_embeds: [bs, dim] 251 | label_probs: [bs, frame_num] 252 | pred_gaussians: [bs, frame_num] 253 | answers_id: [bs] 254 | """ 255 | bs = video_embeds.shape[0] 256 | intra_neg_num = self.frame_num - self.window_size 257 | inter_neg_num_per_video = self.window_size 258 | inter_neg_num = (bs-1)*inter_neg_num_per_video 259 | 260 | # pred_video_embeds 261 | pred_video_embeds = torch.einsum("b k t, b t d -> b k d", self.get_indicator(pred_gaussians, self.window_size), video_embeds) # [bs, window_size, dim] 262 | # intra_negative_video_embeds 263 | negative_gaussians = torch.tensor(1).to(self.device)-pred_gaussians 264 | intra_negative_video_embeds = torch.einsum("b k t, b t d -> b k d", self.get_indicator(negative_gaussians, intra_neg_num), video_embeds) # [bs, intra_neg_num, dim] 265 | # inter_negative_video_embeds 266 | def del_element(index, x): 267 | return torch.cat((x[:index], x[index+1:])) 268 | shuffle_video_embeds = video_embeds[:,torch.randperm(self.frame_num),:][:,:inter_neg_num_per_video,:] # [bs, inter_neg_num_per_video, dim] 打乱第二个维度 269 | inter_negative_video_embeds = [] 270 | for i in range(bs): 271 | temp = del_element(i, shuffle_video_embeds) # [bs-1, inter_neg_num_per_video, dim] 272 | temp = rearrange(temp, "b w o -> (b w) o") # [(bs-1)*inter_neg_num_per_video, dim] 273 | inter_negative_video_embeds.append(temp) 274 | inter_negative_video_embeds = torch.stack(inter_negative_video_embeds, dim=0) # [bs, (bs-1)*inter_neg_num_per_video, dim] 275 | # inter_negative_answer_embeds 276 | inter_negative_answer_embeds = [] 277 | for i in range(bs): 278 | temp = del_element(i, answer_embeds) # [bs-1, dim] 279 | inter_negative_answer_embeds.append(temp) 280 | inter_negative_answer_embeds = torch.stack(inter_negative_answer_embeds, dim=0) # [bs, (bs-1), dim] 281 | # options_negative_embeds 282 | if options_embeds != None: 283 | options_negative_embeds = [] 284 | for i in range(bs): 285 | temp = [] 286 | for j in range(options_embeds.shape[1]): 287 | if j != answers_id[i]: 288 | temp.append(options_embeds[i,j]) 289 | temp = torch.stack(temp, dim=0).to(self.device) # [4, dim] 290 | options_negative_embeds.append(temp) 291 | options_negative_embeds = torch.stack(options_negative_embeds, dim=0) # [bs, 4, dim] 292 | 293 | # proj_head 294 | if self.use_proj: 295 | pred_video_embeds = self.proj_head(pred_video_embeds) 296 | intra_negative_video_embeds = self.proj_head(intra_negative_video_embeds) 297 | inter_negative_video_embeds = self.proj_head(inter_negative_video_embeds) 298 | answer_embeds = self.proj_head(answer_embeds) 299 | inter_negative_answer_embeds = self.proj_head(inter_negative_answer_embeds) 300 | if options_embeds != None: 301 | options_negative_embeds = self.proj_head(options_negative_embeds) 302 | 303 | def l2_norm(x): 304 | return x/x.norm(dim=-1, keepdim=True) 305 | # compute positive_logits 306 | positive_logits = torch.einsum("b w d, b d -> b w", l2_norm(pred_video_embeds), l2_norm(answer_embeds)) # [bs, window_size] 307 | # compute intra_negative_video_logits 308 | intra_negative_video_logits = torch.einsum("b w d, b d -> b w", l2_norm(intra_negative_video_embeds), l2_norm(answer_embeds)) # [bs, intra_neg_num] 309 | # compute inter_negative_video_logits 310 | inter_negative_video_logits = torch.einsum("b w d, b d -> b w", l2_norm(inter_negative_video_embeds), l2_norm(answer_embeds)) # [bs, inter_neg_num] 311 | # compute inter_negative_answer_logits 312 | inter_negative_answer_logits = torch.einsum("b w d, b o d -> b w o", l2_norm(pred_video_embeds), l2_norm(inter_negative_answer_embeds)) # [bs, window_size, bs-1] 313 | # compute option_negative_logits 314 | if options_embeds != None: 315 | options_negative_logits = torch.einsum("b w d, b o d -> b w o", l2_norm(pred_video_embeds), l2_norm(options_negative_embeds)) # [bs, window_size, 4] 316 | 317 | # compute infoNCE loss 318 | infoNCE_loss = 0 319 | labels = torch.zeros(bs, dtype=torch.long, device=self.device) 320 | for i in range(self.window_size): 321 | if options_embeds != None: 322 | logits = torch.cat([positive_logits[:,i].unsqueeze(-1), # [bs, 1] 323 | intra_negative_video_logits, # [bs, intra_neg_num] 324 | options_negative_logits[:, i, :], # [bs, 4] 325 | # inter_negative_video_logits, # [bs, inter_neg_num] 326 | inter_negative_answer_logits[:, i, :], # [bs, bs-1] 327 | ], dim=1) 328 | else: 329 | logits = torch.cat([positive_logits[:,i].unsqueeze(-1), # [bs, 1] 330 | intra_negative_video_logits, # [bs, intra_neg_num] 331 | # inter_negative_video_logits, # [bs, inter_neg_num] 332 | inter_negative_answer_logits[:, i, :], # [bs, bs-1] 333 | ], dim=1) 334 | infoNCE_loss += F.cross_entropy(logits/self.temperature, labels) 335 | infoNCE_loss = infoNCE_loss/self.window_size 336 | return infoNCE_loss 337 | 338 | def forward(self, Q, K, V, answer_embeds, label_probs=None, answers_id=None): 339 | """ 340 | baseline:72.64 341 | oracle: 79.09 342 | """ 343 | 344 | ''' 345 | Q: [bs, seq_len, dim] 346 | K: [bs, frame_num, dim] 347 | V: [bs, frame_num, query_num, dim] 348 | answer_embeds: [bs, dim] 349 | label_probs: [bs, frame_num] 350 | answers_id: [bs] 351 | ''' 352 | video_embeds = K # [bs, frame_num, dim] 353 | question_embeds = Q # [bs, seq_len, dim] 354 | if question_embeds.shape[1] > 1: 355 | option_embeds = Q[:,1:,:] # [bs, 5, dim] 356 | else: 357 | option_embeds = None 358 | answer_embeds = answer_embeds # [bs, dim] 359 | 360 | pred_gaussians, gauss_c, gauss_w = self.mmt_encode(video_embeds, question_embeds) # [bs, frame_num] 361 | 362 | # pos_centers = self.HardTopK(self.window_size, label_probs)/(self.frame_num-1) 363 | # pos_gaussians = self.generate_gmm_weight(pos_centers, gauss_w) # [bs, frame_num] 364 | # kl_loss = self.calculate_kl_divergence(pred_gaussians, pos_gaussians) 365 | regression_loss = self.calculate_regression_loss(gauss_c, label_probs) 366 | infoNCE_loss = self.calculate_contrastive_loss(video_embeds, option_embeds, answer_embeds, pred_gaussians, answers_id) 367 | 368 | if self.training: 369 | selection_mask = self.get_indicator(pred_gaussians, self.window_size) # [bs, window_size, frame_num] 370 | selected_V = torch.einsum("b k t, b t q d -> b k q d", selection_mask, V) # [bs, window_size, query_num, dim] 371 | else: 372 | indicators = self.HardTopK(self.window_size, pred_gaussians) # [bs, window_size] 373 | selection_mask = torch.zeros(K.shape[0], self.window_size, K.shape[1]).to(self.device) # [bs, window_size, frame_num] 374 | for i in range(K.shape[0]): 375 | for j in range(self.window_size): 376 | selection_mask[i][j][indicators[i][j]] = 1 377 | selected_V = torch.einsum("b k t, b t q d -> b k q d", selection_mask, V) # [bs, window_size, query_num, dim] 378 | return selected_V, regression_loss, infoNCE_loss 379 | 380 | # bs = 10 381 | # model = Grounding(window_size=4, dim=1024, heads=8, dropout=0.1, frame_num=32, width=0.2, temperature=0.1) 382 | # print(get_parameter_number(model)) 383 | # model(Q=torch.randn(bs, 6, 1024), 384 | # K=torch.randn(bs, 32, 1024), 385 | # V=torch.randn(bs, 32, 257, 1408), 386 | # answer_embeds=torch.randn(bs, 1024), 387 | # label_probs=torch.randn(bs, 32), 388 | # answers_id=torch.tensor([0 for i in range(bs)])) 389 | -------------------------------------------------------------------------------- /models/eva_clip_branch_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import math 4 | from dataclasses import dataclass 5 | from typing import Tuple, Union, Callable, Optional 6 | import numpy as np 7 | from collections import OrderedDict 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import drop_path 13 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 14 | 15 | class LayerNorm(nn.LayerNorm): 16 | """Subclass torch's LayerNorm to handle fp16.""" 17 | 18 | def forward(self, x: torch.Tensor): 19 | orig_type = x.dtype 20 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 21 | return x.to(orig_type) 22 | 23 | 24 | class QuickGELU(nn.Module): 25 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 26 | def forward(self, x: torch.Tensor): 27 | return x * torch.sigmoid(1.702 * x) 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | num_heads=8, 35 | qkv_bias=True, 36 | scaled_cosine=False, 37 | scale_heads=False, 38 | logit_scale_max=math.log(1. / 0.01), 39 | attn_drop=0., 40 | proj_drop=0. 41 | ): 42 | super().__init__() 43 | self.scaled_cosine = scaled_cosine 44 | self.scale_heads = scale_heads 45 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 46 | self.num_heads = num_heads 47 | self.head_dim = dim // num_heads 48 | self.scale = self.head_dim ** -0.5 49 | self.logit_scale_max = logit_scale_max 50 | 51 | # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original 52 | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) 53 | if qkv_bias: 54 | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) 55 | else: 56 | self.in_proj_bias = None 57 | 58 | if self.scaled_cosine: 59 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) 60 | else: 61 | self.logit_scale = None 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | if self.scale_heads: 64 | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) 65 | else: 66 | self.head_scale = None 67 | self.out_proj = nn.Linear(dim, dim) 68 | self.out_drop = nn.Dropout(proj_drop) 69 | 70 | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): 71 | L, N, C = x.shape 72 | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 73 | q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 74 | k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 75 | v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 76 | 77 | if self.logit_scale is not None: 78 | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) 79 | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() 80 | attn = attn.view(N, self.num_heads, L, L) * logit_scale 81 | attn = attn.view(-1, L, L) 82 | else: 83 | q = q * self.scale 84 | attn = torch.bmm(q, k.transpose(-1, -2)) 85 | 86 | if attn_mask is not None: 87 | if attn_mask.dtype == torch.bool: 88 | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) 89 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 90 | attn_mask = new_attn_mask 91 | attn += attn_mask 92 | 93 | attn = attn.softmax(dim=-1) 94 | attn = self.attn_drop(attn) 95 | 96 | x = torch.bmm(attn, v) 97 | if self.head_scale is not None: 98 | x = x.view(N, self.num_heads, L, C) * self.head_scale 99 | x = x.view(-1, L, C) 100 | x = x.transpose(0, 1).reshape(L, N, C) 101 | x = self.out_proj(x) 102 | x = self.out_drop(x) 103 | return x 104 | 105 | 106 | class ResidualAttentionBlock(nn.Module): 107 | def __init__( 108 | self, 109 | d_model: int, 110 | n_head: int, 111 | mlp_ratio: float = 4.0, 112 | act_layer: Callable = nn.GELU, 113 | scale_cosine_attn: bool = False, 114 | scale_heads: bool = False, 115 | scale_attn: bool = False, 116 | scale_fc: bool = False, 117 | ): 118 | super().__init__() 119 | 120 | self.ln_1 = LayerNorm(d_model) 121 | # FIXME torchscript issues need to be resolved for custom attention 122 | # if scale_cosine_attn or scale_heads: 123 | # self.attn = Attention( 124 | # d_model, n_head, 125 | # scaled_cosine=scale_cosine_attn, 126 | # scale_heads=scale_heads, 127 | # ) 128 | self.attn = nn.MultiheadAttention(d_model, n_head) 129 | self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() 130 | 131 | self.ln_2 = LayerNorm(d_model) 132 | mlp_width = int(d_model * mlp_ratio) 133 | self.mlp = nn.Sequential(OrderedDict([ 134 | ("c_fc", nn.Linear(d_model, mlp_width)), 135 | ('ln', LayerNorm(mlp_width) if scale_fc else nn.Identity()), 136 | ("gelu", act_layer()), 137 | ("c_proj", nn.Linear(mlp_width, d_model)) 138 | ])) 139 | 140 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 141 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 142 | # FIXME torchscript issues need resolving for custom attention option to work 143 | # if self.use_torch_attn: 144 | # return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 145 | # else: 146 | # return self.attn(x, attn_mask=attn_mask) 147 | 148 | def cross_attention(self, x: torch.Tensor, context: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 149 | return self.attn(x, context, context, need_weights=False, attn_mask=attn_mask)[0] 150 | 151 | 152 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 153 | x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask)) 154 | x = x + self.mlp(self.ln_2(x)) 155 | return x 156 | 157 | class Transformer(nn.Module): 158 | def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): 159 | super().__init__() 160 | self.width = width 161 | self.layers = layers 162 | 163 | self.resblocks = nn.ModuleList([ 164 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) 165 | for _ in range(layers) 166 | ]) 167 | 168 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 169 | for r in self.resblocks: 170 | x = r(x, attn_mask=attn_mask) 171 | return x 172 | 173 | class TextTransformer(nn.Module): 174 | def __init__( 175 | self, 176 | vocab_size: int, 177 | width: int, 178 | layers: int, 179 | heads: int, 180 | context_length: int, 181 | embed_dim: int, 182 | act_layer: Callable = nn.GELU, 183 | ): 184 | super().__init__() 185 | self.transformer = Transformer( 186 | width=width, 187 | layers=layers, 188 | heads=heads, 189 | act_layer=act_layer, 190 | ) 191 | self.context_length = context_length 192 | self.vocab_size = vocab_size 193 | self.token_embedding = nn.Embedding(vocab_size, width) 194 | self.positional_embedding = nn.Parameter(torch.empty(context_length, width)) 195 | self.ln_final = LayerNorm(width) 196 | 197 | self.text_projection = nn.Parameter(torch.empty(width, embed_dim)) 198 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 199 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 200 | 201 | self.init_parameters() 202 | 203 | def init_parameters(self): 204 | nn.init.normal_(self.token_embedding.weight, std=0.02) 205 | nn.init.normal_(self.positional_embedding, std=0.01) 206 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 207 | 208 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 209 | attn_std = self.transformer.width ** -0.5 210 | fc_std = (2 * self.transformer.width) ** -0.5 211 | for block in self.transformer.resblocks: 212 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 213 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 214 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 215 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 216 | 217 | if self.text_projection is not None: 218 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 219 | 220 | def build_attention_mask(self): 221 | # lazily create causal attention mask, with full attention between the vision tokens 222 | # pytorch uses additive attention mask; fill with -inf 223 | mask = torch.empty(self.context_length, self.context_length) 224 | mask.fill_(float("-inf")) 225 | mask.triu_(1) # zero out the lower diagonal 226 | return mask 227 | 228 | def forward_features(self, text: torch.Tensor): 229 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 230 | 231 | x = x + self.positional_embedding 232 | x = x.permute(1, 0, 2) # NLD -> LND 233 | x = self.transformer(x, attn_mask=self.attn_mask) 234 | x = x.permute(1, 0, 2) # LND -> NLD 235 | x = self.ln_final(x) # [batch_size, 77, 768] 236 | 237 | # take features from the eot embedding (eot_token is the highest number in each sequence) 238 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] # [batch_size, 768] 239 | return x 240 | 241 | def forward(self, x: torch.Tensor): 242 | x = self.forward_features(x) 243 | if self.text_projection is not None: 244 | x = x @ self.text_projection 245 | # print(x.shape) 246 | return x 247 | 248 | class DropPath(nn.Module): 249 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 250 | """ 251 | def __init__(self, drop_prob=None): 252 | super(DropPath, self).__init__() 253 | self.drop_prob = drop_prob 254 | 255 | def forward(self, x): 256 | return drop_path(x, self.drop_prob, self.training) 257 | 258 | def extra_repr(self) -> str: 259 | return 'p={}'.format(self.drop_prob) 260 | 261 | 262 | class Mlp(nn.Module): 263 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 264 | super().__init__() 265 | out_features = out_features or in_features 266 | hidden_features = hidden_features or in_features 267 | self.fc1 = nn.Linear(in_features, hidden_features) 268 | self.act = act_layer() 269 | self.fc2 = nn.Linear(hidden_features, out_features) 270 | self.drop = nn.Dropout(drop) 271 | 272 | def forward(self, x): 273 | x = self.fc1(x) 274 | x = self.act(x) 275 | # x = self.drop(x) 276 | # commit this for the orignal BERT implement 277 | x = self.fc2(x) 278 | x = self.drop(x) 279 | return x 280 | 281 | 282 | class Vision_Attention(nn.Module): 283 | def __init__( 284 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 285 | proj_drop=0., window_size=None, attn_head_dim=None): 286 | super().__init__() 287 | self.num_heads = num_heads 288 | head_dim = dim // num_heads 289 | if attn_head_dim is not None: 290 | head_dim = attn_head_dim 291 | all_head_dim = head_dim * self.num_heads 292 | self.scale = qk_scale or head_dim ** -0.5 293 | 294 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 295 | if qkv_bias: 296 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 297 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 298 | else: 299 | self.q_bias = None 300 | self.v_bias = None 301 | 302 | if window_size: 303 | self.window_size = window_size 304 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 305 | self.relative_position_bias_table = nn.Parameter( 306 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 307 | # cls to token & token 2 cls & cls to cls 308 | 309 | # get pair-wise relative position index for each token inside the window 310 | coords_h = torch.arange(window_size[0]) 311 | coords_w = torch.arange(window_size[1]) 312 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 313 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 314 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 315 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 316 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 317 | relative_coords[:, :, 1] += window_size[1] - 1 318 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 319 | relative_position_index = \ 320 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) 321 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 322 | relative_position_index[0, 0:] = self.num_relative_distance - 3 323 | relative_position_index[0:, 0] = self.num_relative_distance - 2 324 | relative_position_index[0, 0] = self.num_relative_distance - 1 325 | 326 | self.register_buffer("relative_position_index", relative_position_index) 327 | else: 328 | self.window_size = None 329 | self.relative_position_bias_table = None 330 | self.relative_position_index = None 331 | 332 | self.attn_drop = nn.Dropout(attn_drop) 333 | self.proj = nn.Linear(all_head_dim, dim) 334 | self.proj_drop = nn.Dropout(proj_drop) 335 | 336 | def forward(self, x, rel_pos_bias=None): 337 | B, N, C = x.shape 338 | qkv_bias = None 339 | if self.q_bias is not None: 340 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 341 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 342 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 343 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 344 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 345 | 346 | q = q * self.scale 347 | attn = (q @ k.transpose(-2, -1)) 348 | 349 | if self.relative_position_bias_table is not None: 350 | relative_position_bias = \ 351 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 352 | self.window_size[0] * self.window_size[1] + 1, 353 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 354 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 355 | attn = attn + relative_position_bias.unsqueeze(0) 356 | 357 | if rel_pos_bias is not None: 358 | attn = attn + rel_pos_bias 359 | 360 | attn = attn.softmax(dim=-1) 361 | attn = self.attn_drop(attn) 362 | 363 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 364 | x = self.proj(x) 365 | x = self.proj_drop(x) 366 | return x 367 | 368 | 369 | class Vision_Block(nn.Module): 370 | 371 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 372 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 373 | window_size=None, attn_head_dim=None): 374 | super().__init__() 375 | self.norm1 = norm_layer(dim) 376 | self.attn = Vision_Attention( 377 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 378 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) 379 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 380 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 381 | self.norm2 = norm_layer(dim) 382 | mlp_hidden_dim = int(dim * mlp_ratio) 383 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 384 | 385 | if init_values is not None and init_values > 0: 386 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 387 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 388 | else: 389 | self.gamma_1, self.gamma_2 = None, None 390 | 391 | def forward(self, x, rel_pos_bias=None): 392 | if self.gamma_1 is None: 393 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 394 | x = x + self.drop_path(self.mlp(self.norm2(x))) 395 | else: 396 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 397 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 398 | return x 399 | 400 | class Clip_Branch_Encoder(nn.Module): 401 | def __init__( 402 | self 403 | ): 404 | super().__init__() 405 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 406 | # memory efficient in recent PyTorch releases (>= 1.10). 407 | 408 | self.last_vision_block = Vision_Block( 409 | dim=1408, num_heads=16, mlp_ratio=4.3637, qkv_bias=True, qk_scale=None, 410 | drop=0.0, attn_drop=0.0, drop_path=0.4000000059604645, norm_layer=nn.LayerNorm, 411 | init_values=None, window_size=None) 412 | self.norm = nn.LayerNorm(1408) 413 | self.head = nn.Linear(1408, 1024) 414 | 415 | self.text = TextTransformer( 416 | vocab_size=49408, 417 | width=768, 418 | layers=12, 419 | heads=12, 420 | context_length=77, 421 | embed_dim=1024, 422 | act_layer = nn.GELU 423 | ) 424 | 425 | self.last_vision_block.load_state_dict(torch.load("experiments/eva_clip_last_vision_block.pth", map_location='cpu')) 426 | self.norm.load_state_dict(torch.load("experiments/eva_clip_last_vision_norm.pth", map_location='cpu')) 427 | self.head.load_state_dict(torch.load("experiments/eva_clip_last_vision_head.pth", map_location='cpu')) 428 | self.text.load_state_dict(torch.load("experiments/eva_clip_text_model.pth", map_location='cpu')) 429 | 430 | def encode_image(self, image_features): 431 | image_features = self.last_vision_block(image_features) 432 | image_features = image_features[:, 0, :] 433 | image_features = self.norm(image_features) 434 | image_features = self.head(image_features) 435 | return image_features 436 | 437 | def encode_text(self, text): 438 | return self.text(text) 439 | -------------------------------------------------------------------------------- /models/eva_vit.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, Union 2 | from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig 3 | from transformers import InstructBlipConfig 4 | import torch 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | from transformers.activations import ACT2FN 8 | from transformers.modeling_outputs import ( 9 | BaseModelOutput, 10 | BaseModelOutputWithPooling, 11 | ) 12 | from einops import rearrange, repeat 13 | from transformers.modeling_utils import PreTrainedModel 14 | from transformers.utils import ( 15 | add_start_docstrings_to_model_forward, 16 | replace_return_docstrings, 17 | ModelOutput 18 | ) 19 | from dataclasses import dataclass 20 | import sys 21 | import os 22 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 23 | 24 | @dataclass 25 | class BaseModelOutputWithPooling(ModelOutput): 26 | """ 27 | Base class for model's outputs that also contains a pooling of the last hidden states. 28 | 29 | Args: 30 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 31 | Sequence of hidden-states at the output of the last layer of the model. 32 | pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): 33 | Last layer hidden-state of the first token of the sequence (classification token) after further processing 34 | through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns 35 | the classification token after processing through a linear layer and a tanh activation function. The linear 36 | layer weights are trained from the next sentence prediction (classification) objective during pretraining. 37 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 38 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 39 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 40 | 41 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 42 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 43 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 44 | sequence_length)`. 45 | 46 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 47 | heads. 48 | """ 49 | 50 | last_hidden_state: torch.FloatTensor = None 51 | last_hidden_state_without_norm: torch.FloatTensor = None 52 | pooler_output: torch.FloatTensor = None 53 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 54 | attentions: Optional[Tuple[torch.FloatTensor]] = None 55 | 56 | # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 57 | class Blip2VisionEmbeddings(nn.Module): 58 | def __init__(self, config: Blip2VisionConfig): 59 | super().__init__() 60 | self.config = config 61 | self.embed_dim = config.hidden_size 62 | self.image_size = config.image_size 63 | self.patch_size = config.patch_size 64 | 65 | self.class_embedding = nn.Parameter( 66 | torch.randn(1, 1, self.embed_dim), 67 | ) 68 | 69 | self.patch_embedding = nn.Conv2d( 70 | in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size 71 | ) 72 | 73 | self.num_patches = (self.image_size // self.patch_size) ** 2 # (224/14)^2 = 256 74 | self.num_positions = self.num_patches + 1 # 257 75 | 76 | self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) 77 | 78 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 79 | # pixel_values: [8, 3, 224, 224] 80 | batch_size = pixel_values.shape[0] 81 | target_dtype = self.patch_embedding.weight.dtype 82 | patch_embeds = self.patch_embedding(pixel_values) # [8, 1408, 16, 16] 83 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # [bs, 256, 1408] 84 | 85 | class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) 86 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) # [bs, 257, 1408] 87 | embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) 88 | return embeddings 89 | 90 | class Blip2Attention(nn.Module): 91 | """Multi-headed attention from 'Attention Is All You Need' paper""" 92 | 93 | def __init__(self, config): 94 | super().__init__() 95 | self.config = config 96 | self.embed_dim = config.hidden_size # 1408 97 | self.num_heads = config.num_attention_heads # 16 98 | self.head_dim = self.embed_dim // self.num_heads # 88 99 | if self.head_dim * self.num_heads != self.embed_dim: 100 | raise ValueError( 101 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 102 | f" {self.num_heads})." 103 | ) 104 | self.scale = self.head_dim**-0.5 105 | self.dropout = nn.Dropout(config.attention_dropout) 106 | 107 | # small tweak here compared to CLIP, no bias here 108 | self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) 109 | 110 | if config.qkv_bias: 111 | q_bias = nn.Parameter(torch.zeros(self.embed_dim)) 112 | v_bias = nn.Parameter(torch.zeros(self.embed_dim)) 113 | else: 114 | q_bias = None 115 | v_bias = None 116 | 117 | if q_bias is not None: 118 | qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) 119 | self.qkv.bias = nn.Parameter(qkv_bias) 120 | 121 | self.projection = nn.Linear(self.embed_dim, self.embed_dim) 122 | 123 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 124 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 125 | 126 | def forward( 127 | self, 128 | hidden_states: torch.Tensor, 129 | head_mask: Optional[torch.Tensor] = None, 130 | output_attentions: Optional[bool] = False, 131 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 132 | """Input shape: Batch x Time x Channel""" 133 | 134 | bsz, tgt_len, embed_dim = hidden_states.size() 135 | 136 | mixed_qkv = self.qkv(hidden_states) 137 | 138 | mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( 139 | 2, 0, 3, 1, 4 140 | ) 141 | query_states, key_states, value_states = ( 142 | mixed_qkv[0], 143 | mixed_qkv[1], 144 | mixed_qkv[2], 145 | ) 146 | 147 | # Take the dot product between "query" and "key" to get the raw attention scores. 148 | attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) 149 | 150 | attention_scores = attention_scores * self.scale 151 | 152 | # Normalize the attention scores to probabilities. 153 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 154 | 155 | # This is actually dropping out entire tokens to attend to, which might 156 | # seem a bit unusual, but is taken from the original Transformer paper. 157 | attention_probs = self.dropout(attention_probs) 158 | 159 | # Mask heads if we want to 160 | if head_mask is not None: 161 | attention_probs = attention_probs * head_mask 162 | 163 | context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) 164 | 165 | new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) 166 | context_layer = context_layer.reshape(new_context_layer_shape) 167 | 168 | output = self.projection(context_layer) 169 | 170 | outputs = (output, attention_probs) if output_attentions else (output, None) 171 | 172 | return outputs 173 | 174 | # Copied from transformers.models.blip.modeling_blip.BlipMLP 175 | class Blip2MLP(nn.Module): 176 | def __init__(self, config): 177 | super().__init__() 178 | self.config = config 179 | self.activation_fn = ACT2FN[config.hidden_act] 180 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 181 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 182 | 183 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 184 | hidden_states = self.fc1(hidden_states) 185 | hidden_states = self.activation_fn(hidden_states) 186 | hidden_states = self.fc2(hidden_states) 187 | return hidden_states 188 | 189 | # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 190 | class Blip2EncoderLayer(nn.Module): 191 | def __init__(self, config: Blip2Config): 192 | super().__init__() 193 | self.embed_dim = config.hidden_size 194 | self.self_attn = Blip2Attention(config) 195 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 196 | self.mlp = Blip2MLP(config) 197 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 198 | 199 | def forward( 200 | self, 201 | hidden_states: torch.Tensor, 202 | attention_mask: torch.Tensor, 203 | output_attentions: Optional[bool] = False, 204 | ) -> Tuple[torch.FloatTensor]: 205 | """ 206 | Args: 207 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 208 | attention_mask (`torch.FloatTensor`): attention mask of size 209 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 210 | `(config.encoder_attention_heads,)`. 211 | output_attentions (`bool`, *optional*): 212 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 213 | returned tensors for more detail. 214 | """ 215 | residual = hidden_states 216 | 217 | hidden_states = self.layer_norm1(hidden_states) 218 | hidden_states, attn_weights = self.self_attn( 219 | hidden_states=hidden_states, 220 | head_mask=attention_mask, 221 | output_attentions=output_attentions, 222 | ) 223 | hidden_states = hidden_states + residual 224 | residual = hidden_states 225 | hidden_states = self.layer_norm2(hidden_states) 226 | hidden_states = self.mlp(hidden_states) 227 | 228 | hidden_states = hidden_states + residual 229 | 230 | outputs = (hidden_states,) 231 | 232 | if output_attentions: 233 | outputs += (attn_weights,) 234 | 235 | return outputs 236 | 237 | class Blip2Encoder(nn.Module): 238 | """ 239 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a 240 | [`Blip2EncoderLayer`]. 241 | 242 | Args: 243 | config (`Blip2Config`): 244 | The corresponding vision configuration for the `Blip2Encoder`. 245 | """ 246 | 247 | def __init__(self, config: Blip2Config): 248 | super().__init__() 249 | self.config = config 250 | self.layers = nn.ModuleList([Blip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) 251 | self.gradient_checkpointing = False 252 | 253 | def forward( 254 | self, 255 | inputs_embeds, 256 | attention_mask: Optional[torch.Tensor] = None, 257 | output_attentions: Optional[bool] = None, 258 | output_hidden_states: Optional[bool] = None, 259 | return_dict: Optional[bool] = None, 260 | ) -> Union[Tuple, BaseModelOutput]: 261 | r""" 262 | Args: 263 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 264 | Embedded representation of the inputs. Should be float, not int tokens. 265 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 266 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 267 | 268 | - 1 for tokens that are **not masked**, 269 | - 0 for tokens that are **masked**. 270 | 271 | [What are attention masks?](../glossary#attention-mask) 272 | output_attentions (`bool`, *optional*): 273 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 274 | returned tensors for more detail. 275 | output_hidden_states (`bool`, *optional*): 276 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 277 | for more detail. 278 | return_dict (`bool`, *optional*): 279 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 280 | """ 281 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 282 | output_hidden_states = ( 283 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 284 | ) 285 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 286 | 287 | encoder_states = () if output_hidden_states else None 288 | all_attentions = () if output_attentions else None 289 | 290 | hidden_states = inputs_embeds 291 | for idx, encoder_layer in enumerate(self.layers): 292 | if output_hidden_states: 293 | encoder_states = encoder_states + (hidden_states,) 294 | if self.gradient_checkpointing and self.training: 295 | 296 | def create_custom_forward(module): 297 | def custom_forward(*inputs): 298 | return module(*inputs, output_attentions) 299 | 300 | return custom_forward 301 | 302 | layer_outputs = torch.utils.checkpoint.checkpoint( 303 | create_custom_forward(encoder_layer), 304 | hidden_states, 305 | attention_mask, 306 | ) 307 | else: 308 | layer_outputs = encoder_layer( 309 | hidden_states, 310 | attention_mask, 311 | output_attentions=output_attentions, 312 | ) 313 | 314 | hidden_states = layer_outputs[0] 315 | 316 | if output_attentions: 317 | all_attentions = all_attentions + (layer_outputs[1],) 318 | 319 | if output_hidden_states: 320 | encoder_states = encoder_states + (hidden_states,) 321 | 322 | if not return_dict: 323 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 324 | return BaseModelOutput( 325 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 326 | ) 327 | 328 | class Blip2PreTrainedModel(PreTrainedModel): 329 | """ 330 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 331 | models. 332 | """ 333 | 334 | config_class = Blip2Config 335 | base_model_prefix = "blip" 336 | supports_gradient_checkpointing = True 337 | _keys_to_ignore_on_load_missing = [ 338 | r"position_ids", 339 | r"language_model.encoder.embed_tokens.weight", 340 | r"language_model.decoder.embed_tokens.weight", 341 | r"language_model.lm_head.weight", 342 | ] 343 | _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] 344 | _skip_keys_device_placement = "past_key_values" 345 | _keep_in_fp32_modules = ["wo"] 346 | 347 | def _init_weights(self, module): 348 | """Initialize the weights""" 349 | factor = self.config.initializer_range 350 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): 351 | module.weight.data.normal_(mean=0.0, std=factor) 352 | if hasattr(module, "bias") and module.bias is not None: 353 | module.bias.data.zero_() 354 | 355 | if isinstance(module, Blip2VisionEmbeddings): 356 | if hasattr(self.config, "vision_config"): 357 | factor = self.config.vision_config.initializer_range 358 | nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) 359 | nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) 360 | 361 | elif isinstance(module, nn.LayerNorm): 362 | module.bias.data.zero_() 363 | module.weight.data.fill_(1.0) 364 | elif isinstance(module, nn.Linear) and module.bias is not None: 365 | module.bias.data.zero_() 366 | 367 | def _set_gradient_checkpointing(self, module, value=False): 368 | if isinstance(module, Blip2Encoder): 369 | module.gradient_checkpointing = value 370 | 371 | BLIP_2_VISION_INPUTS_DOCSTRING = r""" 372 | Args: 373 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 374 | Pixel values. Pixel values can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for 375 | details. 376 | output_attentions (`bool`, *optional*): 377 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 378 | tensors for more detail. 379 | output_hidden_states (`bool`, *optional*): 380 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 381 | more detail. 382 | return_dict (`bool`, *optional*): 383 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 384 | """ 385 | 386 | class Blip2VisionModel(Blip2PreTrainedModel): 387 | main_input_name = "pixel_values" 388 | config_class = Blip2VisionConfig 389 | 390 | def __init__(self, config: Blip2VisionConfig): 391 | super().__init__(config) 392 | self.config = config 393 | embed_dim = config.hidden_size 394 | 395 | self.embeddings = Blip2VisionEmbeddings(config) 396 | self.encoder = Blip2Encoder(config) 397 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 398 | 399 | self.post_init() 400 | 401 | @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) 402 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) 403 | def forward( 404 | self, 405 | pixel_values: Optional[torch.FloatTensor] = None, 406 | output_attentions: Optional[bool] = None, 407 | output_hidden_states: Optional[bool] = None, 408 | return_dict: Optional[bool] = None, 409 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 410 | r""" 411 | Returns: 412 | 413 | """ 414 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 415 | output_hidden_states = ( 416 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 417 | ) 418 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 419 | 420 | if pixel_values is None: 421 | raise ValueError("You have to specify pixel_values") 422 | 423 | hidden_states = self.embeddings(pixel_values) 424 | 425 | encoder_outputs = self.encoder( 426 | inputs_embeds=hidden_states, 427 | output_attentions=output_attentions, 428 | output_hidden_states=output_hidden_states, 429 | return_dict=return_dict, 430 | ) 431 | 432 | last_hidden_state = encoder_outputs[0] 433 | pooled_output = last_hidden_state[:, 0, :] 434 | 435 | last_hidden_state = self.post_layernorm(last_hidden_state) 436 | pooled_output = self.post_layernorm(pooled_output) 437 | 438 | if not return_dict: 439 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 440 | 441 | return BaseModelOutputWithPooling( 442 | last_hidden_state=last_hidden_state, 443 | pooler_output=pooled_output, 444 | last_hidden_state_without_norm = encoder_outputs[0], 445 | hidden_states=encoder_outputs.hidden_states, 446 | attentions=encoder_outputs.attentions, 447 | ) 448 | 449 | def get_input_embeddings(self): 450 | return self.embeddings 451 | 452 | -------------------------------------------------------------------------------- /models/blip2_t5_instruct.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import contextlib 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.cuda.amp import autocast as autocast 8 | from transformers import T5Tokenizer, BertTokenizer 9 | from transformers import T5Config, Blip2Config, BertConfig 10 | from transformers.modeling_outputs import BaseModelOutput 11 | from einops import rearrange, repeat 12 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 13 | from utils.utils import * 14 | import torch.nn.functional as F 15 | from models.modeling_t5 import T5ForConditionalGeneration 16 | from models.Qformer import BertLMHeadModel 17 | from models.eva_vit import Blip2VisionModel 18 | from models.eva_clip_branch_encoder import Clip_Branch_Encoder 19 | from utils.simple_tokenizer import tokenize as clip_text_tokenizer 20 | 21 | class Blip2T5Instruct(nn.Module): 22 | def __init__( 23 | self, 24 | frame_num = 32, 25 | dtype=torch.bfloat16, 26 | mode = 'grounding', 27 | window_size = 4, 28 | use_spatial = False, 29 | model = 't5-xl', 30 | temperature = 0.1, 31 | use_vit = False, 32 | width = 0.2, 33 | use_lora=False 34 | ): 35 | super().__init__() 36 | 37 | self.dtype = dtype 38 | self.frame_num = frame_num 39 | self.mode = mode 40 | self.window_size = window_size 41 | self.use_spatial = use_spatial 42 | self.model = model 43 | self.max_input_txt_len = 512 44 | self.max_output_txt_len = 512 45 | self.width = width 46 | self.temperature = temperature 47 | self.use_vit = use_vit 48 | self.use_lora = use_lora 49 | 50 | print('loading ViT') 51 | blip2_config = Blip2Config.from_pretrained('Salesforce/blip2-flan-t5-xl') 52 | self.eva_vit_post_layer_norm = nn.LayerNorm(blip2_config.vision_config.hidden_size, eps=blip2_config.vision_config.layer_norm_eps) 53 | self.eva_vit_post_layer_norm.load_state_dict(torch.load("experiments/eva_vit_post_layernorm.pth", map_location='cpu')) 54 | if self.use_vit: 55 | blip2_config.vision_config.torch_dtype = self.dtype 56 | self.vision_model = Blip2VisionModel(blip2_config.vision_config) 57 | self.vision_model.load_state_dict(torch.load("experiments/eva_vit_g.pth", map_location='cpu')) 58 | for name, param in self.vision_model.named_parameters(): 59 | param.requires_grad = False 60 | self.vision_model.eval() 61 | 62 | print('loading Qformer') 63 | self.tokenizer = self.init_tokenizer(truncation_side="left") 64 | self.Qformer = self.init_Qformer(num_query_token=32, vision_width=blip2_config.vision_config.hidden_size, cross_attention_freq=2) 65 | self.Qformer.resize_token_embeddings(len(self.tokenizer)) 66 | self.Qformer.cls = None 67 | if self.model == 't5-xl': 68 | self.Qformer.load_state_dict(torch.load("experiments/qformer_t5.pth", map_location='cpu')) 69 | self.query_tokens = nn.Parameter(torch.load("experiments/query_tokens_t5.pth", map_location='cpu')) 70 | 71 | if self.mode == 'grounding' or self.mode == 'oracle': 72 | print('loading eva clip branch encoder') 73 | self.branch_encoder = Clip_Branch_Encoder() 74 | for name, param in self.branch_encoder.named_parameters(): 75 | param.requires_grad = False 76 | self.branch_encoder.eval() 77 | 78 | from models.Transformer import create_sinusoidal_embeddings 79 | print('loading frame_embeds') 80 | self.frame_embeds = nn.Embedding(self.frame_num, 768) 81 | create_sinusoidal_embeddings( 82 | n_pos=self.frame_num, 83 | dim=768, 84 | out=self.frame_embeds.weight, 85 | ) 86 | 87 | if self.mode == 'grounding': 88 | print('loading Grounding') 89 | from models.grounding_module import Grounding 90 | self.grounding = Grounding(dim=1024, heads=4, dropout=0.3, window_size=self.window_size, frame_num=self.frame_num, width=self.width, temperature=self.temperature) 91 | 92 | print('loading T5') 93 | if self.model == 't5-xl': 94 | self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl", truncation_side='left') 95 | self.t5_output_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl", truncation_side='right') 96 | t5_config = T5Config.from_pretrained("google/flan-t5-xl") 97 | t5_config.dense_act_fn = "gelu" 98 | self.t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", config=t5_config, torch_dtype=self.dtype) 99 | 100 | print('loading llm_proj') 101 | self.t5_proj = nn.Linear(768, self.t5_model.config.hidden_size) 102 | if self.model == 't5-xl': 103 | self.t5_proj.load_state_dict(torch.load("experiments/llm_proj_t5.pth", map_location='cpu')) 104 | 105 | print("Frozen ViT") 106 | for name, param in self.eva_vit_post_layer_norm.named_parameters(): 107 | param.requires_grad = False 108 | self.eva_vit_post_layer_norm.eval() 109 | 110 | if self.use_lora: 111 | from peft import get_peft_model, LoraConfig, TaskType 112 | print("LORA LLM") 113 | peft_config = LoraConfig( 114 | task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False if self.training else True, r=16, lora_alpha=32, lora_dropout=0.05, 115 | target_modules = ["q", "v"], 116 | ) 117 | self.t5_model = get_peft_model(self.t5_model, peft_config) 118 | else: 119 | print("Frozen t5") 120 | for name, param in self.t5_model.named_parameters(): 121 | param.requires_grad = False 122 | param.data = param.data.bfloat16() 123 | self.t5_model.eval() 124 | 125 | @property 126 | def device(self): 127 | return list(self.parameters())[0].device 128 | 129 | def maybe_autocast(self): 130 | # if on cpu, don't use autocast 131 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 132 | enable_autocast = self.device != torch.device("cpu") 133 | if enable_autocast: 134 | return torch.cuda.amp.autocast(dtype=self.dtype) 135 | else: 136 | return contextlib.nullcontext() 137 | 138 | def init_tokenizer(self, truncation_side="right"): 139 | tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", truncation_side=truncation_side) 140 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 141 | return tokenizer 142 | 143 | def init_Qformer(self, num_query_token, vision_width, cross_attention_freq=2): 144 | encoder_config = BertConfig.from_pretrained("google-bert/bert-base-uncased") 145 | encoder_config.encoder_width = vision_width 146 | encoder_config.add_cross_attention = True 147 | encoder_config.cross_attention_freq = cross_attention_freq 148 | encoder_config.query_length = num_query_token 149 | encoder_config.torch_dtype = self.dtype 150 | Qformer = BertLMHeadModel(config=encoder_config) 151 | return Qformer 152 | 153 | def HardTopK(self, k, x): 154 | topk_results = torch.topk(x, k=k, dim=-1, sorted=False) 155 | indices = topk_results.indices # b, k 156 | indices = torch.sort(indices, dim=-1).values 157 | return indices 158 | 159 | def generate_gauss_weight(self, center, width): 160 | # code copied from https://github.com/minghangz/cpl 161 | weight = torch.linspace(0, 1, self.frame_num) 162 | weight = weight.view(1, -1).expand(center.size(0), -1).to(center.device) 163 | center = center.unsqueeze(-1) 164 | width = width.unsqueeze(-1).clamp(1e-2) / 9 165 | w = 0.3989422804014327 #1/(math.sqrt(2*math.pi)) 166 | weight = w/width*torch.exp(-(weight-center)**2/(2*width**2)) 167 | return weight 168 | 169 | def return_gmm_scores(self, label_probs): 170 | centers = self.HardTopK(self.window_size, label_probs) # [bs, window_size] 171 | centers = centers/(self.frame_num-1) # [bs, window_size] 172 | centers = rearrange(centers, "b w -> (b w)") # [bs*window_size] 173 | gaussians = self.generate_gauss_weight(centers, torch.tensor([self.width for i in range(centers.shape[0])]).to(self.device)) 174 | gaussians = rearrange(gaussians, "(b w) t -> b w t", w=self.window_size) # [bs, window_size, frame_num] 175 | gaussians = torch.sum(gaussians, dim=1) # [bs, frame_num] 176 | gaussians = gaussians/gaussians.max(dim=-1, keepdim=True)[0] # [bs, frame_num] 177 | return gaussians 178 | 179 | def spatial_augmented(self, spatial_image_embeds, video_query_tokens, Qformer_atts, text_Qformer): 180 | bs = spatial_image_embeds.shape[0] // self.frame_num 181 | spatial_query_tokens = self.query_tokens.expand(bs, -1, -1) # [bs, 32, 768] 182 | spatial_image_embeds = rearrange(spatial_image_embeds, "(b t) n d -> b (t n) d", t=self.frame_num) # [bs, 257*frame_count, 1408] 183 | spatial_image_atts = torch.ones(spatial_image_embeds.size()[:-1], dtype=torch.long).to(spatial_image_embeds.device) # [bs, 257*frame_count] 184 | with self.maybe_autocast(): 185 | query_output = self.Qformer.bert( 186 | text_Qformer.input_ids, 187 | attention_mask = Qformer_atts, 188 | query_embeds=spatial_query_tokens, 189 | encoder_hidden_states=spatial_image_embeds, 190 | encoder_attention_mask=spatial_image_atts, 191 | return_dict=True, 192 | ) 193 | spatial_query_tokens = query_output.last_hidden_state[:,:spatial_query_tokens.size(1),:] # [bs, 32, 768] 194 | video_query_tokens = torch.cat([video_query_tokens, spatial_query_tokens], dim=1) # [bs, (4+1)*32, 768] 195 | return video_query_tokens 196 | 197 | def uniform_concat(self, samples): 198 | if self.use_vit: 199 | pixel_values = samples["pixel_values"] # [bs, frame_num, 3, 224, 224] 200 | bs, framecount, _, _, _ = pixel_values.shape 201 | pixel_values = rearrange(pixel_values, "b t c h w -> (b t) c h w") 202 | frame_features_wo_norm = self.vision_model(pixel_values=pixel_values).last_hidden_state_without_norm 203 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "(b t) n d -> b t n d", t=self.frame_num) # [bs, frame_num, 257, 1408] 204 | else: 205 | frame_features_wo_norm = samples["frame_features"] # [bs, frame_num, 257, 1408] 206 | bs, framecount, _, _ = frame_features_wo_norm.shape 207 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "b t n d -> (b t) n d") # [bs*frame_num, 257, 1408] 208 | image_embeds = self.eva_vit_post_layer_norm(frame_features_wo_norm) # [bs*frame_num, 257, 1408] 209 | spatial_image_embeds = image_embeds # [bs*frame_num, 257, 1408] 210 | 211 | if self.window_size < self.frame_num: 212 | image_embeds = rearrange(image_embeds, "(b t) n d -> b t n d", t=self.frame_num) # [bs, frame_count, 257, 1408] 213 | def generate_uniform_elements(T, W): 214 | return torch.linspace(0, T-1, W, dtype=torch.int) 215 | indicators = generate_uniform_elements(self.frame_num, self.window_size).repeat(bs, 1) # [bs, window_size] 216 | selection_mask = torch.zeros(bs, self.window_size, self.frame_num).to(self.device) # [bs, window_size, frame_num] 217 | for i in range(bs): 218 | for j in range(self.window_size): 219 | selection_mask[i][j][indicators[i][j]] = 1 220 | image_embeds = torch.einsum("b k t, b t n d -> b k n d", selection_mask, image_embeds) # [bs, window_size, 257, 1408] 221 | image_embeds = rearrange(image_embeds, "b w n d -> (b w) n d") # [bs*window_size, 257, 1408] 222 | 223 | query_tokens = self.query_tokens.expand(bs*self.window_size, -1, -1) # [bs*window_size, 32, 768] 224 | text_Qformer = self.tokenizer( 225 | samples["text_input"], 226 | padding='longest', 227 | truncation=True, 228 | max_length=self.max_input_txt_len, 229 | return_tensors="pt", 230 | ).to(image_embeds.device) 231 | query_atts = torch.ones((bs, self.query_tokens.shape[1]), dtype=torch.long).to(image_embeds.device) # [bs, 32] 232 | Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1) # [bs, 32+seq_len] 233 | 234 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) 235 | with self.maybe_autocast(): 236 | query_output = self.Qformer.bert( 237 | text_Qformer.input_ids.repeat(self.window_size, 1), 238 | attention_mask = Qformer_atts.repeat(self.window_size, 1), 239 | query_embeds=query_tokens, 240 | encoder_hidden_states=image_embeds, 241 | encoder_attention_mask=image_atts, 242 | return_dict=True, 243 | ) 244 | query_tokens = query_output.last_hidden_state[:,:query_tokens.size(1),:] # [bs*window_size, 32, 768] 245 | video_query_tokens = rearrange(query_tokens, "(b t) n d -> b (t n) d", t=self.window_size) # [bs, window_size*32, 768] 246 | 247 | if self.use_spatial: 248 | video_query_tokens = self.spatial_augmented(spatial_image_embeds, video_query_tokens, Qformer_atts, text_Qformer) 249 | inputs_llm = self.t5_proj(video_query_tokens) 250 | atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(inputs_llm.device) 251 | 252 | regression_loss = torch.tensor(0).to(self.device) 253 | infoNCE_loss = torch.tensor(0).to(self.device) 254 | 255 | return inputs_llm, atts_llm, regression_loss, infoNCE_loss 256 | 257 | def oracle_concat(self, samples): 258 | if self.use_vit: 259 | pixel_values = samples["pixel_values"] # [bs, frame_count, 3, 224, 224] 260 | bs, framecount, _, _, _ = pixel_values.shape 261 | pixel_values = rearrange(pixel_values, "b t c h w -> (b t) c h w") 262 | frame_features_wo_norm = self.vision_model(pixel_values=pixel_values).last_hidden_state_without_norm 263 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "(b t) n d -> b t n d", t=framecount) # [bs, frame_count, 257, 1408] 264 | else: 265 | frame_features_wo_norm = samples["frame_features"] # [bs, frame_count, 257, 1408] 266 | bs, framecount, _, _ = frame_features_wo_norm.shape 267 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "b t n d -> (b t) n d") # [bs*frame_count, 257, 1408] 268 | image_embeds = self.eva_vit_post_layer_norm(frame_features_wo_norm) # [bs*frame_count, 257, 1408] 269 | spatial_image_embeds = image_embeds 270 | 271 | image_embeds_for_selection = self.branch_encoder.encode_image(frame_features_wo_norm) # [bs*frame_count, 1024] 272 | image_embeds_for_selection = rearrange(image_embeds_for_selection, "(b t) d -> b t d", t=framecount) # [bs, frame_count, 1024] 273 | label_embeds_for_selection = self.branch_encoder.encode_text(clip_text_tokenizer(samples["answers_text"]).to(image_embeds.device)) # [bs,1024] # answers_text, questions 274 | def l2_norm(x): 275 | return x/x.norm(dim=-1, keepdim=True) 276 | 277 | label_probs = torch.einsum("b t d, b d -> b t", l2_norm(image_embeds_for_selection), l2_norm(label_embeds_for_selection)) 278 | image_embeds = rearrange(image_embeds, "(b t) n d -> b t n d", t=framecount) # [bs, frame_count, 257, 1408] 279 | 280 | label_probs = self.return_gmm_scores(label_probs) 281 | indicators = self.HardTopK(self.window_size, label_probs) # [bs, window_size] 282 | selection_mask = torch.zeros(bs, self.window_size, framecount).to(self.device) # [bs, window_size, frame_num] 283 | for i in range(bs): 284 | for j in range(self.window_size): 285 | selection_mask[i][j][indicators[i][j]] = 1 286 | image_embeds = torch.einsum("b k t, b t n d -> b k n d", selection_mask, image_embeds) # [bs, window_size, 257, 1408] 287 | 288 | image_embeds = rearrange(image_embeds, "b t n d -> (b t) n d") # [bs*4, 257, 1408] 289 | query_tokens = self.query_tokens.expand(bs*self.window_size, -1, -1) # [bs*frame_count, 32, 768] 290 | text_Qformer = self.tokenizer( 291 | samples["text_input"], 292 | padding='longest', 293 | truncation=True, 294 | max_length=self.max_input_txt_len, 295 | return_tensors="pt", 296 | ).to(image_embeds.device) 297 | query_atts = torch.ones((bs, self.query_tokens.shape[1]), dtype=torch.long).to(image_embeds.device) # [bs, 32] 298 | Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1) # [bs, 32+seq_len] 299 | 300 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) 301 | with self.maybe_autocast(): 302 | query_output = self.Qformer.bert( 303 | text_Qformer.input_ids.repeat(self.window_size, 1), 304 | attention_mask = Qformer_atts.repeat(self.window_size, 1), 305 | query_embeds=query_tokens, 306 | encoder_hidden_states=image_embeds, 307 | encoder_attention_mask=image_atts, 308 | return_dict=True, 309 | ) 310 | query_tokens = query_output.last_hidden_state[:,:query_tokens.size(1),:] # [bs*frame_count, 32, 768] 311 | query_tokens = rearrange(query_tokens, "(b w) n d -> b w n d", w=self.window_size) # [bs, frame_count, 32, 768] 312 | 313 | position_ids = torch.arange(self.window_size, dtype=torch.long, device=query_tokens.device) 314 | position_ids = position_ids.unsqueeze(0).expand(bs, -1) 315 | frame_embedding = self.frame_embeds(position_ids) 316 | frame_embedding = frame_embedding.unsqueeze(-2) 317 | query_tokens = query_tokens + frame_embedding # [bs, frame_count, 32, 768] 318 | video_query_tokens = rearrange(query_tokens, "b w n d -> b (w n) d") # [bs, 4*32, 768] 319 | 320 | if self.use_spatial: 321 | video_query_tokens = self.spatial_augmented(spatial_image_embeds, video_query_tokens, Qformer_atts, text_Qformer) 322 | inputs_llm = self.t5_proj(video_query_tokens) 323 | atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(inputs_llm.device) 324 | 325 | regression_loss = torch.tensor(0).to(self.device) 326 | infoNCE_loss = torch.tensor(0).to(self.device) 327 | 328 | return inputs_llm, atts_llm, regression_loss, infoNCE_loss 329 | 330 | def grounding_concat(self, samples): 331 | if self.use_vit: 332 | pixel_values = samples["pixel_values"] # [bs, frame_count, 3, 224, 224] 333 | bs, framecount, _, _, _ = pixel_values.shape 334 | pixel_values = rearrange(pixel_values, "b t c h w -> (b t) c h w") 335 | frame_features_wo_norm = self.vision_model(pixel_values=pixel_values).last_hidden_state_without_norm 336 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "(b t) n d -> b t n d", t=framecount) # [bs, frame_count, 257, 1408] 337 | else: 338 | frame_features_wo_norm = samples["frame_features"] # [bs, frame_count, 257, 1408] 339 | bs, framecount, _, _ = frame_features_wo_norm.shape 340 | frame_features_wo_norm = rearrange(frame_features_wo_norm, "b t n d -> (b t) n d") # [bs*frame_count, 257, 1408] 341 | image_embeds = self.eva_vit_post_layer_norm(frame_features_wo_norm) # [bs*frame_count, 257, 1408] 342 | spatial_image_embeds = image_embeds 343 | 344 | image_embeds_for_selection = self.branch_encoder.encode_image(frame_features_wo_norm) # [bs*frame_count, 1024] 345 | image_embeds_for_selection = rearrange(image_embeds_for_selection, "(b t) d -> b t d", t=framecount) # [bs, frame_count, 1024] 346 | label_embeds_for_selection = self.branch_encoder.encode_text(clip_text_tokenizer(samples["answers_text"]).to(image_embeds.device)) # [bs,1024] 347 | 348 | if 'options_a0' not in samples.keys(): 349 | question_embeds_for_selection = self.branch_encoder.encode_text(clip_text_tokenizer(samples["questions"]).to(image_embeds.device)) # [bs,1024] 350 | question_embeds_for_selection = question_embeds_for_selection.unsqueeze(1) # [bs,1, 1024] 351 | else: 352 | question_embeds_for_selection = self.branch_encoder.encode_text(clip_text_tokenizer(samples["questions"]).to(image_embeds.device)) # [bs,1024] 353 | options_a0 = self.branch_encoder.encode_text(clip_text_tokenizer(samples["options_a0"]).to(image_embeds.device)) # [bs,1024] 354 | options_a1 = self.branch_encoder.encode_text(clip_text_tokenizer(samples["options_a1"]).to(image_embeds.device)) # [bs,1024] 355 | if 'options_a2' not in samples.keys(): 356 | question_embeds_for_selection = torch.stack([question_embeds_for_selection,options_a0,options_a1], dim=1) # [bs,3, 1024] 357 | else: 358 | options_a2 = self.branch_encoder.encode_text(clip_text_tokenizer(samples["options_a2"]).to(image_embeds.device)) # [bs,1024] 359 | options_a3 = self.branch_encoder.encode_text(clip_text_tokenizer(samples["options_a3"]).to(image_embeds.device)) # [bs,1024] 360 | if 'options_a4' not in samples.keys(): 361 | question_embeds_for_selection = torch.stack([question_embeds_for_selection,options_a0,options_a1,options_a2,options_a3], dim=1) # [bs,5, 1024] 362 | else: 363 | options_a4 = self.branch_encoder.encode_text(clip_text_tokenizer(samples["options_a4"]).to(image_embeds.device)) # [bs,1024] 364 | question_embeds_for_selection = torch.stack([question_embeds_for_selection,options_a0,options_a1,options_a2,options_a3,options_a4], dim=1) # [bs,6, 1024] 365 | 366 | def l2_norm(x): 367 | return x/x.norm(dim=-1, keepdim=True) 368 | 369 | label_probs = torch.einsum("b t d, b d -> b t", l2_norm(image_embeds_for_selection), l2_norm(label_embeds_for_selection)) 370 | image_embeds = rearrange(image_embeds, "(b t) n d -> b t n d", t=framecount) # [bs, frame_count, 257, 1408] 371 | image_embeds, regression_loss, infoNCE_loss = self.grounding(Q=question_embeds_for_selection, K=image_embeds_for_selection, V=image_embeds, answer_embeds=label_embeds_for_selection, label_probs=label_probs, answers_id=samples["answers_id"].to(self.device) if 'answers_id' in samples.keys() else None) # [bs, 4, 257, 1408] 372 | 373 | image_embeds = rearrange(image_embeds, "b w n d -> (b w) n d") # [bs*4, 257, 1408] 374 | query_tokens = self.query_tokens.expand(bs*self.window_size, -1, -1) # [bs*frame_count, 32, 768] 375 | text_Qformer = self.tokenizer( 376 | samples["text_input"], 377 | padding='longest', 378 | truncation=True, 379 | max_length=self.max_input_txt_len, 380 | return_tensors="pt", 381 | ).to(image_embeds.device) 382 | query_atts = torch.ones((bs, self.query_tokens.shape[1]), dtype=torch.long).to(image_embeds.device) # [bs, 32] 383 | Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1) # [bs, 32+seq_len] 384 | 385 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) 386 | with self.maybe_autocast(): 387 | query_output = self.Qformer.bert( 388 | text_Qformer.input_ids.repeat(self.window_size, 1), 389 | attention_mask = Qformer_atts.repeat(self.window_size, 1), 390 | query_embeds=query_tokens, 391 | encoder_hidden_states=image_embeds, 392 | encoder_attention_mask=image_atts, 393 | return_dict=True, 394 | ) 395 | query_tokens = query_output.last_hidden_state[:,:query_tokens.size(1),:] # [bs*frame_count, 32, 768] 396 | query_tokens = rearrange(query_tokens, "(b w) n d -> b w n d", w=self.window_size) # [bs, frame_count, 32, 768] 397 | 398 | position_ids = torch.arange(self.window_size, dtype=torch.long, device=query_tokens.device) 399 | position_ids = position_ids.unsqueeze(0).expand(bs, -1) 400 | frame_embedding = self.frame_embeds(position_ids) 401 | frame_embedding = frame_embedding.unsqueeze(-2) 402 | query_tokens = query_tokens + frame_embedding # [bs, frame_count, 32, 768] 403 | video_query_tokens = rearrange(query_tokens, "b w n d -> b (w n) d") # [bs, 4*32, 768] 404 | 405 | if self.use_spatial: 406 | video_query_tokens = self.spatial_augmented(spatial_image_embeds, video_query_tokens, Qformer_atts, text_Qformer) 407 | inputs_llm = self.t5_proj(video_query_tokens) 408 | atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(inputs_llm.device) 409 | 410 | return inputs_llm, atts_llm, regression_loss, infoNCE_loss 411 | 412 | def forward(self, samples): 413 | # print('-----------------') 414 | # print(samples["text_input"]) 415 | # print(samples["text_output"]) 416 | # print('-----------------') 417 | 418 | if self.mode == 'grounding': 419 | inputs_t5, atts_t5, regression_loss, infoNCE_loss = self.grounding_concat(samples) 420 | elif self.mode == 'uniform': 421 | inputs_t5, atts_t5, regression_loss, infoNCE_loss = self.uniform_concat(samples) 422 | elif self.mode == 'oracle': 423 | inputs_t5, atts_t5, regression_loss, infoNCE_loss = self.oracle_concat(samples) 424 | 425 | with self.maybe_autocast(): 426 | input_tokens = self.t5_tokenizer( 427 | samples["text_input"], 428 | padding="longest", 429 | truncation=True, 430 | max_length=self.max_input_txt_len, 431 | return_tensors="pt", 432 | ).to(inputs_t5.device) 433 | output_tokens = self.t5_output_tokenizer( 434 | samples["text_output"], 435 | padding="longest", 436 | truncation=True, 437 | max_length=self.max_output_txt_len, 438 | return_tensors="pt", 439 | ).to(inputs_t5.device) 440 | 441 | encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) 442 | inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) 443 | inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) 444 | 445 | targets = output_tokens.input_ids.masked_fill( 446 | output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 447 | ) 448 | outputs = self.t5_model( 449 | inputs_embeds=inputs_embeds, 450 | attention_mask=encoder_atts, 451 | decoder_attention_mask=output_tokens.attention_mask, 452 | return_dict=True, 453 | labels=targets, 454 | ) 455 | vqa_loss = outputs.loss 456 | 457 | return { 458 | "loss": vqa_loss+regression_loss+infoNCE_loss, 459 | "vqa_loss": vqa_loss, 460 | "regression_loss": regression_loss, 461 | "infoNCE_loss": infoNCE_loss, 462 | } 463 | 464 | @torch.no_grad() 465 | def generate( 466 | self, 467 | samples, 468 | **generate_kwargs 469 | ): 470 | 471 | if self.mode == 'grounding': 472 | inputs_t5, atts_t5, _, _ = self.grounding_concat(samples) 473 | elif self.mode == 'uniform': 474 | inputs_t5, atts_t5, _, _ = self.uniform_concat(samples) 475 | elif self.mode == 'oracle': 476 | inputs_t5, atts_t5, _, _ = self.oracle_concat(samples) 477 | 478 | with self.maybe_autocast(): 479 | input_tokens = self.t5_tokenizer( 480 | samples["text_input"], 481 | padding="longest", 482 | return_tensors="pt" 483 | ).to(inputs_t5.device) 484 | 485 | encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) 486 | inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) 487 | inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) 488 | 489 | outputs = self.t5_model.generate( 490 | inputs_embeds=inputs_embeds, 491 | attention_mask=encoder_atts, 492 | **generate_kwargs 493 | ) 494 | output_text = self.t5_tokenizer.batch_decode( 495 | outputs, skip_special_tokens=True 496 | ) 497 | 498 | return output_text 499 | 500 | --------------------------------------------------------------------------------