├── vebench ├── blip_models │ ├── __init__.py │ ├── blip_itm.py │ ├── blip_nlvr.py │ ├── blip_vqa.py │ ├── blip.py │ ├── blip_retrieval.py │ ├── blip_pretrain.py │ └── vit.py ├── models │ ├── backbone │ │ ├── __init__.py │ │ ├── BLIP_configs │ │ │ ├── retrieval_msrvtt.yaml │ │ │ ├── nocaps.yaml │ │ │ ├── nlvr.yaml │ │ │ ├── med_config.json │ │ │ ├── bert_config.json │ │ │ ├── pretrain.yaml │ │ │ ├── vqa.yaml │ │ │ ├── caption_coco.yaml │ │ │ ├── retrieval_coco.yaml │ │ │ └── retrieval_flickr.yaml │ │ ├── blip.py │ │ ├── uniformer_backbone.py │ │ └── conv_backbone.py │ ├── __init__.py │ ├── network.py │ ├── tradition.py │ ├── text_alignment.py │ ├── fidelity.py │ └── head.py ├── __init__.py ├── README.md ├── configs │ ├── text.yaml │ ├── doublestream.yaml │ └── dover.yaml ├── infer.py ├── evaluator.py └── preprocess.py ├── MANIFEST.in ├── assets ├── .DS_Store ├── dst.mp4 ├── dst2.mp4 ├── src.mp4 ├── scores.jpg └── overview.jpg ├── requirements.txt ├── .gitignore ├── test.py ├── setup.py ├── LICENSE └── README.md /vebench/blip_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vebench/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vebench/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import VEBenchModel -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include config/*.yaml 2 | include models/backbone/BLIP_configs 3 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/.DS_Store -------------------------------------------------------------------------------- /assets/dst.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/dst.mp4 -------------------------------------------------------------------------------- /assets/dst2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/dst2.mp4 -------------------------------------------------------------------------------- /assets/src.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/src.mp4 -------------------------------------------------------------------------------- /assets/scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/scores.jpg -------------------------------------------------------------------------------- /assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/overview.jpg -------------------------------------------------------------------------------- /vebench/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .network import EvalEditModel 2 | 3 | __all__=['EvalEditModel'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord 2 | einops 3 | fairscale 4 | numpy 5 | timm 6 | transformers 7 | scikit-video 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ckpts 2 | __pycache__ 3 | data 4 | */__pycache__ 5 | */*/__pycache__ 6 | */*/*/__pycache__ 7 | build 8 | dist 9 | *egg-info 10 | .DS_Store 11 | */.DS_Store 12 | */*/.DS_Store -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from vebench import VEBenchModel 2 | 3 | evaluator = VEBenchModel() 4 | 5 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4') 6 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4') 7 | print(score1, score2) -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/retrieval_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 6 | 7 | # size of vit model; base or large 8 | vit: 'base' 9 | batch_size: 64 10 | k_test: 128 11 | image_size: 384 12 | num_frm_test: 8 -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 6 | 7 | vit: 'base' 8 | batch_size: 32 9 | 10 | image_size: 384 11 | 12 | max_length: 20 13 | min_length: 5 14 | num_beams: 3 15 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/NLVR2/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' 6 | 7 | #size of vit model; base or large 8 | vit: 'base' 9 | batch_size_train: 16 10 | batch_size_test: 64 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | max_epoch: 15 14 | 15 | image_size: 384 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-5 20 | min_lr: 0 21 | 22 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', 2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /vebench/README.md: -------------------------------------------------------------------------------- 1 | ## Easy Use 2 | VE-Bench can be installed with a single ``pip`` command. Since the model employs normalization during training, its output does not represent absolute scores. We **recommend performing comparisons between video pairs**, as demonstrated below: 3 | ``` 4 | pip install vebench 5 | ``` 6 | When comparing videos: 7 | ``` 8 | from vebench import VEBenchModel 9 | 10 | evaluator = VEBenchModel() 11 | 12 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4') 13 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4') 14 | print(score1, score2) # Score1: 1.3563, Score2: 0.66194 15 | ``` -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ 2 | vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 8 | 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 16 12 | batch_size_test: 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | name = 'vebench' 4 | 5 | long_description='Please refer to https://github.com/littlespray/VE-Bench' 6 | 7 | setup( 8 | name=name, # 包名同工程名,这样导入包的时候更有对应性 9 | version='1.0.0', 10 | author="Shangkun Sun", 11 | license="MIT Licence", 12 | author_email='sunshk@stu.pku.edu.cn', 13 | description="Evaluator for Text-driven Video Editing", 14 | packages=find_packages(), 15 | python_requires='>=3', 16 | long_description=long_description, 17 | # 设置依赖包 18 | install_requires=['torch', 'decord', 'einops', 'fairscale', 'numpy', 'timm', 'transformers', 'sk-video'], 19 | include_package_data=True, # 包含额外的非Python文件 20 | package_data={ 21 | '': ['configs/*.yaml', 'models/backbone/BLIP_configs/*'], # 匹配目录下的所有文件 22 | }, 23 | ) 24 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 12 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 256 28 | negative_all_rank: True 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /vebench/models/backbone/BLIP_configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 10 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 128 28 | negative_all_rank: False 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sunshk1227 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /vebench/configs/text.yaml: -------------------------------------------------------------------------------- 1 | name: e-bench-blip-test 2 | num_epochs: 0 3 | l_num_epochs: 20 4 | warmup_epochs: 2.5 5 | ema: true 6 | save_model: true 7 | batch_size: 8 8 | num_workers: 6 9 | split_seed: 42 #useless 10 | 11 | wandb: 12 | project_name: e-bench-blip-test 13 | 14 | data: 15 | videoQA: 16 | type: ViewDecompositionDataset 17 | args: 18 | weight: 0.443 19 | phase: train 20 | anno_file: ../e-bench-db/label.txt 21 | data_prefix: ../e-bench-db/edited 22 | sample_types: 23 | time: 24 | size_h: 224 25 | size_w: 224 26 | clip_len: 16 27 | frame_interval: 1 28 | t_frag: 16 29 | num_clips: 1 30 | 31 | model: 32 | type: VideoTextAlignmentModel 33 | args: 34 | backbone: 35 | time: 36 | #in_channels: 768 37 | type: blip 38 | pretrained: true 39 | checkpoint: true 40 | blip_type: multimodal_text 41 | 42 | backbone_preserve_keys: time 43 | divide_head: true 44 | use_tn: true 45 | vqa_head: 46 | #in_channels: 768 47 | hidden_channels: 64 48 | attn_pool3d: true # 代码里默认为false 49 | text_pool3d: false 50 | 51 | optimizer: 52 | lr: !!float 6.25e-4 53 | backbone_lr_mult: !!float 1e-1 54 | wd: 0.05 55 | 56 | test_load_path: [./ckpts/e-bench-blip_head_videoQA_9_eval_s_finetuned.pth] -------------------------------------------------------------------------------- /vebench/configs/doublestream.yaml: -------------------------------------------------------------------------------- 1 | name: e-bench-uniformer-src-edit-test 2 | num_epochs: 10 3 | l_num_epochs: 20 4 | warmup_epochs: 2.5 5 | ema: true 6 | save_model: true 7 | batch_size: 8 8 | num_workers: 6 9 | split_seed: 42 #useless 10 | 11 | wandb: 12 | project_name: e-bench-uniformer-src-edit-test 13 | 14 | data: 15 | videoQA: 16 | type: ViewDecompositionDataset 17 | args: 18 | weight: 0.443 19 | phase: train 20 | anno_file: ../e-bench-db/label.txt 21 | data_prefix: ../e-bench-db/src/ 22 | sample_types: 23 | technical: 24 | fragments_h: 7 25 | fragments_w: 7 26 | fsize_h: 32 27 | fsize_w: 32 28 | aligned: 32 29 | clip_len: 32 30 | frame_interval: 1 31 | num_clips: 1 32 | aesthetic: 33 | size_h: 224 34 | size_w: 224 35 | clip_len: 32 36 | frame_interval: 1 37 | t_frag: 32 38 | num_clips: 1 39 | 40 | model: 41 | type: DoubleStreamModel 42 | args: 43 | backbone: 44 | technical: 45 | type: uniformerv2_b16 46 | checkpoint: true 47 | pretrained: 48 | aesthetic: 49 | type: uniformerv2_b16 50 | pretrained: true # 代码里默认为True 51 | in22k: false # 代码里默认为False 52 | 53 | backbone_preserve_keys: technical,aesthetic 54 | divide_head: true 55 | use_tn: true 56 | vqa_head: 57 | #in_channels: 768 58 | hidden_channels: 64 59 | attn_pool3d: true # 代码里默认为false 60 | text_pool3d: false 61 | 62 | optimizer: 63 | lr: !!float 6.25e-4 64 | backbone_lr_mult: !!float 1e-1 65 | wd: 0.05 66 | 67 | test_load_path: [./ckpts/e-bench-uniformer-src-edit_head_videoQA_3_eval_s_finetuned.pth] -------------------------------------------------------------------------------- /vebench/models/network.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | 7 | from .tradition import DOVER 8 | from .fidelity import DoubleStreamModel 9 | from .text_alignment import VideoTextAlignmentModel 10 | 11 | from huggingface_hub import snapshot_download 12 | 13 | class EvalEditModel(nn.Module): 14 | def __init__(self, dover_opt, doublestream_opt, text_opt, model_path='ckpts'): 15 | super().__init__() 16 | 17 | if not os.path.isdir(model_path): 18 | model_path = snapshot_download('sunshk/vebench') 19 | 20 | # build model 21 | self.traditional_branch = DOVER(**dover_opt['model']['args'],model_path=model_path).eval() 22 | self.fidelity_branch = DoubleStreamModel(**doublestream_opt['model']['args'], model_path=model_path).eval() 23 | self.text_branch = VideoTextAlignmentModel(**text_opt['model']['args'], model_path=model_path).eval() 24 | 25 | # load_weight 26 | self.load_ckpt(model_path) 27 | 28 | 29 | def load_ckpt(self, model_path): 30 | # print('111') 31 | self.traditional_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-dover_head_videoQA_0_eval_n_finetuned.pth'), map_location='cpu')['state_dict']) 32 | self.fidelity_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-uniformer-src-edit_head_videoQA_3_eval_s_finetuned.pth'),map_location='cpu')['state_dict'],strict=False) 33 | self.text_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-blip_head_videoQA_9_eval_s_finetuned.pth'), map_location='cpu')['state_dict'],strict=False) 34 | 35 | def forward(self, src_video, edit_video, prompt): 36 | traditional_score = self.traditional_branch(edit_video,reduce_scores=True) 37 | fidelity_score = self.fidelity_branch(src_video, edit_video) 38 | text_score = self.text_branch(edit_video,prompts=prompt) 39 | # the weight of each score is pre-computed within each branch 40 | return (traditional_score + fidelity_score[0] + text_score[0]).item() 41 | 42 | 43 | 44 | if __name__ == "__main__": 45 | eval_model=EvalEditModel() 46 | -------------------------------------------------------------------------------- /vebench/configs/dover.yaml: -------------------------------------------------------------------------------- 1 | name: e-bench-dover-test 2 | num_epochs: 20 3 | l_num_epochs: 40 4 | warmup_epochs: 2.5 5 | ema: true 6 | save_model: true 7 | batch_size: 8 8 | num_workers: 6 9 | split_seed: 42 #useless 10 | 11 | wandb: 12 | project_name: e-bench-dover-test 13 | 14 | data: 15 | videoQA: 16 | type: ViewDecompositionDataset 17 | args: 18 | weight: 0.443 19 | phase: train 20 | anno_file: ../e-bench-db/label.txt 21 | data_prefix: ../e-bench-db/edited 22 | sample_types: 23 | technical: 24 | fragments_h: 7 25 | fragments_w: 7 26 | fsize_h: 32 27 | fsize_w: 32 28 | aligned: 32 29 | clip_len: 32 30 | frame_interval: 1 31 | num_clips: 1 32 | aesthetic: 33 | size_h: 224 34 | size_w: 224 35 | clip_len: 32 36 | frame_interval: 1 37 | t_frag: 32 38 | num_clips: 1 39 | # time: 40 | # size_h: 224 41 | # size_w: 224 42 | # clip_len: 16 43 | # frame_interval: 1 44 | # t_frag: 16 45 | # num_clips: 1 46 | 47 | model: 48 | type: DOVER 49 | args: 50 | backbone: 51 | technical: 52 | type: swin_tiny_grpb 53 | checkpoint: true 54 | pretrained: 55 | aesthetic: 56 | type: conv_tiny 57 | pretrained: true # 代码里默认为True 58 | in22k: false # 代码里默认为False 59 | 60 | backbone_preserve_keys: technical,aesthetic 61 | divide_head: true 62 | vqa_head: 63 | #in_channels: 768 64 | hidden_channels: 64 65 | attn_pool3d: true # 代码里默认为false 66 | text_pool3d: false 67 | 68 | optimizer: 69 | lr: !!float 6.25e-4 70 | backbone_lr_mult: !!float 1e-1 71 | wd: 0.05 72 | 73 | test_load_path: [./ckpts/e-bench-dover_head_videoQA_0_eval_n_finetuned.pth] -------------------------------------------------------------------------------- /vebench/blip_models/blip_itm.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from .blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /vebench/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models import EvalEditModel 5 | from preprocess import Processor 6 | import yaml 7 | import argparse 8 | 9 | #fixed seed 10 | seed_n = 42 11 | print('seed is ' + str(seed_n)) 12 | torch.manual_seed(seed_n) 13 | 14 | device='cuda' 15 | class EBenchModel(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | dover_config = 'configs/dover.yaml' 19 | doublestream_config = 'configs/doublestream.yaml' 20 | text_config = 'configs/text.yaml' 21 | 22 | with open(dover_config, "r") as f: 23 | dover_opt = yaml.safe_load(f) 24 | with open(doublestream_config, "r") as f: 25 | doublestream_opt = yaml.safe_load(f) 26 | with open(text_config, "r") as f: 27 | text_opt = yaml.safe_load(f) 28 | self.model = EvalEditModel().cuda() 29 | self.traditional_processor=Processor(dover_opt['data']['videoQA']['args']) 30 | self.text_pocessor=Processor(text_opt['data']['videoQA']['args']) 31 | self.doublestream_processor=Processor(doublestream_opt['data']['videoQA']['args']) 32 | 33 | 34 | def read_data(self, path): 35 | traditional_data=self.traditional_processor.preprocess(path) 36 | text_data=self.text_pocessor.preprocess(path) 37 | doublestream_data = self.doublestream_processor.preprocess(path) 38 | data={} 39 | for branch_data in[traditional_data,text_data,doublestream_data]: 40 | for key in branch_data.keys(): 41 | data[key]=branch_data[key] 42 | return data 43 | 44 | 45 | @torch.no_grad() 46 | def evaluate(self, prompt, src_path, dst_path): 47 | src_video = self.read_data(src_path) 48 | dst_video = self.read_data(dst_path) 49 | result = self.model(src_video, dst_video, prompt) 50 | return result 51 | 52 | if __name__ == "__main__": 53 | 54 | parser = argparse.ArgumentParser(description='Process video files with EBenchModel.') 55 | 56 | 57 | parser.add_argument('--single_test', action='store_true', help='Run a single test with specified paths and prompt.') 58 | parser.add_argument('--src_path', type=str, help='Source video path for single test.') 59 | parser.add_argument('--dst_path', type=str, help='Destination video path for single test.') 60 | parser.add_argument('--prompt', type=str, help='Prompt for single test.') 61 | parser.add_argument('--data_path', type=str, help='Data path for batch processing.') 62 | parser.add_argument('--label_path', type=str, help='Label path for batch processing.') 63 | 64 | 65 | args = parser.parse_args() 66 | 67 | 68 | if args.single_test: 69 | if args.src_path and args.dst_path and args.prompt: 70 | src_path = args.src_path 71 | dst_path = args.dst_path 72 | prompt = args.prompt 73 | ebench = EBenchModel() 74 | result = ebench.evaluate(prompt, src_path, dst_path) 75 | print(f"The result is {result}") 76 | else: 77 | print("Error: For single test, --src_path, --dst_path, and --prompt must be provided.") 78 | else: 79 | if args.data_path and args.label_path: 80 | data_path = args.data_path 81 | label_path = args.label_path 82 | src=[] 83 | dst=[] 84 | prompts=[] 85 | with open(label_path,'r') as file: 86 | for line in file: 87 | video_name,_,prompt=line.split('|') 88 | src+=[data_path+"src/"+video_name] 89 | dst += [data_path + "edited/" + video_name] 90 | prompts+=[prompt] 91 | ebench = EBenchModel() 92 | results=[] 93 | for src_path,dst_path,prompt in zip(src,dst,prompts): 94 | result = ebench.evaluate(prompt, src_path, dst_path) 95 | results+=[result] 96 | print(len(results)) 97 | with open("label.txt","w") as file: 98 | for src_path,result in zip(src,results): 99 | file.write(f"{src_path.split('/')[-1]},{result}\n") 100 | else: 101 | print("Error: For batch test, --data_path, --label_path must be provided.") 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [\[AAAI 25\] VE-Bench: Subjective-Aligned Benchmark Suite for Text-Driven Video Editing Quality Assessment](https://arxiv.org/abs/2408.11481) 2 | 3 |
4 | Shangkun Sun, Xiaoyu Liang, Songlin Fan, Wenxu Gao, Wei Gao*
5 | 6 | (* Corresponding author)
7 | 8 | from MMCAL, Peking University 9 |
10 | 11 | 14 | 15 | ## 🎦 Introduction 16 | TL;DR: VE-Bench is an evaluation suite for text-driven video editing, consisting of a quality assessment model to provide a human-aligned metric for edited videos, and a database containing rich video-prompt pairs and the corresponding human scores. 17 | 18 |
19 | 20 |
21 | Overview of the VE-Bench Suite 22 |
23 | 24 | VE-Bench DB contains a rich collection of source videos, including real-world videos, AIGC videos, and CG videos, covering various aspects such as people, objects, animals, and landscapes. It also includes a variety of editing instructions across different categories, including semantic editing like addition, removal, replacement, etc., as well as structural changes in size, shape, etc., and stylizations such as color, texture, etc. Additionally, it features editing results based on different video editing models. We conducted a subjective experiment involving 24 participants from diverse backgrounds, resulting in 28,080 score samples. We further trained VE-Bench QA model based on this data. The left image below shows the box plot of average scores obtained by each model during the subjective experiment, while the right image illustrates the scores for each model across different types of prompts. 25 | 26 |
27 | 28 |
29 | Left: Average score distributions of 8 editing methods.     Right: Performance on different types of prompts from previous video-editing methods. 30 |
31 | 32 | ## Easy Use 33 | VE-Bench can be installed with a single ``pip`` command. 34 | ``` 35 | pip install vebench 36 | ``` 37 | When comparing videos, you can use ``python test.py``, namely: 38 | ``` 39 | from vebench import VEBenchModel 40 | 41 | evaluator = VEBenchModel() 42 | 43 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4') 44 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4') 45 | print(score1, score2) # Score1: 1.3563, Score2: 0.66194 46 | ``` 47 | Since the model employs normalization during training, its output does not represent exactly absolute 1 \~ 10 scores, as demonstrated above. 48 | 49 | ## Database 50 | VE-Bench DB is available here. [baidu netdisk](https://pan.baidu.com/s/1D5y6ADXgz8PPHGCxROlNIQ?pwd=sggc) | [google drive](https://drive.google.com/file/d/1SBmXK6XKuyGTaV9LUQXfy5w82bsA3Nve/view?usp=sharing) 51 | 52 | 53 | ## Local Inference 54 | 55 | ### 💼 Preparation 56 | `` 57 | cd vebench 58 | `` 59 | 60 | You can also download all checkpoints from [google drive](https://drive.google.com/drive/folders/1kD82Ex90VP9A_AqjYV1J5DYvBQW-hkXa?usp=sharing) and put them into ``ckpts``. 61 | 62 | ### ✨ Usage 63 | To evaluate one single video: 64 | ``` 65 | python -m infer.py --single_test --src_path ${path_to_source_video} --dst_path ${path_to_dst_video} --prompt ${editing_prompt} 66 | 67 | # Run on example videos 68 | # python -m infer.py --single_test --src_path "./data/src/00433tokenflow_baby_gaze.mp4" --dst_path "./data/edited/00433tokenflow_baby_gaze.mp4" --prompt "A black-haired boy is turning his head" 69 | ``` 70 | 71 | 72 | To evaluate a set of videos: 73 | ``` 74 | python -m infer.py --data_path ${path_to_data_folder} --label_path ${path_to_prompt_txt_file} 75 | ``` 76 | 77 | ## 🙏 Acknowledgements 78 | Part of the code is developed based on [DOVER](https://github.com/VQAssessment/DOVER) and [BLIP](https://github.com/salesforce/BLIP). We would like to thank the authors for their contributions to the community. 79 | 80 | 81 | ## 📭 Contact 82 | If your have any comments or questions, feel free to contact [sunshk@stu.pku.edu.cn](lsunshk@stu.pku.edu.cn). 83 | 84 | 85 | 86 | ## 📖 BibTex 87 | ```bibtex 88 | @article{sun2024bench, 89 | title={VE-Bench: Subjective-Aligned Benchmark Suite for Text-Driven Video Editing Quality Assessment}, 90 | author={Sun, Shangkun and Liang, Xiaoyu and Fan, Songlin and Gao, Wenxu and Gao, Wei}, 91 | journal={arXiv preprint arXiv:2408.11481}, 92 | year={2024} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /vebench/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | from .models import EvalEditModel 6 | from .preprocess import Processor 7 | import yaml 8 | import argparse 9 | import random 10 | import numpy as np 11 | 12 | 13 | device='cuda' 14 | class VEBenchModel(nn.Module): 15 | def __init__(self, seed=42): 16 | super().__init__() 17 | 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | base_dir = os.path.dirname(os.path.abspath(__file__)) 25 | # 构造配置文件的绝对路径 26 | dover_config = os.path.join(base_dir, 'configs', 'dover.yaml') 27 | doublestream_config = os.path.join(base_dir, 'configs', 'doublestream.yaml') 28 | text_config = os.path.join(base_dir, 'configs', 'text.yaml') 29 | 30 | 31 | 32 | with open(dover_config, "r") as f: 33 | dover_opt = yaml.safe_load(f) 34 | with open(doublestream_config, "r") as f: 35 | doublestream_opt = yaml.safe_load(f) 36 | with open(text_config, "r") as f: 37 | text_opt = yaml.safe_load(f) 38 | self.model = EvalEditModel(dover_opt, doublestream_opt, text_opt).cuda() 39 | self.traditional_processor=Processor(dover_opt['data']['videoQA']['args']) 40 | self.text_pocessor=Processor(text_opt['data']['videoQA']['args']) 41 | self.doublestream_processor=Processor(doublestream_opt['data']['videoQA']['args']) 42 | 43 | 44 | def read_data(self, path): 45 | traditional_data=self.traditional_processor.preprocess(path) 46 | text_data=self.text_pocessor.preprocess(path) 47 | doublestream_data = self.doublestream_processor.preprocess(path) 48 | data={} 49 | for branch_data in[traditional_data,text_data,doublestream_data]: 50 | for key in branch_data.keys(): 51 | data[key]=branch_data[key] 52 | return data 53 | 54 | 55 | @torch.no_grad() 56 | def evaluate(self, prompt, src_path, dst_path): 57 | src_video = self.read_data(src_path) 58 | dst_video = self.read_data(dst_path) 59 | result = self.model(src_video, dst_video, prompt) 60 | return result 61 | 62 | if __name__ == "__main__": 63 | 64 | parser = argparse.ArgumentParser(description='Process video files with VEBenchModel.') 65 | 66 | 67 | parser.add_argument('--single_test', action='store_true', help='Run a single test with specified paths and prompt.') 68 | parser.add_argument('--src_path', type=str, help='Source video path for single test.') 69 | parser.add_argument('--dst_path', type=str, help='Destination video path for single test.') 70 | parser.add_argument('--prompt', type=str, help='Prompt for single test.') 71 | parser.add_argument('--data_path', type=str, help='Data path for batch processing.') 72 | parser.add_argument('--label_path', type=str, help='Label path for batch processing.') 73 | 74 | 75 | args = parser.parse_args() 76 | 77 | 78 | if args.single_test: 79 | if args.src_path and args.dst_path and args.prompt: 80 | src_path = args.src_path 81 | dst_path = args.dst_path 82 | prompt = args.prompt 83 | ebench = VEBenchModel() 84 | result = ebench.evaluate(prompt, src_path, dst_path) 85 | print(f"The result is {result}") 86 | else: 87 | print("Error: For single test, --src_path, --dst_path, and --prompt must be provided.") 88 | else: 89 | if args.data_path and args.label_path: 90 | data_path = args.data_path 91 | label_path = args.label_path 92 | src=[] 93 | dst=[] 94 | prompts=[] 95 | with open(label_path,'r') as file: 96 | for line in file: 97 | video_name,_,prompt=line.split('|') 98 | src+=[data_path+"src/"+video_name] 99 | dst += [data_path + "edited/" + video_name] 100 | prompts+=[prompt] 101 | ebench = EBenchModel() 102 | results=[] 103 | for src_path,dst_path,prompt in zip(src,dst,prompts): 104 | result = ebench.evaluate(prompt, src_path, dst_path) 105 | results+=[result] 106 | print(len(results)) 107 | with open("label.txt","w") as file: 108 | for src_path,result in zip(src,results): 109 | file.write(f"{src_path.split('/')[-1]},{result}\n") 110 | else: 111 | print("Error: For batch test, --data_path, --label_path must be provided.") 112 | -------------------------------------------------------------------------------- /vebench/blip_models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig 2 | from .nlvr_encoder import BertModel 3 | from .vit import interpolate_pos_embed 4 | from .blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | class BLIP_NLVR(nn.Module): 15 | def __init__(self, 16 | med_config = 'configs/med_config.json', 17 | image_size = 480, 18 | vit = 'base', 19 | vit_grad_ckpt = False, 20 | vit_ckpt_layer = 0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | self.cls_head = nn.Sequential( 37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(self.text_encoder.config.hidden_size, 2) 40 | ) 41 | 42 | def forward(self, image, text, targets, train=True): 43 | 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 47 | 48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 49 | text.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = [image0_embeds,image1_embeds], 54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 55 | image_atts[image0_embeds.size(0):]], 56 | return_dict = True, 57 | ) 58 | hidden_state = output.last_hidden_state[:,0,:] 59 | prediction = self.cls_head(hidden_state) 60 | 61 | if train: 62 | loss = F.cross_entropy(prediction, targets) 63 | return loss 64 | else: 65 | return prediction 66 | 67 | def blip_nlvr(pretrained='',**kwargs): 68 | model = BLIP_NLVR(**kwargs) 69 | if pretrained: 70 | model,msg = load_checkpoint(model,pretrained) 71 | print("missing keys:") 72 | print(msg.missing_keys) 73 | return model 74 | 75 | 76 | def load_checkpoint(model,url_or_filename): 77 | if is_url(url_or_filename): 78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 79 | checkpoint = torch.load(cached_file, map_location='cpu') 80 | elif os.path.isfile(url_or_filename): 81 | checkpoint = torch.load(url_or_filename, map_location='cpu') 82 | else: 83 | raise RuntimeError('checkpoint url or path is invalid') 84 | state_dict = checkpoint['model'] 85 | 86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 87 | 88 | for key in list(state_dict.keys()): 89 | if 'crossattention.self.' in key: 90 | new_key0 = key.replace('self','self0') 91 | new_key1 = key.replace('self','self1') 92 | state_dict[new_key0] = state_dict[key] 93 | state_dict[new_key1] = state_dict[key] 94 | elif 'crossattention.output.dense.' in key: 95 | new_key0 = key.replace('dense','dense0') 96 | new_key1 = key.replace('dense','dense1') 97 | state_dict[new_key0] = state_dict[key] 98 | state_dict[new_key1] = state_dict[key] 99 | 100 | msg = model.load_state_dict(state_dict,strict=False) 101 | # print('load checkpoint from %s'%url_or_filename) 102 | return model,msg 103 | -------------------------------------------------------------------------------- /vebench/models/tradition.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | from functools import partial, reduce 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | from .backbone.conv_backbone import convnext_3d_tiny 10 | from .head import VARHead, VQAHead,VQAHead_cls 11 | from .backbone.swin_backbone import SwinTransformer3D as VideoBackbone 12 | 13 | class DOVER(nn.Module): 14 | def __init__( 15 | self, 16 | backbone_size="divided", 17 | backbone_preserve_keys="technical,aesthetic", 18 | multi=False, 19 | layer=-1, 20 | backbone=dict( 21 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} 22 | ), 23 | divide_head=True, 24 | vqa_head=dict(in_channels=768), 25 | var=False, 26 | model_path=None, 27 | ): 28 | self.backbone_preserve_keys = backbone_preserve_keys.split(",") 29 | self.multi = multi 30 | self.layer = layer 31 | super().__init__() 32 | for key, hypers in backbone.items(): 33 | if key not in self.backbone_preserve_keys: 34 | continue 35 | if backbone_size == "divided": 36 | t_backbone_size = hypers["type"] 37 | else: 38 | t_backbone_size = backbone_size 39 | if t_backbone_size == "swin_tiny_grpb": 40 | # to reproduce fast-vqa 41 | b = VideoBackbone() 42 | elif t_backbone_size == "conv_tiny": 43 | b = convnext_3d_tiny(pretrained=model_path) 44 | else: 45 | raise NotImplementedError 46 | setattr(self, key + "_backbone", b) 47 | if divide_head: 48 | for key in backbone: 49 | pre_pool = False #if key == "technical" else True 50 | if key not in self.backbone_preserve_keys: 51 | continue 52 | b = VQAHead_cls(pre_pool=pre_pool, **vqa_head) 53 | setattr(self, key + "_head", b) 54 | else: 55 | if var: 56 | self.vqa_head = VARHead(**vqa_head) 57 | else: 58 | self.vqa_head = VQAHead(**vqa_head) 59 | 60 | def forward( 61 | self, 62 | vclips, 63 | inference=True, 64 | return_pooled_feats=False, 65 | return_raw_feats=False, 66 | reduce_scores=False, 67 | pooled=False, 68 | **kwargs 69 | ): 70 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" 71 | if inference: 72 | self.eval() 73 | with torch.no_grad(): 74 | scores = [] 75 | feats = {} 76 | for key in self.backbone_preserve_keys: 77 | feat = getattr(self, key.split("_")[0] + "_backbone")( 78 | vclips[key], multi=self.multi, layer=self.layer, **kwargs 79 | ) 80 | if hasattr(self, key.split("_")[0] + "_head"): 81 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)] 82 | else: 83 | scores += [getattr(self, "vqa_head")(feat)] 84 | if return_pooled_feats: 85 | feats[key] = feat 86 | if return_raw_feats: 87 | feats[key] = feat 88 | if reduce_scores: 89 | if len(scores) > 1: 90 | scores = reduce(lambda x, y: x + y, scores) 91 | else: 92 | scores = scores[0] 93 | if pooled: 94 | scores = torch.mean(scores, (1, 2, 3, 4)) 95 | self.train() 96 | if return_pooled_feats or return_raw_feats: 97 | return scores, feats 98 | return scores 99 | else: 100 | self.train() 101 | scores = [] 102 | feats = {} 103 | for key in vclips: 104 | feat = getattr(self, key.split("_")[0] + "_backbone")( 105 | vclips[key], multi=self.multi, layer=self.layer, **kwargs 106 | ) 107 | if hasattr(self, key.split("_")[0] + "_head"): 108 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)] 109 | else: 110 | scores += [getattr(self, "vqa_head")(feat)] 111 | if return_pooled_feats: 112 | feats[key] = feat.mean((-3, -2, -1)) 113 | if reduce_scores: 114 | if len(scores) > 1: 115 | scores = reduce(lambda x, y: x + y, scores) 116 | else: 117 | scores = scores[0] 118 | if pooled: 119 | scores = torch.mean(scores, (1, 2, 3, 4)) 120 | 121 | if return_pooled_feats: 122 | return scores, feats 123 | return scores -------------------------------------------------------------------------------- /vebench/models/text_alignment.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .head import VQAHead_cls,VARHead,VQAHead 4 | from .backbone.blip import MyBLIP as BLIP 5 | 6 | class VideoTextAlignmentModel(nn.Module): 7 | def __init__( 8 | self, 9 | backbone_size="divided", 10 | backbone_preserve_keys="fragments,resize", 11 | multi=False, 12 | layer=-1, 13 | backbone=dict( 14 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} 15 | ), 16 | divide_head=False, 17 | head_type='VQAhead_cls', 18 | vqa_head=dict(in_channels=768), 19 | var=False, 20 | use_tn=False, 21 | model_path=None, 22 | ): 23 | self.backbone_preserve_keys = backbone_preserve_keys.split(",") 24 | self.multi = multi 25 | self.layer = layer 26 | super().__init__() 27 | 28 | for key, hypers in backbone.items(): 29 | if key not in self.backbone_preserve_keys: 30 | continue 31 | if backbone_size == "divided": 32 | t_backbone_size = hypers["type"] 33 | else: 34 | t_backbone_size = backbone_size 35 | 36 | assert t_backbone_size == "blip" 37 | type = hypers["blip_type"] 38 | b = BLIP(type, model_path) 39 | 40 | setattr(self, key + "_backbone", b) 41 | if divide_head: 42 | for key in backbone: 43 | pre_pool = False # if key == "technical" else True 44 | if key not in self.backbone_preserve_keys: 45 | continue 46 | in_channel = 768 47 | b = VQAHead_cls(pre_pool=pre_pool, in_channels=in_channel, **vqa_head) 48 | setattr(self, key + "_head", b) 49 | else: 50 | if var: 51 | self.vqa_head = VARHead(**vqa_head) 52 | else: 53 | self.vqa_head = VQAHead(**vqa_head) 54 | 55 | def forward( 56 | self, 57 | vclips, 58 | prompts=None, 59 | inference=True, 60 | return_pooled_feats=False, 61 | return_raw_feats=False, 62 | reduce_scores=False, 63 | pooled=False, 64 | **kwargs 65 | ): 66 | # import pdb;pdb.set_trace() 67 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" 68 | if inference: 69 | self.eval() 70 | with torch.no_grad(): 71 | scores = [] 72 | feats = {} 73 | for key in self.backbone_preserve_keys: 74 | feat = getattr(self, key.split("_")[0] + "_backbone")( 75 | vclips[key], prompts 76 | ) 77 | if hasattr(self, key.split("_")[0] + "_head"): 78 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)[0]] 79 | else: 80 | scores += [getattr(self, "vqa_head")(feat)] 81 | if return_pooled_feats: 82 | feats[key] = feat 83 | if return_raw_feats: 84 | feats[key] = feat 85 | if reduce_scores: 86 | if len(scores) > 1: 87 | scores = reduce(lambda x, y: x + y, scores) 88 | else: 89 | scores = scores[0] 90 | if pooled: 91 | scores = torch.mean(scores, (1, 2, 3, 4)) 92 | self.train() 93 | if return_pooled_feats or return_raw_feats: 94 | return scores, feats 95 | return scores 96 | else: 97 | self.train() 98 | scores = [] 99 | feats = {} 100 | 101 | for key in vclips: 102 | feat = getattr(self, key.split("_")[0] + "_backbone")( 103 | vclips[key], prompts 104 | ) 105 | if hasattr(self, key.split("_")[0] + "_head"): 106 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)[0]] 107 | else: 108 | scores += [getattr(self, "vqa_head")(feat)] 109 | if return_pooled_feats: 110 | feats[key] = feat.mean((-3, -2, -1)) 111 | if reduce_scores: 112 | if len(scores) > 1: 113 | scores = reduce(lambda x, y: x + y, scores) 114 | else: 115 | scores = scores[0] 116 | if pooled: 117 | # print(scores.shape) 118 | scores = torch.mean(scores, (1, 2, 3, 4)) 119 | # print(scores.shape) 120 | 121 | if return_pooled_feats: 122 | return scores, feats 123 | return scores 124 | -------------------------------------------------------------------------------- /vebench/models/fidelity.py: -------------------------------------------------------------------------------- 1 | from .backbone.uniformer_backbone import uniformerv2_b16 2 | from .head import VQAHead_cls,VARHead,VQAHead 3 | import torch.nn as nn 4 | import torch 5 | 6 | class DoubleStreamModel(nn.Module): 7 | def __init__( 8 | self, 9 | backbone_size="divided", 10 | backbone_preserve_keys="fragments,resize", 11 | multi=False, 12 | layer=-1, 13 | backbone=dict( 14 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} 15 | ), 16 | divide_head=False, 17 | head_type='VQAhead_cls', 18 | vqa_head=dict(in_channels=768), 19 | var=False, 20 | use_tn=False, 21 | model_path=None, 22 | ): 23 | self.backbone_preserve_keys = backbone_preserve_keys.split(",") 24 | self.multi = multi 25 | self.layer = layer 26 | super().__init__() 27 | 28 | for key, hypers in backbone.items(): 29 | if key not in self.backbone_preserve_keys: 30 | continue 31 | if backbone_size == "divided": 32 | t_backbone_size = hypers["type"] 33 | else: 34 | t_backbone_size = backbone_size 35 | assert t_backbone_size == "uniformerv2_b16" 36 | b = uniformerv2_b16(pretrained=model_path, temporal_downsample=False, no_lmhra=True, t_size=32) 37 | setattr(self, key + "_backbone", b) 38 | if divide_head: 39 | for key in backbone: 40 | pre_pool = False # if key == "technical" else True 41 | if key not in self.backbone_preserve_keys: 42 | continue 43 | in_channel = 1536 44 | b = VQAHead_cls(pre_pool=pre_pool, in_channels=in_channel, **vqa_head) 45 | setattr(self, key + "_head", b) 46 | else: 47 | if var: 48 | self.vqa_head = VARHead(**vqa_head) 49 | else: 50 | self.vqa_head = VQAHead(**vqa_head) 51 | 52 | def forward( 53 | self, 54 | vclips, 55 | prompts=None, 56 | inference=True, 57 | return_pooled_feats=False, 58 | return_raw_feats=False, 59 | reduce_scores=False, 60 | pooled=False, 61 | **kwargs 62 | ): 63 | # import pdb;pdb.set_trace() 64 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" 65 | if inference: 66 | self.eval() 67 | with torch.no_grad(): 68 | scores = [] 69 | feats = [] 70 | for key in self.backbone_preserve_keys: 71 | if "time" in key: 72 | feat = getattr(self, key.split("_")[0] + "_backbone")( 73 | vclips[key], prompts 74 | ) 75 | else: 76 | feat = getattr(self, key.split("_")[0] + "_backbone")( 77 | vclips[key] 78 | ) 79 | feats += [feat] 80 | 81 | feats = (torch.cat(feats, dim=1)) 82 | if hasattr(self, key.split("_")[0] + "_head"): 83 | scores += [getattr(self, key.split("_")[0] + "_head")(feats)[0]] 84 | else: 85 | scores += [getattr(self, "vqa_head")(feats)] 86 | 87 | if reduce_scores: 88 | if len(scores) > 1: 89 | scores = reduce(lambda x, y: x + y, scores) 90 | else: 91 | scores = scores[0] 92 | if pooled: 93 | scores = torch.mean(scores, (1, 2, 3, 4)) 94 | self.train() 95 | if return_pooled_feats or return_raw_feats: 96 | return scores, feats 97 | return scores 98 | else: 99 | self.train() 100 | scores = [] 101 | feats = [] 102 | for key in vclips: 103 | feat = getattr(self, key.split("_")[0] + "_backbone")( 104 | vclips[key] 105 | ) 106 | feats.append(feat) 107 | feats = (torch.cat(feats, dim=1)) 108 | if hasattr(self, key.split("_")[0] + "_head"): 109 | scores += [getattr(self, key.split("_")[0] + "_head")(feats)[0]] 110 | else: 111 | scores += [getattr(self, "vqa_head")(feats)] 112 | scores += [torch.zeros_like(scores[0])] 113 | if reduce_scores: 114 | if len(scores) > 1: 115 | scores = reduce(lambda x, y: x + y, scores) 116 | else: 117 | scores = scores[0] 118 | if pooled: 119 | scores = torch.mean(scores, (1, 2, 3, 4)) 120 | 121 | if return_pooled_feats: 122 | return scores, feats 123 | return scores -------------------------------------------------------------------------------- /vebench/blip_models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig, BertModel, BertLMHeadModel 2 | from .blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | class BLIP_VQA(nn.Module): 11 | def __init__(self, 12 | med_config = 'BLIP_configs/med_config.json', 13 | image_size = 480, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | ): 18 | """ 19 | Args: 20 | med_config (str): path for the mixture of encoder-decoder model's configuration file 21 | image_size (int): input image size 22 | vit (str): model size of vision transformer 23 | """ 24 | super().__init__() 25 | 26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 27 | self.tokenizer = init_tokenizer() 28 | 29 | encoder_config = BertConfig.from_json_file(med_config) 30 | encoder_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 32 | 33 | decoder_config = BertConfig.from_json_file(med_config) 34 | self.text_decoder = BertLMHeadModel(config=decoder_config) 35 | 36 | 37 | def forward(self, video, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 38 | temporal=[] 39 | for i in range(video.shape[2]): 40 | image=video[:,:,i,...] 41 | image_embeds = self.visual_encoder(image) 42 | temporal.append(image_embeds) 43 | temporal=torch.cat(temporal,dim=2) 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | 47 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 48 | return_tensors="pt").to(image.device) 49 | question.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | if train: 52 | ''' 53 | n: number of answers for each question 54 | weights: weight for each answer 55 | ''' 56 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 57 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 58 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 59 | 60 | question_output = self.text_encoder(question.input_ids, 61 | attention_mask = question.attention_mask, 62 | encoder_hidden_states = image_embeds, 63 | encoder_attention_mask = image_atts, 64 | return_dict = True) 65 | 66 | question_states = [] 67 | question_atts = [] 68 | for b, n in enumerate(n): 69 | question_states += [question_output.last_hidden_state[b]]*n 70 | question_atts += [question.attention_mask[b]]*n 71 | question_states = torch.stack(question_states,0) 72 | question_atts = torch.stack(question_atts,0) 73 | 74 | answer_output = self.text_decoder(answer.input_ids, 75 | attention_mask = answer.attention_mask, 76 | encoder_hidden_states = question_states, 77 | encoder_attention_mask = question_atts, 78 | labels = answer_targets, 79 | return_dict = True, 80 | reduction = 'none', 81 | ) 82 | 83 | loss = weights * answer_output.loss 84 | loss = loss.sum()/image.size(0) 85 | 86 | return loss 87 | 88 | 89 | else: 90 | question_output = self.text_encoder(question.input_ids, 91 | attention_mask = question.attention_mask, 92 | encoder_hidden_states = image_embeds, 93 | encoder_attention_mask = image_atts, 94 | return_dict = True) 95 | 96 | if inference=='generate': 97 | num_beams = 3 98 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 99 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 100 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 101 | 102 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 103 | 104 | outputs = self.text_decoder.generate(input_ids=bos_ids, 105 | max_length=10, 106 | min_length=1, 107 | num_beams=num_beams, 108 | eos_token_id=self.tokenizer.sep_token_id, 109 | pad_token_id=self.tokenizer.pad_token_id, 110 | **model_kwargs) 111 | 112 | answers = [] 113 | for output in outputs: 114 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 115 | answers.append(answer) 116 | return answers 117 | 118 | elif inference=='rank': 119 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 120 | answer.input_ids, answer.attention_mask, k_test) 121 | return max_ids 122 | 123 | 124 | 125 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 126 | 127 | num_ques = question_states.size(0) 128 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 129 | 130 | start_output = self.text_decoder(start_ids, 131 | encoder_hidden_states = question_states, 132 | encoder_attention_mask = question_atts, 133 | return_dict = True, 134 | reduction = 'none') 135 | logits = start_output.logits[:,0,:] # first token's logit 136 | 137 | # topk_probs: top-k probability 138 | # topk_ids: [num_question, k] 139 | answer_first_token = answer_ids[:,1] 140 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 141 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 142 | 143 | # answer input: [num_question*k, answer_len] 144 | input_ids = [] 145 | input_atts = [] 146 | for b, topk_id in enumerate(topk_ids): 147 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 148 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 149 | input_ids = torch.cat(input_ids,dim=0) 150 | input_atts = torch.cat(input_atts,dim=0) 151 | 152 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 153 | 154 | # repeat encoder's output for top-k answers 155 | question_states = tile(question_states, 0, k) 156 | question_atts = tile(question_atts, 0, k) 157 | 158 | output = self.text_decoder(input_ids, 159 | attention_mask = input_atts, 160 | encoder_hidden_states = question_states, 161 | encoder_attention_mask = question_atts, 162 | labels = targets_ids, 163 | return_dict = True, 164 | reduction = 'none') 165 | 166 | log_probs_sum = -output.loss 167 | log_probs_sum = log_probs_sum.view(num_ques,k) 168 | 169 | max_topk_ids = log_probs_sum.argmax(dim=1) 170 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 171 | 172 | return max_ids 173 | 174 | 175 | def blip_vqa(pretrained='',**kwargs): 176 | model = BLIP_VQA(**kwargs) 177 | if pretrained: 178 | model,msg = load_checkpoint(model,pretrained) 179 | # assert(len(msg.missing_keys)==0) 180 | return model 181 | 182 | 183 | def tile(x, dim, n_tile): 184 | init_dim = x.size(dim) 185 | repeat_idx = [1] * x.dim() 186 | repeat_idx[dim] = n_tile 187 | x = x.repeat(*(repeat_idx)) 188 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 189 | return torch.index_select(x, dim, order_index.to(x.device)) 190 | 191 | -------------------------------------------------------------------------------- /vebench/blip_models/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | from .vit import VisionTransformer, interpolate_pos_embed 12 | from .med import BertConfig, BertModel, BertLMHeadModel 13 | from transformers import BertTokenizer 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | import os 20 | from urllib.parse import urlparse 21 | from timm.models.hub import download_cached_file 22 | 23 | class BLIP_Base(nn.Module): 24 | def __init__(self, 25 | med_config = 'BLIP_configs/med_config.json', 26 | image_size = 224, 27 | vit = 'base', 28 | vit_grad_ckpt = False, 29 | vit_ckpt_layer = 0, 30 | ): 31 | """ 32 | Args: 33 | med_config (str): path for the mixture of encoder-decoder model's configuration file 34 | image_size (int): input image size 35 | vit (str): model size of vision transformer 36 | """ 37 | super().__init__() 38 | 39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 40 | self.tokenizer = init_tokenizer() 41 | med_config = BertConfig.from_json_file(med_config) 42 | med_config.encoder_width = vision_width 43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 44 | 45 | 46 | def forward(self, image, caption, mode): 47 | 48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 50 | 51 | if mode=='image': 52 | # return image features 53 | image_embeds = self.visual_encoder(image) 54 | return image_embeds 55 | 56 | elif mode=='text': 57 | # return text features 58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 59 | return_dict = True, mode = 'text') 60 | return text_output.last_hidden_state 61 | 62 | elif mode=='multimodal': 63 | # return multimodel features 64 | image_embeds = self.visual_encoder(image) 65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 66 | 67 | text.input_ids[:,0] = self.tokenizer.enc_token_id 68 | output = self.text_encoder(text.input_ids, 69 | attention_mask = text.attention_mask, 70 | encoder_hidden_states = image_embeds, 71 | encoder_attention_mask = image_atts, 72 | return_dict = True, 73 | ) 74 | return output.last_hidden_state 75 | 76 | 77 | 78 | class BLIP_Decoder(nn.Module): 79 | def __init__(self, 80 | med_config = 'BLIP_configs/med_config.json', 81 | image_size = 384, 82 | vit = 'base', 83 | vit_grad_ckpt = False, 84 | vit_ckpt_layer = 0, 85 | prompt = 'a picture of ', 86 | ): 87 | """ 88 | Args: 89 | med_config (str): path for the mixture of encoder-decoder model's configuration file 90 | image_size (int): input image size 91 | vit (str): model size of vision transformer 92 | """ 93 | super().__init__() 94 | 95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 96 | self.tokenizer = init_tokenizer() 97 | med_config = BertConfig.from_json_file(med_config) 98 | med_config.encoder_width = vision_width 99 | self.text_decoder = BertLMHeadModel(config=med_config) 100 | 101 | self.prompt = prompt 102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 103 | 104 | 105 | def forward(self, image, caption): 106 | 107 | image_embeds = self.visual_encoder(image) 108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 109 | 110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 111 | 112 | text.input_ids[:,0] = self.tokenizer.bos_token_id 113 | 114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 115 | decoder_targets[:,:self.prompt_length] = -100 116 | 117 | decoder_output = self.text_decoder(text.input_ids, 118 | attention_mask = text.attention_mask, 119 | encoder_hidden_states = image_embeds, 120 | encoder_attention_mask = image_atts, 121 | labels = decoder_targets, 122 | return_dict = True, 123 | ) 124 | loss_lm = decoder_output.loss 125 | 126 | return loss_lm 127 | 128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 129 | image_embeds = self.visual_encoder(image) 130 | 131 | if not sample: 132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 133 | 134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 136 | 137 | prompt = [self.prompt] * image.size(0) 138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 139 | input_ids[:,0] = self.tokenizer.bos_token_id 140 | input_ids = input_ids[:, :-1] 141 | 142 | if sample: 143 | #nucleus sampling 144 | outputs = self.text_decoder.generate(input_ids=input_ids, 145 | max_length=max_length, 146 | min_length=min_length, 147 | do_sample=True, 148 | top_p=top_p, 149 | num_return_sequences=1, 150 | eos_token_id=self.tokenizer.sep_token_id, 151 | pad_token_id=self.tokenizer.pad_token_id, 152 | repetition_penalty=1.1, 153 | **model_kwargs) 154 | else: 155 | #beam search 156 | outputs = self.text_decoder.generate(input_ids=input_ids, 157 | max_length=max_length, 158 | min_length=min_length, 159 | num_beams=num_beams, 160 | eos_token_id=self.tokenizer.sep_token_id, 161 | pad_token_id=self.tokenizer.pad_token_id, 162 | repetition_penalty=repetition_penalty, 163 | **model_kwargs) 164 | 165 | captions = [] 166 | for output in outputs: 167 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 168 | captions.append(caption[len(self.prompt):]) 169 | return captions 170 | 171 | 172 | def blip_decoder(pretrained='',**kwargs): 173 | model = BLIP_Decoder(**kwargs) 174 | if pretrained: 175 | model,msg = load_checkpoint(model,pretrained) 176 | assert(len(msg.missing_keys)==0) 177 | return model 178 | 179 | def blip_feature_extractor(pretrained='',**kwargs): 180 | model = BLIP_Base(**kwargs) 181 | if pretrained: 182 | model,msg = load_checkpoint(model,pretrained) 183 | assert(len(msg.missing_keys)==0) 184 | return model 185 | 186 | def init_tokenizer(): 187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 191 | return tokenizer 192 | 193 | 194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 195 | 196 | assert vit in ['base', 'large'], "vit parameter must be base or large" 197 | if vit=='base': 198 | vision_width = 768 199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 201 | drop_path_rate=0 or drop_path_rate 202 | ) 203 | elif vit=='large': 204 | vision_width = 1024 205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 207 | drop_path_rate=0.1 or drop_path_rate 208 | ) 209 | return visual_encoder, vision_width 210 | 211 | def is_url(url_or_filename): 212 | parsed = urlparse(url_or_filename) 213 | return parsed.scheme in ("http", "https") 214 | 215 | def load_checkpoint(model,url_or_filename): 216 | if is_url(url_or_filename): 217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 218 | checkpoint = torch.load(cached_file, map_location='cpu') 219 | elif os.path.isfile(url_or_filename): 220 | checkpoint = torch.load(url_or_filename, map_location='cpu') 221 | else: 222 | raise RuntimeError('checkpoint url or path is invalid') 223 | 224 | state_dict = checkpoint['model'] 225 | 226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 229 | model.visual_encoder_m) 230 | for key in model.state_dict().keys(): 231 | if key in state_dict.keys(): 232 | if state_dict[key].shape!=model.state_dict()[key].shape: 233 | del state_dict[key] 234 | 235 | msg = model.load_state_dict(state_dict,strict=False) 236 | # print('load checkpoint from %s'%url_or_filename) 237 | return model,msg 238 | 239 | -------------------------------------------------------------------------------- /vebench/models/head.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from torchvision.ops import roi_align, roi_pool 8 | 9 | 10 | class MultiHeadCrossAttention(nn.Module): 11 | def __init__(self, embed_dim, query_dim, kv_dim, num_heads, output_dim=None): 12 | super(MultiHeadCrossAttention, self).__init__() 13 | # assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads" 14 | 15 | self.embed_dim = embed_dim 16 | self.query_dim = query_dim 17 | self.kv_dim = kv_dim 18 | self.output_dim = output_dim if output_dim else embed_dim 19 | 20 | self.num_heads = num_heads 21 | self.head_dim = embed_dim // num_heads 22 | 23 | self.q_proj = nn.Linear(query_dim, embed_dim) 24 | self.k_proj = nn.Linear(kv_dim, embed_dim) 25 | self.v_proj = nn.Linear(kv_dim, embed_dim) 26 | self.out_proj = nn.Linear(embed_dim, output_dim) 27 | 28 | def forward(self, query, key, value, mask=None, return_attn=False): 29 | batch_size = query.size(0) 30 | 31 | # Linear projections 32 | q = self.q_proj(query) # NLC 33 | k = self.k_proj(key) 34 | v = self.v_proj(value) 35 | 36 | # Reshape and transpose for multi-head attention 37 | q = q.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 38 | k = k.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 39 | v = v.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 40 | 41 | # Scaled dot-product attention 42 | scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) 43 | if mask is not None: 44 | scores = scores.masked_fill(mask == 0, float('-inf')) 45 | attn = F.softmax(scores, dim=-1) 46 | 47 | # Combine heads 48 | context = torch.matmul(attn, v) 49 | context = context.transpose(1, 2).reshape(batch_size, -1, self.embed_dim) 50 | 51 | # Final linear projection 52 | output = self.out_proj(context) 53 | if return_attn: 54 | return output, attn 55 | return output 56 | 57 | 58 | class AttentionPool3d(nn.Module): 59 | def __init__(self, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.cross_attn = MultiHeadCrossAttention( 62 | embed_dim=embed_dim, 63 | query_dim=embed_dim, 64 | kv_dim=embed_dim, 65 | num_heads=num_heads, 66 | output_dim=output_dim 67 | ) 68 | self.num_heads = num_heads 69 | 70 | def forward(self, x, return_attn=False): # x: BCLHW 71 | # import pdb;pdb.set_trace() 72 | x = x.flatten(start_dim=2).permute(2, 0, 1) # BC(LHW) -> (LHW)BC 73 | x_mean = x.mean(dim=0, keepdim=True) # (1)BC 74 | x = torch.cat([x_mean, x], dim=0) # (LHW+1)BC 75 | x = x.permute(1, 0, 2).contiguous() # B(LHW+1)C 76 | x_mean = x_mean.permute(1, 0, 2).contiguous() # B(1)C 77 | 78 | if return_attn: 79 | x, attn = self.cross_attn(query=x_mean, key=x, value=x, return_attn=True) # B(1)C 80 | return x.squeeze(dim=-1), attn 81 | x = self.cross_attn(query=x_mean, key=x, value=x).squeeze(dim=1) # BC 82 | batch, channels = x.shape 83 | x = x.view(batch, channels, 1, 1, 1) 84 | 85 | return x 86 | 87 | 88 | class TextAttentionPool3d(nn.Module): 89 | def __init__(self, embed_dim: int, txt_dim: int, num_heads: int, output_dim: int = None): 90 | super().__init__() 91 | self.cross_attn = MultiHeadCrossAttention( 92 | embed_dim=embed_dim, 93 | query_dim=txt_dim, 94 | kv_dim=embed_dim, 95 | num_heads=num_heads, 96 | output_dim=output_dim 97 | ) 98 | self.num_heads = num_heads 99 | 100 | def forward(self, x, txt_feat): 101 | # import pdb;pdb.set_trace() 102 | # import pdb;pdb.set_trace() 103 | x = x.flatten(start_dim=2).permute(2, 0, 1) # BC(LHW) -> (LHW)BC 104 | x_mean = x.mean(dim=0, keepdim=True) # (1)BC 105 | x = torch.cat([x_mean, x], dim=0) # (LHW+1)BC 106 | x = x.permute(1, 0, 2).contiguous() # B(LHW+1)C 107 | x_mean = x_mean.permute(1, 0, 2).contiguous() # B(1)C 108 | 109 | txt_feat = txt_feat.unsqueeze(dim=1) # BC -> B(1)C 110 | 111 | x = self.cross_attn(query=txt_feat, key=x, value=x) # B(1)C 112 | x = x.squeeze(dim=1) 113 | batch, channels = x.shape 114 | x = x.view(batch, channels, 1, 1, 1) 115 | return x 116 | 117 | 118 | class VQAHead(nn.Module): 119 | """MLP Regression Head for VQA. 120 | Args: 121 | in_channels: input channels for MLP 122 | hidden_channels: hidden channels for MLP 123 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 124 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes) 125 | """ 126 | 127 | def __init__( 128 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, attn_pool3d=False, 129 | text_pool3d=False, **kwargs 130 | ): 131 | super().__init__() 132 | self.dropout_ratio = dropout_ratio 133 | self.in_channels = in_channels 134 | self.hidden_channels = hidden_channels 135 | self.pre_pool = pre_pool 136 | self.attn_pool3d = attn_pool3d 137 | self.text_pool3d = text_pool3d 138 | if self.dropout_ratio != 0: 139 | self.dropout = nn.Dropout(p=self.dropout_ratio) 140 | else: 141 | self.dropout = None 142 | 143 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 144 | if self.attn_pool3d: 145 | self.attn_pool = AttentionPool3d(embed_dim=self.in_channels, num_heads=12, 146 | output_dim=self.in_channels) # 768//64=12 147 | if self.text_pool3d: 148 | self.text_pool = TextAttentionPool3d(embed_dim=self.in_channels, txt_dim=1024, num_heads=12, 149 | output_dim=self.in_channels) 150 | 151 | self.fc_hid = nn.Conv3d(2 * self.in_channels, self.hidden_channels, 152 | (1, 1, 1)) if self.text_pool3d else nn.Conv3d(self.in_channels, self.hidden_channels, 153 | (1, 1, 1)) 154 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1)) 155 | self.gelu = nn.GELU() 156 | 157 | def forward(self, x, txt=None, inference=False, rois=None): 158 | # import pdb;pdb.set_trace() 159 | if self.pre_pool: 160 | x = self.avg_pool(x) 161 | if self.attn_pool3d: 162 | x_vis = self.attn_pool(x) 163 | if self.text_pool3d and txt is not None: 164 | x_txt = self.text_pool(x, txt) 165 | if inference and x_txt.size(0) != x_vis.size(0): 166 | x_txt = x_txt.expand(x_vis.size(0), -1, -1, -1, -1) 167 | x = torch.concat([x_vis, x_txt], dim=1) 168 | if self.attn_pool3d and not self.text_pool3d: 169 | x = self.dropout(x_vis) 170 | else: 171 | x = self.dropout(x) 172 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) 173 | return qlt_score 174 | 175 | 176 | def clean(serie): 177 | output = serie[(np.isnan(serie) == False) & (np.isinf(serie) == False)] 178 | return output 179 | 180 | 181 | class VQAHead_cls(nn.Module): 182 | """MLP Regression Head for VQA. 183 | Args: 184 | in_channels: input channels for MLP 185 | hidden_channels: hidden channels for MLP 186 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 187 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes) 188 | """ 189 | 190 | def __init__( 191 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, attn_pool3d=False, 192 | text_pool3d=False, **kwargs 193 | ): 194 | super().__init__() 195 | self.dropout_ratio = dropout_ratio 196 | self.in_channels = in_channels 197 | self.hidden_channels = hidden_channels 198 | self.pre_pool = pre_pool 199 | self.attn_pool3d = attn_pool3d 200 | self.text_pool3d = text_pool3d 201 | if self.dropout_ratio != 0: 202 | self.dropout = nn.Dropout(p=self.dropout_ratio) 203 | else: 204 | self.dropout = None 205 | 206 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 207 | if self.attn_pool3d: 208 | self.attn_pool = AttentionPool3d(embed_dim=self.in_channels, num_heads=16, 209 | output_dim=self.in_channels) # 768//64=12 210 | if self.text_pool3d: 211 | self.text_pool = TextAttentionPool3d(embed_dim=self.in_channels, txt_dim=1024, num_heads=16, 212 | output_dim=self.in_channels) 213 | # self.fc_hid=nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1)) 214 | self.fc_hid = nn.Conv3d(2 * self.in_channels, self.hidden_channels, 215 | (1, 1, 1)) if self.text_pool3d else nn.Conv3d(self.in_channels, self.hidden_channels, 216 | (1, 1, 1)) 217 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1)) 218 | self.gelu = nn.GELU() 219 | 220 | self.fc_cls1 = nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1)) 221 | self.fc_cls2 = nn.Conv3d(self.hidden_channels, 10, (1, 1, 1)) 222 | self.gelu_cls = nn.GELU() 223 | 224 | def forward(self, x, txt=None, inference=False, rois=None): 225 | # import pdb;pdb.set_trace() 226 | if self.pre_pool: 227 | x = self.avg_pool(x) 228 | if self.attn_pool3d: 229 | x_vis = self.attn_pool(x) 230 | x_cls = self.fc_cls2(self.dropout(self.gelu_cls(self.fc_cls1(x_vis)))) 231 | if self.text_pool3d and txt is not None: 232 | x_txt = self.text_pool(x, txt) 233 | if inference and x_txt.size(0) != x_vis.size(0): 234 | x_txt = x_txt.expand(x_vis.size(0), -1, -1, -1, -1) 235 | x = torch.concat([x_vis, x_txt], dim=1) 236 | if self.attn_pool3d and not self.text_pool3d: 237 | x = self.dropout(x_vis) 238 | else: 239 | x = self.dropout(x) 240 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) 241 | # print(qlt_score.shape) 242 | return qlt_score#, x_cls 243 | class VARHead(nn.Module): 244 | """MLP Regression Head for Video Action Recognition. 245 | Args: 246 | in_channels: input channels for MLP 247 | hidden_channels: hidden channels for MLP 248 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 249 | """ 250 | 251 | def __init__(self, in_channels=768, out_channels=400, dropout_ratio=0.5, **kwargs): 252 | super().__init__() 253 | self.dropout_ratio = dropout_ratio 254 | self.in_channels = in_channels 255 | self.out_channels = out_channels 256 | if self.dropout_ratio != 0: 257 | self.dropout = nn.Dropout(p=self.dropout_ratio) 258 | else: 259 | self.dropout = None 260 | self.fc = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1)) 261 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 262 | 263 | def forward(self, x, rois=None): 264 | x = self.dropout(x) 265 | x = self.avg_pool(x) 266 | out = self.fc(x) 267 | return out 268 | -------------------------------------------------------------------------------- /vebench/models/backbone/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import os 9 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" 10 | import warnings 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | from ...blip_models.vit import VisionTransformer, interpolate_pos_embed 15 | from ...blip_models.med import BertConfig, BertModel, BertLMHeadModel 16 | from transformers import BertTokenizer 17 | from timm.models.vision_transformer import Attention as TemporalAttention 18 | from timm.layers import Mlp, DropPath, to_2tuple 19 | from timm.layers import PatchEmbed, Mlp, DropPath, RmsNorm, PatchDropout, SwiGLUPacked, \ 20 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ 21 | get_act_layer, get_norm_layer 22 | 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import numpy as np 27 | import os 28 | from urllib.parse import urlparse 29 | from timm.models.hub import download_cached_file 30 | 31 | class MyAttention(nn.Module): 32 | 33 | def __init__( 34 | self, 35 | dim: int, 36 | num_heads: int = 8, 37 | qkv_bias: bool = False, 38 | qk_norm: bool = False, 39 | attn_drop: float = 0., 40 | proj_drop: float = 0., 41 | step:int=1, 42 | norm_layer: nn.Module = nn.LayerNorm, 43 | ) -> None: 44 | super().__init__() 45 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 46 | self.num_heads = num_heads 47 | self.head_dim = dim // num_heads 48 | self.scale = self.head_dim ** -0.5 49 | self.fused_attn = use_fused_attn() 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 53 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 54 | self.attn_drop = nn.Dropout(attn_drop) 55 | self.proj = nn.Linear(dim, dim) 56 | self.proj_drop = nn.Dropout(proj_drop) 57 | self.step=step 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | B, T, N, C = x.shape 61 | qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, self.head_dim).permute(3, 1, 0, 4, 2, 5) 62 | q, k, v = qkv.unbind(0) 63 | q, k = self.q_norm(q), self.k_norm(k) 64 | k=torch.cat((k[:self.step,...],k),dim=0)[:int(-1*self.step),...] 65 | v=torch.cat((v[:self.step,...],v),dim=0)[:int(-1*self.step),...] 66 | if self.fused_attn: 67 | x = F.scaled_dot_product_attention( 68 | q, k, v, 69 | dropout_p=self.attn_drop.p if self.training else 0., 70 | ) 71 | else: 72 | q = q * self.scale 73 | attn = q @ k.transpose(-2, -1) 74 | attn = attn.softmax(dim=-1) 75 | attn = self.attn_drop(attn) 76 | return attn 77 | 78 | 79 | from einops import rearrange 80 | 81 | class Block(nn.Module): 82 | def __init__( 83 | self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., 84 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None,type="A"): 85 | super().__init__() 86 | self.norm1 = norm_layer(dim) 87 | if ws is None: 88 | self.attn = TemporalAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop) 89 | # elif ws == 1: 90 | # self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) 91 | # else: 92 | # self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) 93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm2 = norm_layer(dim) 95 | mlp_hidden_dim = int(dim * mlp_ratio) 96 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 97 | self.temporal_attn_1=MyAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop,step=1) 98 | self.temporal_attn_2 = MyAttention(dim, num_heads, attn_drop=attn_drop, proj_drop=drop,step=2) 99 | self.temporal_conv = nn.Conv1d(dim, dim, kernel_size=3,stride=1, padding=1) 100 | self.type=type 101 | self.gelu=nn.GELU() 102 | 103 | def forward(self, x,B): 104 | # x: (B*T, h*w, C) 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | # spatial 107 | if self.type=="A": 108 | temp = self.mlp(self.norm2(x)) 109 | 110 | temp=rearrange(temp,'(b t) l c -> b t l c', b=B) 111 | 112 | # step_1=self.drop_path(self.temporal_attn_1(temp)) 113 | # step_2=self.drop_path(self.temporal_attn_2(temp)) 114 | # step=torch.cat((step_2,step_1),dim=1) 115 | # temp=torch.cat((step,temp),dim=1) 116 | # temporal 117 | temp = rearrange(temp, 'b t l c -> (b l) c t', b=B) 118 | 119 | temp = self.temporal_conv(temp) 120 | temp = rearrange(temp, '(b l) c t -> (b t) l c', b=B) 121 | 122 | # output 123 | x = x + self.drop_path(temp) 124 | elif self.type=="B": 125 | spatial = self.mlp(self.norm2(x)) 126 | temp=rearrange(spatial,'(b t) l c->(b l) c t',b=B) 127 | temp = self.temporal_conv(temp) 128 | temp = rearrange(temp, '(b l) c t -> (b t) l c', b=B) 129 | x=x+self.gelu(temp)+self.gelu(spatial) 130 | 131 | #x=rearrange(x,'(b t) l c -> b t l c',b=B).mean(1) 132 | return rearrange(x,'(b t) l c -> b t l c',b=B).mean(1),rearrange(x,'(b t) l c -> b t l c',b=B) 133 | 134 | 135 | class My_BLIP_Base(nn.Module): 136 | def __init__(self, 137 | med_config, 138 | image_size=224, 139 | vit='base', 140 | vit_grad_ckpt=False, 141 | vit_ckpt_layer=0, 142 | drop_path=0.2, 143 | in_chans=1024, 144 | embed_dim=1024, 145 | patch_size=2, 146 | ): 147 | """ 148 | Args: 149 | med_config (str): path for the mixture of encoder-decoder model's configuration file 150 | image_size (int): input image size 151 | vit (str): model size of vision transformer 152 | """ 153 | super().__init__() 154 | 155 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 156 | self.tokenizer = init_tokenizer() 157 | med_config = BertConfig.from_json_file(med_config) 158 | med_config.encoder_width = vision_width 159 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 160 | self.temporal_block=Block(dim=1024,num_heads=8,drop_path=0.2) 161 | #self.post_block=PostBlock(t_dim=768,v_dim=1024) 162 | self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 163 | 164 | 165 | self.softmax=nn.Softmax(dim=1) 166 | 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 168 | self.norm = nn.LayerNorm(embed_dim) 169 | for name, m in self.named_modules(): 170 | if 'temporal_conv' in name: 171 | nn.init.dirac_(m.weight.data) # initialized to be identity 172 | nn.init.zeros_(m.bias.data) 173 | if 'temporal_fc' in name: 174 | nn.init.constant_(m.weight, 0) 175 | nn.init.constant_(m.bias, 0) 176 | 177 | 178 | 179 | def threeDConv(self,video):#3DConv 180 | temporal=self.visual_encoder(video) 181 | return self.temporal_block(temporal.reshape(-1,temporal.shape[-2],temporal.shape[-1]),B=temporal.shape[0]) 182 | 183 | 184 | 185 | def forward(self, video, caption, mode): 186 | 187 | text = self.tokenizer(caption, return_tensors="pt",padding=True).to(video.device) 188 | 189 | assert mode=="multimodal_text" 190 | image_embeds, frame_embeds = self.threeDConv(video) # 8,197,1024 191 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(video.device) 192 | 193 | text.input_ids[:, 0] = self.tokenizer.enc_token_id 194 | output = self.text_encoder(text.input_ids, 195 | attention_mask=text.attention_mask, 196 | encoder_hidden_states=image_embeds, 197 | encoder_attention_mask=image_atts, 198 | return_dict=True, 199 | ) 200 | return output.last_hidden_state 201 | 202 | 203 | 204 | def blip_feature_extractor(pretrained='', **kwargs): 205 | base_dir = os.path.dirname(os.path.abspath(__file__)) 206 | config = os.path.join(base_dir, 'BLIP_configs', 'med_config.json') 207 | model = My_BLIP_Base(config, vit="large",**kwargs) 208 | if pretrained: 209 | model, msg = load_checkpoint(model, pretrained) 210 | #assert (len(msg.missing_keys) == 0) 211 | return model 212 | 213 | 214 | def init_tokenizer(): 215 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 216 | tokenizer.add_special_tokens({'bos_token': '[DEC]'}) 217 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) 218 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 219 | return tokenizer 220 | 221 | 222 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 223 | assert vit in ['base', 'large'], "vit parameter must be base or large" 224 | if vit == 'base': 225 | vision_width = 768 226 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 227 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, 228 | ckpt_layer=ckpt_layer, 229 | drop_path_rate=0 or drop_path_rate 230 | ) 231 | elif vit == 'large': 232 | vision_width = 1024 233 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 234 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, 235 | ckpt_layer=ckpt_layer, 236 | drop_path_rate=0.1 or drop_path_rate 237 | ) 238 | return visual_encoder, vision_width 239 | 240 | 241 | def is_url(url_or_filename): 242 | parsed = urlparse(url_or_filename) 243 | return parsed.scheme in ("http", "https") 244 | 245 | 246 | def load_checkpoint(model, url_or_filename): 247 | if is_url(url_or_filename): 248 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 249 | checkpoint = torch.load(cached_file, map_location='cpu') 250 | elif os.path.isfile(url_or_filename): 251 | checkpoint = torch.load(url_or_filename, map_location='cpu') 252 | else: 253 | print(url_or_filename) 254 | raise RuntimeError('checkpoint url or path is invalid') 255 | 256 | state_dict = checkpoint['model'] 257 | 258 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'], 259 | model.visual_encoder) 260 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 261 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 262 | model.visual_encoder_m) 263 | for key in model.state_dict().keys(): 264 | if key in state_dict.keys(): 265 | if state_dict[key].shape != model.state_dict()[key].shape: 266 | del state_dict[key] 267 | 268 | msg = model.load_state_dict(state_dict, strict=False) 269 | # print('load checkpoint from %s' % url_or_filename) 270 | return model, msg 271 | 272 | class MyBLIP(nn.Module): 273 | def __init__(self,type="multimodal", model_path=None): 274 | super().__init__() 275 | self.model = blip_feature_extractor(pretrained=os.path.join(model_path, 'model_large.pth')) 276 | self.type=type 277 | 278 | 279 | 280 | def forward(self, x, text): 281 | B, C, T, H, W = x.size() 282 | return self.model(x,text,self.type).permute(0,2,1).unsqueeze(-1).unsqueeze(-1) 283 | 284 | #return self.model(x, text, "image_attn") 285 | 286 | -------------------------------------------------------------------------------- /vebench/blip_models/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_Retrieval(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | queue_size = 57600, 19 | momentum = 0.995, 20 | negative_all_rank = False, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | text_width = self.text_encoder.config.hidden_size 37 | 38 | self.vision_proj = nn.Linear(vision_width, embed_dim) 39 | self.text_proj = nn.Linear(text_width, embed_dim) 40 | 41 | self.itm_head = nn.Linear(text_width, 2) 42 | 43 | # create momentum encoders 44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 47 | self.text_proj_m = nn.Linear(text_width, embed_dim) 48 | 49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 50 | [self.vision_proj,self.vision_proj_m], 51 | [self.text_encoder,self.text_encoder_m], 52 | [self.text_proj,self.text_proj_m], 53 | ] 54 | self.copy_params() 55 | 56 | # create the queue 57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 61 | 62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 64 | 65 | self.queue_size = queue_size 66 | self.momentum = momentum 67 | self.temp = nn.Parameter(0.07*torch.ones([])) 68 | 69 | self.negative_all_rank = negative_all_rank 70 | 71 | 72 | def forward(self, image, caption, alpha, idx): 73 | with torch.no_grad(): 74 | self.temp.clamp_(0.001,0.5) 75 | 76 | image_embeds = self.visual_encoder(image) 77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 79 | 80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 81 | return_tensors="pt").to(image.device) 82 | 83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 84 | return_dict = True, mode = 'text') 85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 86 | 87 | ###============== Image-text Contrastive Learning ===================### 88 | idx = idx.view(-1,1) 89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 90 | pos_idx = torch.eq(idx, idx_all).float() 91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 92 | 93 | # get momentum features 94 | with torch.no_grad(): 95 | self._momentum_update() 96 | image_embeds_m = self.visual_encoder_m(image) 97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 99 | 100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 101 | return_dict = True, mode = 'text') 102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 104 | 105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 107 | 108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 110 | 111 | sim_i2t = image_feat @ text_feat_m_all / self.temp 112 | sim_t2i = text_feat @ image_feat_m_all / self.temp 113 | 114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 116 | 117 | loss_ita = (loss_i2t+loss_t2i)/2 118 | 119 | idxs = concat_all_gather(idx) 120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 121 | 122 | ###============== Image-text Matching ===================### 123 | encoder_input_ids = text.input_ids.clone() 124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 125 | 126 | # forward the positve image-text pair 127 | bs = image.size(0) 128 | output_pos = self.text_encoder(encoder_input_ids, 129 | attention_mask = text.attention_mask, 130 | encoder_hidden_states = image_embeds, 131 | encoder_attention_mask = image_atts, 132 | return_dict = True, 133 | ) 134 | 135 | 136 | if self.negative_all_rank: 137 | # compute sample similarity 138 | with torch.no_grad(): 139 | mask = torch.eq(idx, idxs.t()) 140 | 141 | image_feat_world = concat_all_gather(image_feat) 142 | text_feat_world = concat_all_gather(text_feat) 143 | 144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 146 | 147 | weights_i2t = F.softmax(sim_i2t,dim=1) 148 | weights_i2t.masked_fill_(mask, 0) 149 | 150 | weights_t2i = F.softmax(sim_t2i,dim=1) 151 | weights_t2i.masked_fill_(mask, 0) 152 | 153 | image_embeds_world = all_gather_with_grad(image_embeds) 154 | 155 | # select a negative image (from all ranks) for each text 156 | image_embeds_neg = [] 157 | for b in range(bs): 158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 159 | image_embeds_neg.append(image_embeds_world[neg_idx]) 160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 161 | 162 | # select a negative text (from all ranks) for each image 163 | input_ids_world = concat_all_gather(encoder_input_ids) 164 | att_mask_world = concat_all_gather(text.attention_mask) 165 | 166 | text_ids_neg = [] 167 | text_atts_neg = [] 168 | for b in range(bs): 169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 170 | text_ids_neg.append(input_ids_world[neg_idx]) 171 | text_atts_neg.append(att_mask_world[neg_idx]) 172 | 173 | else: 174 | with torch.no_grad(): 175 | mask = torch.eq(idx, idx.t()) 176 | 177 | sim_i2t = image_feat @ text_feat.t() / self.temp 178 | sim_t2i = text_feat @ image_feat.t() / self.temp 179 | 180 | weights_i2t = F.softmax(sim_i2t,dim=1) 181 | weights_i2t.masked_fill_(mask, 0) 182 | 183 | weights_t2i = F.softmax(sim_t2i,dim=1) 184 | weights_t2i.masked_fill_(mask, 0) 185 | 186 | # select a negative image (from same rank) for each text 187 | image_embeds_neg = [] 188 | for b in range(bs): 189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 190 | image_embeds_neg.append(image_embeds[neg_idx]) 191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 192 | 193 | # select a negative text (from same rank) for each image 194 | text_ids_neg = [] 195 | text_atts_neg = [] 196 | for b in range(bs): 197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 198 | text_ids_neg.append(encoder_input_ids[neg_idx]) 199 | text_atts_neg.append(text.attention_mask[neg_idx]) 200 | 201 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 202 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 203 | 204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 206 | 207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 209 | 210 | output_neg = self.text_encoder(text_ids_all, 211 | attention_mask = text_atts_all, 212 | encoder_hidden_states = image_embeds_all, 213 | encoder_attention_mask = image_atts_all, 214 | return_dict = True, 215 | ) 216 | 217 | 218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 219 | vl_output = self.itm_head(vl_embeddings) 220 | 221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 222 | dim=0).to(image.device) 223 | loss_itm = F.cross_entropy(vl_output, itm_labels) 224 | 225 | return loss_ita, loss_itm 226 | 227 | 228 | @torch.no_grad() 229 | def copy_params(self): 230 | for model_pair in self.model_pairs: 231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 232 | param_m.data.copy_(param.data) # initialize 233 | param_m.requires_grad = False # not update by gradient 234 | 235 | 236 | @torch.no_grad() 237 | def _momentum_update(self): 238 | for model_pair in self.model_pairs: 239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 241 | 242 | 243 | @torch.no_grad() 244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 245 | # gather keys before updating queue 246 | image_feats = concat_all_gather(image_feat) 247 | text_feats = concat_all_gather(text_feat) 248 | 249 | 250 | batch_size = image_feats.shape[0] 251 | 252 | ptr = int(self.ptr_queue) 253 | assert self.queue_size % batch_size == 0 # for simplicity 254 | 255 | # replace the keys at ptr (dequeue and enqueue) 256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 259 | ptr = (ptr + batch_size) % self.queue_size # move pointer 260 | 261 | self.ptr_queue[0] = ptr 262 | 263 | 264 | def blip_retrieval(pretrained='',**kwargs): 265 | model = BLIP_Retrieval(**kwargs) 266 | if pretrained: 267 | model,msg = load_checkpoint(model,pretrained) 268 | print("missing keys:") 269 | print(msg.missing_keys) 270 | return model 271 | 272 | 273 | @torch.no_grad() 274 | def concat_all_gather(tensor): 275 | """ 276 | Performs all_gather operation on the provided tensors. 277 | *** Warning ***: torch.distributed.all_gather has no gradient. 278 | """ 279 | tensors_gather = [torch.ones_like(tensor) 280 | for _ in range(torch.distributed.get_world_size())] 281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 282 | 283 | output = torch.cat(tensors_gather, dim=0) 284 | return output 285 | 286 | 287 | class GatherLayer(torch.autograd.Function): 288 | """ 289 | Gather tensors from all workers with support for backward propagation: 290 | This implementation does not cut the gradients as torch.distributed.all_gather does. 291 | """ 292 | 293 | @staticmethod 294 | def forward(ctx, x): 295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 296 | torch.distributed.all_gather(output, x) 297 | return tuple(output) 298 | 299 | @staticmethod 300 | def backward(ctx, *grads): 301 | all_gradients = torch.stack(grads) 302 | torch.distributed.all_reduce(all_gradients) 303 | return all_gradients[torch.distributed.get_rank()] 304 | 305 | 306 | def all_gather_with_grad(tensors): 307 | """ 308 | Performs all_gather operation on the provided tensors. 309 | Graph remains connected for backward grad computation. 310 | """ 311 | # Queue the gathered tensors 312 | world_size = torch.distributed.get_world_size() 313 | # There is no need for reduction in the single-proc case 314 | if world_size == 1: 315 | return tensors 316 | 317 | tensor_all = GatherLayer.apply(tensors) 318 | 319 | return torch.cat(tensor_all, dim=0) 320 | -------------------------------------------------------------------------------- /vebench/preprocess.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import glob 4 | import os 5 | os.environ["WANDB_MODE"] = "offline" 6 | import os.path as osp 7 | import random 8 | from functools import lru_cache 9 | 10 | import decord 11 | import skvideo.io 12 | import torch 13 | import torchvision 14 | from decord import VideoReader, cpu, gpu 15 | 16 | 17 | decord.bridge.set_bridge("torch") 18 | 19 | 20 | def get_spatial_fragments( 21 | video, 22 | fragments_h=7, 23 | fragments_w=7, 24 | fsize_h=32, 25 | fsize_w=32, 26 | aligned=32, 27 | nfrags=1, 28 | random=False, 29 | random_upsample=False, 30 | fallback_type="upsample", 31 | upsample=-1, 32 | **kwargs, 33 | ): 34 | if upsample > 0: 35 | old_h, old_w = video.shape[-2], video.shape[-1] 36 | if old_h >= old_w: 37 | w = upsample 38 | h = int(upsample * old_h / old_w) 39 | else: 40 | h = upsample 41 | w = int(upsample * old_w / old_h) 42 | 43 | video = get_resized_video(video, h, w) 44 | size_h = fragments_h * fsize_h 45 | size_w = fragments_w * fsize_w 46 | ## video: [C,T,H,W] 47 | ## situation for images 48 | if video.shape[1] == 1: 49 | aligned = 1 50 | 51 | dur_t, res_h, res_w = video.shape[-3:] 52 | ratio = min(res_h / size_h, res_w / size_w) 53 | if fallback_type == "upsample" and ratio < 1: 54 | 55 | ovideo = video 56 | video = torch.nn.functional.interpolate( 57 | video / 255.0, scale_factor=1 / ratio, mode="bilinear" 58 | ) 59 | video = (video * 255.0).type_as(ovideo) 60 | 61 | if random_upsample: 62 | 63 | randratio = random.random() * 0.5 + 1 64 | video = torch.nn.functional.interpolate( 65 | video / 255.0, scale_factor=randratio, mode="bilinear" 66 | ) 67 | video = (video * 255.0).type_as(ovideo) 68 | 69 | assert dur_t % aligned == 0, "Please provide match vclip and align index" 70 | size = size_h, size_w 71 | 72 | ## make sure that sampling will not run out of the picture 73 | hgrids = torch.LongTensor( 74 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] 75 | ) 76 | wgrids = torch.LongTensor( 77 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] 78 | ) 79 | hlength, wlength = res_h // fragments_h, res_w // fragments_w 80 | 81 | if random: 82 | print("This part is deprecated. Please remind that.") 83 | if res_h > fsize_h: 84 | rnd_h = torch.randint( 85 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 86 | ) 87 | else: 88 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 89 | if res_w > fsize_w: 90 | rnd_w = torch.randint( 91 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 92 | ) 93 | else: 94 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 95 | else: 96 | if hlength > fsize_h: 97 | rnd_h = torch.randint( 98 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 99 | ) 100 | else: 101 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 102 | if wlength > fsize_w: 103 | rnd_w = torch.randint( 104 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 105 | ) 106 | else: 107 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 108 | 109 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device) 110 | # target_videos = [] 111 | 112 | for i, hs in enumerate(hgrids): 113 | for j, ws in enumerate(wgrids): 114 | for t in range(dur_t // aligned): 115 | t_s, t_e = t * aligned, (t + 1) * aligned 116 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h 117 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w 118 | if random: 119 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h 120 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w 121 | else: 122 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h 123 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w 124 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ 125 | :, t_s:t_e, h_so:h_eo, w_so:w_eo 126 | ] 127 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) 128 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) 129 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments 130 | return target_video 131 | 132 | 133 | @lru_cache 134 | def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False): 135 | if random_crop: 136 | return torchvision.transforms.RandomResizedCrop( 137 | (size_h, size_w), scale=(0.40, 1.0) 138 | ) 139 | if target_ratio > 1: 140 | size_h = int(target_ratio * size_w) 141 | assert size_h > size_w 142 | elif target_ratio < 1: 143 | size_w = int(size_h / target_ratio) 144 | assert size_w > size_h 145 | return torchvision.transforms.Resize((size_h, size_w)) 146 | 147 | 148 | def get_resized_video( 149 | video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs, 150 | ): 151 | video = video.permute(1, 0, 2, 3) 152 | resize_opt = get_resize_function( 153 | size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop 154 | ) 155 | video = resize_opt(video).permute(1, 0, 2, 3) 156 | return video 157 | 158 | 159 | def get_arp_resized_video( 160 | video, short_edge=224, train=False, **kwargs, 161 | ): 162 | if train: ## if during training, will random crop into square and then resize 163 | res_h, res_w = video.shape[-2:] 164 | ori_short_edge = min(video.shape[-2:]) 165 | if res_h > ori_short_edge: 166 | rnd_h = random.randrange(res_h - ori_short_edge) 167 | video = video[..., rnd_h : rnd_h + ori_short_edge, :] 168 | elif res_w > ori_short_edge: 169 | rnd_w = random.randrange(res_w - ori_short_edge) 170 | video = video[..., :, rnd_h : rnd_h + ori_short_edge] 171 | ori_short_edge = min(video.shape[-2:]) 172 | scale_factor = short_edge / ori_short_edge 173 | ovideo = video 174 | video = torch.nn.functional.interpolate( 175 | video / 255.0, scale_factors=scale_factor, mode="bilinear" 176 | ) 177 | video = (video * 255.0).type_as(ovideo) 178 | return video 179 | 180 | 181 | def get_arp_fragment_video( 182 | video, short_fragments=7, fsize=32, train=False, **kwargs, 183 | ): 184 | if ( 185 | train 186 | ): ## if during training, will random crop into square and then get fragments 187 | res_h, res_w = video.shape[-2:] 188 | ori_short_edge = min(video.shape[-2:]) 189 | if res_h > ori_short_edge: 190 | rnd_h = random.randrange(res_h - ori_short_edge) 191 | video = video[..., rnd_h : rnd_h + ori_short_edge, :] 192 | elif res_w > ori_short_edge: 193 | rnd_w = random.randrange(res_w - ori_short_edge) 194 | video = video[..., :, rnd_h : rnd_h + ori_short_edge] 195 | kwargs["fsize_h"], kwargs["fsize_w"] = fsize, fsize 196 | res_h, res_w = video.shape[-2:] 197 | if res_h > res_w: 198 | kwargs["fragments_w"] = short_fragments 199 | kwargs["fragments_h"] = int(short_fragments * res_h / res_w) 200 | else: 201 | kwargs["fragments_h"] = short_fragments 202 | kwargs["fragments_w"] = int(short_fragments * res_w / res_h) 203 | return get_spatial_fragments(video, **kwargs) 204 | 205 | 206 | def get_cropped_video( 207 | video, size_h=224, size_w=224, **kwargs, 208 | ): 209 | kwargs["fragments_h"], kwargs["fragments_w"] = 1, 1 210 | kwargs["fsize_h"], kwargs["fsize_w"] = size_h, size_w 211 | return get_spatial_fragments(video, **kwargs) 212 | 213 | 214 | def get_single_view( 215 | video, sample_type="aesthetic", **kwargs, 216 | ): 217 | if sample_type.startswith("aesthetic"): 218 | video = get_resized_video(video, **kwargs) 219 | elif sample_type.startswith("technical"): 220 | video = get_spatial_fragments(video, **kwargs) 221 | elif sample_type.startswith("clip"): 222 | video = get_resized_video(video, **kwargs) 223 | elif sample_type.startswith("time"): 224 | video = get_resized_video(video, **kwargs) 225 | elif sample_type.startswith("other"): 226 | video = get_spatial_fragments(video, **kwargs) 227 | elif "flow" in sample_type: 228 | video = get_resized_video(video, **kwargs) 229 | elif sample_type == "original": 230 | return video 231 | 232 | return video 233 | 234 | 235 | def spatial_temporal_view_decomposition( 236 | video_path, sample_types, samplers, edit_video_path=None,is_train=False, augment=False, 237 | ): 238 | video = {} 239 | if video_path.endswith(".yuv"): 240 | print("This part will be deprecated due to large memory cost.") 241 | ## This is only an adaptation to LIVE-Qualcomm 242 | ovideo = skvideo.io.vread( 243 | video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} 244 | ) 245 | for stype in samplers: 246 | frame_inds = samplers[stype](ovideo.shape[0], is_train) 247 | imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds] 248 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) 249 | del ovideo 250 | else: 251 | decord.bridge.set_bridge("torch") 252 | vreader = VideoReader(video_path) 253 | ### Avoid duplicated video decoding!!! Important!!!! 254 | all_frame_inds = [] 255 | frame_inds = {} 256 | for stype in samplers: 257 | frame_inds[stype] = samplers[stype](len(vreader), is_train) 258 | all_frame_inds.append(frame_inds[stype]) 259 | 260 | ### Each frame is only decoded one time!!! 261 | all_frame_inds = np.concatenate(all_frame_inds, 0) 262 | frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)} 263 | 264 | for stype in samplers: 265 | imgs = [frame_dict[idx] for idx in frame_inds[stype]] 266 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) 267 | 268 | sampled_video = {} 269 | for stype, sopt in sample_types.items(): 270 | sampled_video[stype] = get_single_view(video[stype], stype, **sopt) 271 | return sampled_video, frame_inds 272 | 273 | 274 | 275 | 276 | 277 | import random 278 | 279 | import numpy as np 280 | 281 | 282 | class UnifiedFrameSampler: 283 | def __init__( 284 | self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0, 285 | ): 286 | 287 | self.fragments_t = fragments_t 288 | self.fsize_t = fsize_t 289 | self.size_t = fragments_t * fsize_t 290 | self.frame_interval = frame_interval 291 | self.num_clips = num_clips 292 | self.drop_rate = drop_rate 293 | 294 | def get_frame_indices(self, num_frames, train=False): 295 | 296 | tgrids = np.array( 297 | [num_frames // self.fragments_t * i for i in range(self.fragments_t)], 298 | dtype=np.int32, 299 | ) 300 | tlength = num_frames // self.fragments_t 301 | 302 | if tlength > self.fsize_t * self.frame_interval: 303 | rnd_t = np.random.randint( 304 | 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) 305 | ) 306 | else: 307 | rnd_t = np.zeros(len(tgrids), dtype=np.int32) 308 | 309 | ranges_t = ( 310 | np.arange(self.fsize_t)[None, :] * self.frame_interval 311 | + rnd_t[:, None] 312 | + tgrids[:, None] 313 | ) 314 | 315 | drop = random.sample( 316 | list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate) 317 | ) 318 | dropped_ranges_t = [] 319 | for i, rt in enumerate(ranges_t): 320 | if i not in drop: 321 | dropped_ranges_t.append(rt) 322 | return np.concatenate(dropped_ranges_t) 323 | 324 | def __call__(self, total_frames, train=False, start_index=0): 325 | frame_inds = [] 326 | 327 | for i in range(self.num_clips): 328 | frame_inds += [self.get_frame_indices(total_frames)] 329 | 330 | frame_inds = np.concatenate(frame_inds) 331 | frame_inds = np.mod(frame_inds + start_index, total_frames) 332 | return frame_inds.astype(np.int32) 333 | 334 | 335 | class Processor(): 336 | def __init__(self, opt,from_src=False): 337 | ## opt is a dictionary that includes options for video sampling 338 | 339 | super().__init__() 340 | 341 | self.sample_types = opt["sample_types"] 342 | self.phase = opt["phase"] 343 | self.crop = opt.get("random_crop", False) 344 | self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) 345 | self.std = torch.FloatTensor([58.395, 57.12, 57.375]) 346 | self.samplers = {} 347 | for stype, sopt in opt["sample_types"].items(): 348 | if "t_frag" not in sopt: 349 | # resized temporal sampling for TQE in DOVER 350 | self.samplers[stype] = UnifiedFrameSampler( 351 | sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"] 352 | ) 353 | else: 354 | # temporal sampling for AQE in DOVER 355 | self.samplers[stype] = UnifiedFrameSampler( 356 | sopt["clip_len"] // sopt["t_frag"], 357 | sopt["t_frag"], 358 | sopt["frame_interval"], 359 | sopt["num_clips"], 360 | ) 361 | 362 | def preprocess(self, filename): 363 | #try: 364 | ## Read Original Frames 365 | ## Process Frames 366 | data, frame_inds = spatial_temporal_view_decomposition( 367 | filename, 368 | self.sample_types, 369 | self.samplers, 370 | self.phase == "test", 371 | (self.phase == "train"), 372 | ) 373 | 374 | for k, v in data.items(): 375 | data[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute( 376 | 3, 0, 1, 2 377 | ).unsqueeze(0).cuda() 378 | 379 | data["num_clips"] = {} 380 | for stype, sopt in self.sample_types.items(): 381 | data["num_clips"][stype] = sopt["num_clips"] 382 | data["frame_inds"] = frame_inds 383 | # except: 384 | # # exception flow 385 | # return {"name": filename} 386 | # edit_name是technical 387 | return data 388 | -------------------------------------------------------------------------------- /vebench/blip_models/blip_pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from .med import BertConfig, BertModel, BertLMHeadModel 9 | from transformers import BertTokenizer 10 | import transformers 11 | transformers.logging.set_verbosity_error() 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from models.blip import create_vit, init_tokenizer, load_checkpoint 18 | 19 | class BLIP_Pretrain(nn.Module): 20 | def __init__(self, 21 | med_config = 'configs/bert_config.json', 22 | image_size = 224, 23 | vit = 'base', 24 | vit_grad_ckpt = False, 25 | vit_ckpt_layer = 0, 26 | embed_dim = 256, 27 | queue_size = 57600, 28 | momentum = 0.995, 29 | ): 30 | """ 31 | Args: 32 | med_config (str): path for the mixture of encoder-decoder model's configuration file 33 | image_size (int): input image size 34 | vit (str): model size of vision transformer 35 | """ 36 | super().__init__() 37 | 38 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) 39 | 40 | if vit=='base': 41 | checkpoint = torch.hub.load_state_dict_from_url( 42 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 43 | map_location="cpu", check_hash=True) 44 | state_dict = checkpoint["model"] 45 | msg = self.visual_encoder.load_state_dict(state_dict,strict=False) 46 | elif vit=='large': 47 | from timm.models.helpers import load_custom_pretrained 48 | from timm.models.vision_transformer import default_cfgs 49 | load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k']) 50 | 51 | self.tokenizer = init_tokenizer() 52 | encoder_config = BertConfig.from_json_file(med_config) 53 | encoder_config.encoder_width = vision_width 54 | self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False) 55 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 56 | 57 | text_width = self.text_encoder.config.hidden_size 58 | 59 | self.vision_proj = nn.Linear(vision_width, embed_dim) 60 | self.text_proj = nn.Linear(text_width, embed_dim) 61 | 62 | self.itm_head = nn.Linear(text_width, 2) 63 | 64 | # create momentum encoders 65 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 66 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 67 | self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False) 68 | self.text_proj_m = nn.Linear(text_width, embed_dim) 69 | 70 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 71 | [self.vision_proj,self.vision_proj_m], 72 | [self.text_encoder,self.text_encoder_m], 73 | [self.text_proj,self.text_proj_m], 74 | ] 75 | self.copy_params() 76 | 77 | # create the queue 78 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 79 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 80 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 81 | 82 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 83 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 84 | 85 | self.queue_size = queue_size 86 | self.momentum = momentum 87 | self.temp = nn.Parameter(0.07*torch.ones([])) 88 | 89 | # create the decoder 90 | decoder_config = BertConfig.from_json_file(med_config) 91 | decoder_config.encoder_width = vision_width 92 | self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config) 93 | self.text_decoder.resize_token_embeddings(len(self.tokenizer)) 94 | tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention') 95 | 96 | 97 | def forward(self, image, caption, alpha): 98 | with torch.no_grad(): 99 | self.temp.clamp_(0.001,0.5) 100 | 101 | image_embeds = self.visual_encoder(image) 102 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 103 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 104 | 105 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, 106 | return_tensors="pt").to(image.device) 107 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 108 | return_dict = True, mode = 'text') 109 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 110 | 111 | # get momentum features 112 | with torch.no_grad(): 113 | self._momentum_update() 114 | image_embeds_m = self.visual_encoder_m(image) 115 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 116 | image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 117 | 118 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 119 | return_dict = True, mode = 'text') 120 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 121 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 122 | 123 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp 124 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp 125 | 126 | sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) 127 | sim_targets.fill_diagonal_(1) 128 | 129 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 130 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 131 | 132 | sim_i2t = image_feat @ text_feat_all / self.temp 133 | sim_t2i = text_feat @ image_feat_all / self.temp 134 | 135 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 136 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 137 | 138 | loss_ita = (loss_i2t+loss_t2i)/2 139 | 140 | self._dequeue_and_enqueue(image_feat_m, text_feat_m) 141 | 142 | ###============== Image-text Matching ===================### 143 | encoder_input_ids = text.input_ids.clone() 144 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 145 | 146 | # forward the positve image-text pair 147 | bs = image.size(0) 148 | output_pos = self.text_encoder(encoder_input_ids, 149 | attention_mask = text.attention_mask, 150 | encoder_hidden_states = image_embeds, 151 | encoder_attention_mask = image_atts, 152 | return_dict = True, 153 | ) 154 | with torch.no_grad(): 155 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 156 | weights_t2i.fill_diagonal_(0) 157 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 158 | weights_i2t.fill_diagonal_(0) 159 | 160 | # select a negative image for each text 161 | image_embeds_neg = [] 162 | for b in range(bs): 163 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 164 | image_embeds_neg.append(image_embeds[neg_idx]) 165 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 166 | 167 | # select a negative text for each image 168 | text_ids_neg = [] 169 | text_atts_neg = [] 170 | for b in range(bs): 171 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 172 | text_ids_neg.append(encoder_input_ids[neg_idx]) 173 | text_atts_neg.append(text.attention_mask[neg_idx]) 174 | 175 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 176 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 177 | 178 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 179 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 180 | 181 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 182 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 183 | 184 | output_neg = self.text_encoder(text_ids_all, 185 | attention_mask = text_atts_all, 186 | encoder_hidden_states = image_embeds_all, 187 | encoder_attention_mask = image_atts_all, 188 | return_dict = True, 189 | ) 190 | 191 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 192 | vl_output = self.itm_head(vl_embeddings) 193 | 194 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 195 | dim=0).to(image.device) 196 | loss_itm = F.cross_entropy(vl_output, itm_labels) 197 | 198 | ##================= LM ========================## 199 | decoder_input_ids = text.input_ids.clone() 200 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 201 | decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) 202 | 203 | decoder_output = self.text_decoder(decoder_input_ids, 204 | attention_mask = text.attention_mask, 205 | encoder_hidden_states = image_embeds, 206 | encoder_attention_mask = image_atts, 207 | labels = decoder_targets, 208 | return_dict = True, 209 | ) 210 | 211 | loss_lm = decoder_output.loss 212 | return loss_ita, loss_itm, loss_lm 213 | 214 | 215 | 216 | @torch.no_grad() 217 | def copy_params(self): 218 | for model_pair in self.model_pairs: 219 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 220 | param_m.data.copy_(param.data) # initialize 221 | param_m.requires_grad = False # not update by gradient 222 | 223 | 224 | @torch.no_grad() 225 | def _momentum_update(self): 226 | for model_pair in self.model_pairs: 227 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 228 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 229 | 230 | 231 | @torch.no_grad() 232 | def _dequeue_and_enqueue(self, image_feat, text_feat): 233 | # gather keys before updating queue 234 | image_feats = concat_all_gather(image_feat) 235 | text_feats = concat_all_gather(text_feat) 236 | 237 | batch_size = image_feats.shape[0] 238 | 239 | ptr = int(self.queue_ptr) 240 | assert self.queue_size % batch_size == 0 # for simplicity 241 | 242 | # replace the keys at ptr (dequeue and enqueue) 243 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 244 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 245 | ptr = (ptr + batch_size) % self.queue_size # move pointer 246 | 247 | self.queue_ptr[0] = ptr 248 | 249 | 250 | def blip_pretrain(**kwargs): 251 | model = BLIP_Pretrain(**kwargs) 252 | return model 253 | 254 | 255 | @torch.no_grad() 256 | def concat_all_gather(tensor): 257 | """ 258 | Performs all_gather operation on the provided tensors. 259 | *** Warning ***: torch.distributed.all_gather has no gradient. 260 | """ 261 | tensors_gather = [torch.ones_like(tensor) 262 | for _ in range(torch.distributed.get_world_size())] 263 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 264 | 265 | output = torch.cat(tensors_gather, dim=0) 266 | return output 267 | 268 | 269 | from typing import List 270 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): 271 | uninitialized_encoder_weights: List[str] = [] 272 | if decoder.__class__ != encoder.__class__: 273 | logger.info( 274 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 275 | ) 276 | 277 | def tie_encoder_to_decoder_recursively( 278 | decoder_pointer: nn.Module, 279 | encoder_pointer: nn.Module, 280 | module_name: str, 281 | uninitialized_encoder_weights: List[str], 282 | skip_key: str, 283 | depth=0, 284 | ): 285 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 286 | encoder_pointer, nn.Module 287 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 288 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 289 | assert hasattr(encoder_pointer, "weight") 290 | encoder_pointer.weight = decoder_pointer.weight 291 | if hasattr(decoder_pointer, "bias"): 292 | assert hasattr(encoder_pointer, "bias") 293 | encoder_pointer.bias = decoder_pointer.bias 294 | print(module_name+' is tied') 295 | return 296 | 297 | encoder_modules = encoder_pointer._modules 298 | decoder_modules = decoder_pointer._modules 299 | if len(decoder_modules) > 0: 300 | assert ( 301 | len(encoder_modules) > 0 302 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 303 | 304 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) 305 | encoder_layer_pos = 0 306 | for name, module in decoder_modules.items(): 307 | if name.isdigit(): 308 | encoder_name = str(int(name) + encoder_layer_pos) 309 | decoder_name = name 310 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( 311 | encoder_modules 312 | ) != len(decoder_modules): 313 | # this can happen if the name corresponds to the position in a list module list of layers 314 | # in this case the decoder has added a cross-attention that the encoder does not have 315 | # thus skip this step and subtract one layer pos from encoder 316 | encoder_layer_pos -= 1 317 | continue 318 | elif name not in encoder_modules: 319 | continue 320 | elif depth > 500: 321 | raise ValueError( 322 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 323 | ) 324 | else: 325 | decoder_name = encoder_name = name 326 | tie_encoder_to_decoder_recursively( 327 | decoder_modules[decoder_name], 328 | encoder_modules[encoder_name], 329 | module_name + "/" + name, 330 | uninitialized_encoder_weights, 331 | skip_key, 332 | depth=depth + 1, 333 | ) 334 | all_encoder_weights.remove(module_name + "/" + encoder_name) 335 | 336 | uninitialized_encoder_weights += list(all_encoder_weights) 337 | 338 | # tie weights recursively 339 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key) 340 | -------------------------------------------------------------------------------- /vebench/models/backbone/uniformer_backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from collections import OrderedDict 4 | 5 | from timm.models.layers import DropPath 6 | import torch 7 | from torch import nn 8 | from torch.nn import MultiheadAttention 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint as checkpoint 11 | 12 | # import logging as logging 13 | 14 | # logger = logging.get_logger(__name__) 15 | 16 | 17 | 18 | class LayerNorm(nn.LayerNorm): 19 | """Subclass torch's LayerNorm to handle fp16.""" 20 | 21 | def forward(self, x): 22 | orig_type = x.dtype 23 | ret = super().forward(x.type(torch.float32)) 24 | return ret.type(orig_type) 25 | 26 | 27 | class QuickGELU(nn.Module): 28 | def forward(self, x): 29 | return x * torch.sigmoid(1.702 * x) 30 | 31 | 32 | class Local_MHRA(nn.Module): 33 | def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3): 34 | super().__init__() 35 | 36 | padding = pos_kernel_size // 2 37 | re_d_model = int(d_model // dw_reduction) 38 | self.pos_embed = nn.Sequential( 39 | nn.BatchNorm3d(d_model), 40 | nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0), 41 | nn.Conv3d(re_d_model, re_d_model, kernel_size=(pos_kernel_size, 1, 1), stride=(1, 1, 1), 42 | padding=(padding, 0, 0), groups=re_d_model), 43 | nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0), 44 | ) 45 | 46 | # init zero 47 | # logger.info('Init zero for Conv in pos_emb') 48 | nn.init.constant_(self.pos_embed[3].weight, 0) 49 | nn.init.constant_(self.pos_embed[3].bias, 0) 50 | 51 | def forward(self, x): 52 | return self.pos_embed(x) 53 | 54 | 55 | class ResidualAttentionBlock(nn.Module): 56 | def __init__( 57 | self, d_model, n_head, attn_mask=None, drop_path=0.0, 58 | dw_reduction=1.5, no_lmhra=False, double_lmhra=True 59 | ): 60 | super().__init__() 61 | 62 | self.n_head = n_head 63 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 64 | # logger.info(f'Drop path rate: {drop_path}') 65 | 66 | self.no_lmhra = no_lmhra 67 | self.double_lmhra = double_lmhra 68 | # logger.info(f'No L_MHRA: {no_lmhra}') 69 | # logger.info(f'Double L_MHRA: {double_lmhra}') 70 | if not no_lmhra: 71 | self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction) 72 | if double_lmhra: 73 | self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction) 74 | 75 | # spatial 76 | self.attn = MultiheadAttention(d_model, n_head) 77 | self.ln_1 = LayerNorm(d_model) 78 | self.mlp = nn.Sequential(OrderedDict([ 79 | ("c_fc", nn.Linear(d_model, d_model * 4)), 80 | ("gelu", QuickGELU()), 81 | ("c_proj", nn.Linear(d_model * 4, d_model)) 82 | ])) 83 | self.ln_2 = LayerNorm(d_model) 84 | self.attn_mask = attn_mask 85 | 86 | def attention(self, x): 87 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 88 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 89 | 90 | def forward(self, x, T=8, use_checkpoint=False): 91 | # x: 1+HW, NT, C 92 | if not self.no_lmhra: 93 | # Local MHRA 94 | tmp_x = x[1:, :, :] 95 | L, NT, C = tmp_x.shape 96 | N = NT // T 97 | H = W = int(L ** 0.5) 98 | tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous() 99 | tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x)) 100 | tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C) 101 | x = torch.cat([x[:1, :, :], tmp_x], dim=0) 102 | # MHSA 103 | if use_checkpoint: 104 | attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x)) 105 | x = x + self.drop_path(attn_out) 106 | else: 107 | x = x + self.drop_path(self.attention(self.ln_1(x))) 108 | # Local MHRA 109 | if not self.no_lmhra and self.double_lmhra: 110 | tmp_x = x[1:, :, :] 111 | tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous() 112 | tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x)) 113 | tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C) 114 | x = torch.cat([x[:1, :, :], tmp_x], dim=0) 115 | # FFN 116 | if use_checkpoint: 117 | mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x)) 118 | x = x + self.drop_path(mlp_out) 119 | else: 120 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 121 | return x 122 | 123 | 124 | class Extractor(nn.Module): 125 | def __init__( 126 | self, d_model, n_head, attn_mask=None, 127 | mlp_factor=4.0, dropout=0.0, drop_path=0.0, 128 | ): 129 | super().__init__() 130 | 131 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 132 | # logger.info(f'Drop path rate: {drop_path}') 133 | self.attn = nn.MultiheadAttention(d_model, n_head) 134 | self.ln_1 = nn.LayerNorm(d_model) 135 | d_mlp = round(mlp_factor * d_model) 136 | self.mlp = nn.Sequential(OrderedDict([ 137 | ("c_fc", nn.Linear(d_model, d_mlp)), 138 | ("gelu", QuickGELU()), 139 | ("dropout", nn.Dropout(dropout)), 140 | ("c_proj", nn.Linear(d_mlp, d_model)) 141 | ])) 142 | self.ln_2 = nn.LayerNorm(d_model) 143 | self.ln_3 = nn.LayerNorm(d_model) 144 | self.attn_mask = attn_mask 145 | 146 | # zero init 147 | nn.init.xavier_uniform_(self.attn.in_proj_weight) 148 | nn.init.constant_(self.attn.out_proj.weight, 0.) 149 | nn.init.constant_(self.attn.out_proj.bias, 0.) 150 | nn.init.xavier_uniform_(self.mlp[0].weight) 151 | nn.init.constant_(self.mlp[-1].weight, 0.) 152 | nn.init.constant_(self.mlp[-1].bias, 0.) 153 | 154 | def attention(self, x, y): 155 | d_model = self.ln_1.weight.size(0) 156 | q = (x @ self.attn.in_proj_weight[:d_model].T) + self.attn.in_proj_bias[:d_model] 157 | 158 | k = (y @ self.attn.in_proj_weight[d_model:-d_model].T) + self.attn.in_proj_bias[d_model:-d_model] 159 | v = (y @ self.attn.in_proj_weight[-d_model:].T) + self.attn.in_proj_bias[-d_model:] 160 | Tx, Ty, N = q.size(0), k.size(0), q.size(1) 161 | q = q.view(Tx, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3) 162 | k = k.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3) 163 | v = v.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3) 164 | aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim ** 0.5)) 165 | 166 | aff = aff.softmax(dim=-1) 167 | out = aff @ v 168 | out = out.permute(2, 0, 1, 3).flatten(2) 169 | out = self.attn.out_proj(out) 170 | return out 171 | 172 | def forward(self, x, y): 173 | x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y))) 174 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 175 | return x 176 | 177 | 178 | class Transformer(nn.Module): 179 | def __init__( 180 | self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0., 181 | use_checkpoint=False, checkpoint_num=[0], t_size=8, dw_reduction=2, 182 | no_lmhra=False, double_lmhra=True, 183 | return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 184 | n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0., 185 | mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], 186 | cls_dropout=0.5, num_classes=400, 187 | frozen=False, 188 | ): 189 | super().__init__() 190 | self.T = t_size 191 | self.return_list = return_list 192 | # backbone 193 | b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)] 194 | self.resblocks = nn.ModuleList([ 195 | ResidualAttentionBlock( 196 | width, heads, attn_mask, 197 | drop_path=b_dpr[i], 198 | dw_reduction=dw_reduction, 199 | no_lmhra=no_lmhra, 200 | double_lmhra=double_lmhra, 201 | ) for i in range(layers) 202 | ]) 203 | # checkpoint 204 | self.use_checkpoint = use_checkpoint 205 | self.checkpoint_num = checkpoint_num 206 | # logger.info(f'Use checkpoint: {self.use_checkpoint}') 207 | # logger.info(f'Checkpoint number: {self.checkpoint_num}') 208 | 209 | # global block 210 | assert n_layers == len(return_list) 211 | self.frozen = frozen 212 | # self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim)) 213 | self.dpe = nn.ModuleList([ 214 | nn.Conv3d(n_dim, n_dim, kernel_size=3, stride=1, padding=1, bias=True, groups=n_dim) 215 | for i in range(n_layers) 216 | ]) 217 | for m in self.dpe: 218 | nn.init.constant_(m.bias, 0.) 219 | # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 220 | # self.dec = nn.ModuleList([ 221 | # Extractor( 222 | # n_dim, n_head, mlp_factor=mlp_factor, 223 | # dropout=mlp_dropout[i], drop_path=dpr[i], 224 | # ) for i in range(n_layers) 225 | # ]) 226 | # projection 227 | # self.proj = nn.Sequential( 228 | # nn.LayerNorm(n_dim), 229 | # nn.Dropout(cls_dropout), 230 | # nn.Linear(n_dim, num_classes), 231 | # ) 232 | # if not self.frozen: 233 | # self.balance = nn.Parameter(torch.zeros((n_dim))) 234 | # self.sigmoid = nn.Sigmoid() 235 | 236 | def forward(self, x): # 577 80 1024 237 | T_down = self.T 238 | L, NT, C = x.shape 239 | N = NT // T_down 240 | H = W = int((L - 1) ** 0.5) 241 | # cls_token = self.temporal_cls_token.repeat(1, N, 1) 242 | feature = torch.zeros(N, C, T_down, H, W).cuda() 243 | j = -1 244 | for i, resblock in enumerate(self.resblocks): 245 | if self.use_checkpoint and i < self.checkpoint_num[0]: 246 | x = resblock(x, self.T, use_checkpoint=True) 247 | else: 248 | x = resblock(x, T_down) 249 | if i in self.return_list: 250 | j += 1 251 | tmp_x = x.clone() 252 | tmp_x = tmp_x.view(L, N, T_down, C) 253 | # dpe 254 | _, tmp_feats = tmp_x[:1], tmp_x[1:] 255 | tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W) 256 | tmp_feats = self.dpe[j](tmp_feats.clone()).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous() 257 | tmp_x[1:] = tmp_x[1:] + tmp_feats 258 | # global block 259 | # tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C 260 | feature += tmp_x[1:].permute(1, 3, 2, 0).reshape(N, C, T_down, H, W) 261 | return feature / 4 262 | # cls_token = self.dec[j](cls_token, tmp_x) 263 | # 264 | # if self.frozen: 265 | # return self.proj(cls_token[0, :, :]) 266 | # else: 267 | # weight = self.sigmoid(self.balance) 268 | # residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C 269 | # 270 | # return torch.cat([self.proj((1 - weight) * cls_token[0, :, :] + weight * residual).reshape(N,710,1,1,1),torch.zeros(N,58,1,1,1).cuda()],dim=1) 271 | 272 | 273 | class VisionTransformer(nn.Module): 274 | def __init__( 275 | self, 276 | # backbone 277 | input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0., 278 | use_checkpoint=False, checkpoint_num=[0], t_size=8, kernel_size=3, dw_reduction=1.5, 279 | temporal_downsample=True, 280 | no_lmhra=-False, double_lmhra=True, 281 | # global block 282 | return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 283 | n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0., 284 | mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], 285 | cls_dropout=0.5, num_classes=400, 286 | frozen=False 287 | ): 288 | super().__init__() 289 | self.input_resolution = input_resolution 290 | self.output_dim = output_dim 291 | padding = (kernel_size - 1) // 2 292 | if temporal_downsample: 293 | self.conv1 = nn.Conv3d(3, width, (kernel_size, patch_size, patch_size), (2, patch_size, patch_size), 294 | (padding, 0, 0), bias=False) 295 | t_size = t_size // 2 296 | else: 297 | self.conv1 = nn.Conv3d(3, width, (1, patch_size, patch_size), (1, patch_size, patch_size), (0, 0, 0), 298 | bias=False) 299 | 300 | scale = width ** -0.5 301 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 302 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 303 | self.ln_pre = LayerNorm(width) 304 | 305 | self.transformer = Transformer( 306 | width, layers, heads, dw_reduction=dw_reduction, 307 | backbone_drop_path_rate=backbone_drop_path_rate, 308 | use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size, 309 | no_lmhra=no_lmhra, double_lmhra=double_lmhra, 310 | return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head, 311 | mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout, 312 | cls_dropout=cls_dropout, num_classes=num_classes, 313 | frozen=frozen, 314 | ) 315 | 316 | def forward(self, x): # (5,3,64,336,336) 317 | x = self.conv1(x) # shape = [*, width, grid, grid] 318 | N, C, T, H, W = x.shape 319 | x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C) 320 | 321 | x = torch.cat( 322 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 323 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 324 | x = x + self.positional_embedding.to(x.dtype) 325 | x = self.ln_pre(x) 326 | 327 | x = x.permute(1, 0, 2) # NLD -> LND #577,160,1024 328 | out = self.transformer(x) # 10,710 329 | return out 330 | 331 | 332 | def inflate_weight(weight_2d, time_dim, center=True): 333 | # logger.info(f'Init center: {center}') 334 | if center: 335 | weight_3d = torch.zeros(*weight_2d.shape) 336 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 337 | middle_idx = time_dim // 2 338 | weight_3d[:, :, middle_idx, :, :] = weight_2d 339 | else: 340 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 341 | weight_3d = weight_3d / time_dim 342 | return weight_3d 343 | 344 | 345 | def load_state_dict(model, state_dict): 346 | state_dict_3d = model.state_dict() 347 | new_state_dict = OrderedDict() 348 | for k in state_dict.keys(): 349 | if k[9:] not in state_dict_3d: 350 | continue 351 | if state_dict[k].shape != state_dict_3d[k[9:]].shape: 352 | if len(state_dict_3d[k[9:]].shape) <= 2: 353 | new_state_dict[k[9:]] = state_dict[k] 354 | # logger.info(f'Ignore: {k}') 355 | continue 356 | # logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') 357 | time_dim = state_dict_3d[k[9:]].shape[2] 358 | new_state_dict[k[9:]] = inflate_weight(state_dict[k], time_dim) 359 | else: 360 | new_state_dict[k[9:]] = state_dict[k] 361 | model.load_state_dict(new_state_dict, strict=False) 362 | 363 | 364 | def uniformerv2_b16( 365 | pretrained=True, use_checkpoint=False, checkpoint_num=[0], 366 | t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0., 367 | temporal_downsample=True, 368 | no_lmhra=False, double_lmhra=True, 369 | return_list=[8, 9, 10, 11], 370 | n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0., 371 | mlp_dropout=[0.5, 0.5, 0.5, 0.5], 372 | cls_dropout=0.5, num_classes=400, 373 | frozen=False, 374 | ): 375 | model = VisionTransformer( 376 | input_resolution=224, 377 | patch_size=16, 378 | width=768, 379 | layers=12, 380 | heads=12, 381 | output_dim=512, 382 | use_checkpoint=use_checkpoint, 383 | checkpoint_num=checkpoint_num, 384 | t_size=t_size, 385 | dw_reduction=dw_reduction, 386 | backbone_drop_path_rate=backbone_drop_path_rate, 387 | temporal_downsample=temporal_downsample, 388 | no_lmhra=no_lmhra, 389 | double_lmhra=double_lmhra, 390 | return_list=return_list, 391 | n_layers=n_layers, 392 | n_dim=n_dim, 393 | n_head=n_head, 394 | mlp_factor=mlp_factor, 395 | drop_path_rate=drop_path_rate, 396 | mlp_dropout=mlp_dropout, 397 | cls_dropout=cls_dropout, 398 | num_classes=num_classes, 399 | frozen=frozen, 400 | ) 401 | 402 | if pretrained: 403 | # logger.info('load pretrained weights') 404 | state_dict = torch.load(os.path.join(pretrained, 'k400+k710_uniformerv2_b16_8x224.pth'), map_location='cpu') 405 | load_state_dict(model, state_dict) 406 | return model 407 | -------------------------------------------------------------------------------- /vebench/blip_models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | from timm.models.vision_transformer import Attention as TemporalAttention 21 | 22 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 23 | 24 | class BottleNeckAdapter(nn.Module): 25 | def __init__(self,in_feature,downsample_rate): 26 | super().__init__() 27 | hidden_state_1=in_feature//downsample_rate 28 | hidden_state_2=hidden_state_1//downsample_rate 29 | self.downsample_1=nn.Linear(in_feature,hidden_state_1) 30 | self.downsample_2=nn.Linear(hidden_state_1,hidden_state_2) 31 | #self.gelu=nn.functional.gelu() 32 | self.upsample_1 = nn.Linear(hidden_state_2, hidden_state_1) 33 | self.upsample_2=nn.Linear(hidden_state_1,in_feature) 34 | 35 | def forward(self,x): 36 | y=self.downsample_1(x) 37 | y = self.downsample_2(y) 38 | y=nn.functional.gelu(y) 39 | y=self.upsample_1(y) 40 | y = self.upsample_2(y) 41 | return x+y 42 | 43 | from einops import rearrange 44 | 45 | 46 | class TreeDConvAdapter(nn.Module): 47 | def __init__( 48 | self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., 49 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): 50 | super().__init__() 51 | self.norm1 = norm_layer(dim) 52 | if ws is None: 53 | self.attn = TemporalAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop) 54 | # elif ws == 1: 55 | # self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) 56 | # else: 57 | # self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) 58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 59 | self.norm2 = norm_layer(dim) 60 | mlp_hidden_dim = int(dim * mlp_ratio) 61 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 62 | self.temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1,groups=dim) 63 | self.apply(self._init_weights) 64 | 65 | def _init_weights(self, m): 66 | if hasattr(m,"weight")and m.weight is not None: 67 | trunc_normal_(m.weight, mean=0.0, std=0.01) 68 | if hasattr(m,"bias") and m.bias is not None: 69 | nn.init.constant_(m.bias, 0) 70 | 71 | def forward(self, x,B): 72 | # x: (B*T, h*w, C) 73 | origin_x=x 74 | x = x + self.drop_path(self.attn(self.norm1(x))) 75 | # spatial 76 | x = self.mlp(self.norm2(x)) 77 | # 78 | # temporal 79 | x = rearrange(x, '(b t) l c -> (b l) c t', b=B) 80 | x = self.temporal_conv(x) 81 | x = rearrange(x, '(b l) c t -> (b t) l c', b=B) 82 | # 83 | # # output 84 | x = origin_x + self.drop_path(x) 85 | return x 86 | 87 | class Mlp(nn.Module): 88 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 89 | """ 90 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 91 | super().__init__() 92 | out_features = out_features or in_features 93 | hidden_features = hidden_features or in_features 94 | self.fc1 = nn.Linear(in_features, hidden_features) 95 | self.act = act_layer() 96 | self.fc2 = nn.Linear(hidden_features, out_features) 97 | self.drop = nn.Dropout(drop) 98 | 99 | def forward(self, x): 100 | x = self.fc1(x) 101 | x = self.act(x) 102 | x = self.drop(x) 103 | x = self.fc2(x) 104 | x = self.drop(x) 105 | return x 106 | 107 | 108 | class Attention(nn.Module): 109 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 110 | super().__init__() 111 | self.num_heads = num_heads 112 | head_dim = dim // num_heads 113 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 114 | self.scale = qk_scale or head_dim ** -0.5 115 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 116 | self.attn_drop = nn.Dropout(attn_drop) 117 | self.proj = nn.Linear(dim, dim) 118 | self.proj_drop = nn.Dropout(proj_drop) 119 | self.attn_gradients = None 120 | self.attention_map = None 121 | 122 | def save_attn_gradients(self, attn_gradients): 123 | self.attn_gradients = attn_gradients 124 | 125 | def get_attn_gradients(self): 126 | return self.attn_gradients 127 | 128 | def save_attention_map(self, attention_map): 129 | self.attention_map = attention_map 130 | 131 | def get_attention_map(self): 132 | return self.attention_map 133 | 134 | def forward(self, x, register_hook=False): 135 | B, N, C = x.shape 136 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 137 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 138 | 139 | attn = (q @ k.transpose(-2, -1)) * self.scale 140 | attn = attn.softmax(dim=-1) 141 | attn = self.attn_drop(attn) 142 | 143 | if register_hook: 144 | self.save_attention_map(attn) 145 | attn.register_hook(self.save_attn_gradients) 146 | 147 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 148 | x = self.proj(x) 149 | x = self.proj_drop(x) 150 | return x 151 | 152 | 153 | class Block(nn.Module): 154 | 155 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 156 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False,depth=-1): 157 | super().__init__() 158 | self.norm1 = norm_layer(dim) 159 | self.attn = Attention( 160 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 161 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 162 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 163 | self.norm2 = norm_layer(dim) 164 | mlp_hidden_dim = int(dim * mlp_ratio) 165 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 166 | self.depth=depth 167 | # if self.depth in [17,19,21,22,23]: 168 | # self.adapter = TreeDConvAdapter(dim=dim,num_heads=8,mlp_ratio=0.5) 169 | 170 | if use_grad_checkpointing: 171 | self.attn = checkpoint_wrapper(self.attn) 172 | self.mlp = checkpoint_wrapper(self.mlp) 173 | 174 | def forward(self, x, number,B=8,register_hook=False): 175 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 176 | x = x + self.drop_path(self.mlp(self.norm2(x))) 177 | # if self.depth in [17,19,21,22,23]: 178 | # x=self.adapter(x,B) 179 | return x 180 | 181 | 182 | class VisionTransformer(nn.Module): 183 | """ Vision Transformer 184 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 185 | https://arxiv.org/abs/2010.11929 186 | """ 187 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 188 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 189 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 190 | use_grad_checkpointing=False, ckpt_layer=0): 191 | """ 192 | Args: 193 | img_size (int, tuple): input image size 194 | patch_size (int, tuple): patch size 195 | in_chans (int): number of input channels 196 | num_classes (int): number of classes for classification head 197 | embed_dim (int): embedding dimension 198 | depth (int): depth of transformer 199 | num_heads (int): number of attention heads 200 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 201 | qkv_bias (bool): enable bias for qkv if True 202 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 203 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 204 | drop_rate (float): dropout rate 205 | attn_drop_rate (float): attention dropout rate 206 | drop_path_rate (float): stochastic depth rate 207 | norm_layer: (nn.Module): normalization layer 208 | """ 209 | super().__init__() 210 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 211 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 212 | 213 | self.patch_embed = PatchEmbed( 214 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 215 | 216 | num_patches = self.patch_embed.num_patches 217 | 218 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 219 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 220 | self.pos_drop = nn.Dropout(p=drop_rate) 221 | 222 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 223 | self.blocks = nn.ModuleList([ 224 | Block( 225 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 226 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 227 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer), 228 | depth=i 229 | ) 230 | for i in range(depth)]) 231 | self.norm = norm_layer(embed_dim) 232 | 233 | trunc_normal_(self.pos_embed, std=.02) 234 | trunc_normal_(self.cls_token, std=.02) 235 | self.apply(self._init_weights) 236 | 237 | def _init_weights(self, m): 238 | if isinstance(m, nn.Linear): 239 | trunc_normal_(m.weight, std=.02) 240 | if isinstance(m, nn.Linear) and m.bias is not None: 241 | nn.init.constant_(m.bias, 0) 242 | elif isinstance(m, nn.LayerNorm): 243 | nn.init.constant_(m.bias, 0) 244 | nn.init.constant_(m.weight, 1.0) 245 | 246 | @torch.jit.ignore 247 | def no_weight_decay(self): 248 | return {'pos_embed', 'cls_token'} 249 | 250 | def forward(self, video, register_blk=-1):#temporal 251 | B,C,L,W,H = video.shape 252 | temporal = [] 253 | video=self.patch_embed(video.reshape(-1,C,W,H)) 254 | cls_tokens = self.cls_token.expand(B*L, -1, -1) 255 | video=torch.cat((cls_tokens,video), dim=1) 256 | 257 | video = video + self.pos_embed[:, :video.size(1), :] 258 | x=self.pos_drop(video) 259 | #x=video.reshape(B,L,video.shape[-2],video.shape[-1]) 260 | for i,blk in enumerate(self.blocks): 261 | x = blk(x, i,B, register_blk==i) 262 | x = self.norm(x).reshape(B,L,x.shape[-2],x.shape[-1]) 263 | # video_mean=rearrange(x, '(b t) l c -> b t l c', b=B).mean(1) 264 | # frame=rearrange(x, '(b t) l c -> b t l c', b=B) 265 | # for i in range(video.shape[2]): 266 | # image = video[:, :, i, ...] 267 | # image_embeds = self.visual_encoder(image) 268 | # temporal.append(image_embeds.unsqueeze(1)) 269 | # temporal = torch.cat(temporal, dim=1) 270 | 271 | 272 | 273 | # x = self.patch_embed(x) 274 | # 275 | # cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 276 | # x = torch.cat((cls_tokens, x), dim=1) 277 | # 278 | # x = x + self.pos_embed[:,:x.size(1),:] 279 | # x = self.pos_drop(x) 280 | # 281 | # for i,blk in enumerate(self.blocks): 282 | # x = blk(x, register_blk==i) 283 | # x = self.norm(x) 284 | 285 | return x 286 | 287 | @torch.jit.ignore() 288 | def load_pretrained(self, checkpoint_path, prefix=''): 289 | _load_weights(self, checkpoint_path, prefix) 290 | 291 | 292 | @torch.no_grad() 293 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 294 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 295 | """ 296 | import numpy as np 297 | 298 | def _n2p(w, t=True): 299 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 300 | w = w.flatten() 301 | if t: 302 | if w.ndim == 4: 303 | w = w.transpose([3, 2, 0, 1]) 304 | elif w.ndim == 3: 305 | w = w.transpose([2, 0, 1]) 306 | elif w.ndim == 2: 307 | w = w.transpose([1, 0]) 308 | return torch.from_numpy(w) 309 | 310 | w = np.load(checkpoint_path) 311 | if not prefix and 'opt/target/embedding/kernel' in w: 312 | prefix = 'opt/target/' 313 | 314 | if hasattr(model.patch_embed, 'backbone'): 315 | # hybrid 316 | backbone = model.patch_embed.backbone 317 | stem_only = not hasattr(backbone, 'stem') 318 | stem = backbone if stem_only else backbone.stem 319 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 320 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 321 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 322 | if not stem_only: 323 | for i, stage in enumerate(backbone.stages): 324 | for j, block in enumerate(stage.blocks): 325 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 326 | for r in range(3): 327 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 328 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 329 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 330 | if block.downsample is not None: 331 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 332 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 333 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 334 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 335 | else: 336 | embed_conv_w = adapt_input_conv( 337 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 338 | model.patch_embed.proj.weight.copy_(embed_conv_w) 339 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 340 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 341 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 342 | if pos_embed_w.shape != model.pos_embed.shape: 343 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 344 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 345 | model.pos_embed.copy_(pos_embed_w) 346 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 347 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 348 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 349 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 350 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 351 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 352 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 353 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 354 | for i, block in enumerate(model.blocks.children()): 355 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 356 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 357 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 358 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 359 | block.attn.qkv.weight.copy_(torch.cat([ 360 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 361 | block.attn.qkv.bias.copy_(torch.cat([ 362 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 363 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 364 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 365 | for r in range(2): 366 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 367 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 368 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 369 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 370 | 371 | 372 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 373 | # interpolate position embedding 374 | embedding_size = pos_embed_checkpoint.shape[-1] 375 | num_patches = visual_encoder.patch_embed.num_patches 376 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 377 | # height (== width) for the checkpoint position embedding 378 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 379 | # height (== width) for the new position embedding 380 | new_size = int(num_patches ** 0.5) 381 | 382 | if orig_size!=new_size: 383 | # class_token and dist_token are kept unchanged 384 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 385 | # only the position tokens are interpolated 386 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 387 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 388 | pos_tokens = torch.nn.functional.interpolate( 389 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 390 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 391 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 392 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 393 | 394 | return new_pos_embed 395 | else: 396 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /vebench/models/backbone/conv_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_, DropPath 5 | import os 6 | 7 | 8 | class GRN(nn.Module): 9 | """ GRN (Global Response Normalization) layer 10 | """ 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 14 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 15 | 16 | def forward(self, x): 17 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 18 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 19 | return self.gamma * (x * Nx) + self.beta + x 20 | 21 | class Block(nn.Module): 22 | r""" ConvNeXt Block. There are two equivalent implementations: 23 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 24 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 25 | We use (2) as we find it slightly faster in PyTorch 26 | 27 | Args: 28 | dim (int): Number of input channels. 29 | drop_path (float): Stochastic depth rate. Default: 0.0 30 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 31 | """ 32 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 33 | super().__init__() 34 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 35 | self.norm = LayerNorm(dim, eps=1e-6) 36 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 37 | self.act = nn.GELU() 38 | self.pwconv2 = nn.Linear(4 * dim, dim) 39 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 40 | requires_grad=True) if layer_scale_init_value > 0 else None 41 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 42 | 43 | def forward(self, x): 44 | input = x 45 | x = self.dwconv(x) 46 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 47 | x = self.norm(x) 48 | x = self.pwconv1(x) 49 | x = self.act(x) 50 | x = self.pwconv2(x) 51 | if self.gamma is not None: 52 | x = self.gamma * x 53 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 54 | 55 | x = input + self.drop_path(x) 56 | return x 57 | 58 | class ConvNeXt(nn.Module): 59 | r""" ConvNeXt 60 | A PyTorch impl of : `A ConvNet for the 2020s` - 61 | https://arxiv.org/pdf/2201.03545.pdf 62 | Args: 63 | in_chans (int): Number of input image channels. Default: 3 64 | num_classes (int): Number of classes for classification head. Default: 1000 65 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 66 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 67 | drop_path_rate (float): Stochastic depth rate. Default: 0. 68 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 69 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 70 | """ 71 | def __init__(self, in_chans=3, num_classes=1000, 72 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 73 | layer_scale_init_value=1e-6, head_init_scale=1., 74 | ): 75 | super().__init__() 76 | 77 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 78 | stem = nn.Sequential( 79 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 80 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 81 | ) 82 | self.downsample_layers.append(stem) 83 | for i in range(3): 84 | downsample_layer = nn.Sequential( 85 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 86 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 87 | ) 88 | self.downsample_layers.append(downsample_layer) 89 | 90 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 91 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 92 | cur = 0 93 | for i in range(4): 94 | stage = nn.Sequential( 95 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 96 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 97 | ) 98 | self.stages.append(stage) 99 | cur += depths[i] 100 | 101 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 102 | self.head = nn.Linear(dims[-1], num_classes) 103 | 104 | self.apply(self._init_weights) 105 | self.head.weight.data.mul_(head_init_scale) 106 | self.head.bias.data.mul_(head_init_scale) 107 | 108 | def _init_weights(self, m): 109 | if isinstance(m, (nn.Conv2d, nn.Linear)): 110 | trunc_normal_(m.weight, std=.02) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward_features(self, x): 114 | for i in range(4): 115 | x = self.downsample_layers[i](x) 116 | x = self.stages[i](x) 117 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 118 | 119 | def forward(self, x): 120 | x = self.forward_features(x) 121 | x = self.head(x) 122 | return x 123 | 124 | class LayerNorm(nn.Module): 125 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 126 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 127 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 128 | with shape (batch_size, channels, height, width). 129 | """ 130 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 131 | super().__init__() 132 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 133 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 134 | self.eps = eps 135 | self.data_format = data_format 136 | if self.data_format not in ["channels_last", "channels_first"]: 137 | raise NotImplementedError 138 | self.normalized_shape = (normalized_shape, ) 139 | 140 | def forward(self, x): 141 | if self.data_format == "channels_last": 142 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 143 | elif self.data_format == "channels_first": 144 | u = x.mean(1, keepdim=True) 145 | s = (x - u).pow(2).mean(1, keepdim=True) 146 | x = (x - u) / torch.sqrt(s + self.eps) 147 | if len(x.shape) == 4: 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | elif len(x.shape) == 5: 150 | x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] 151 | return x 152 | 153 | 154 | class Block3D(nn.Module): 155 | r""" ConvNeXt Block. There are two equivalent implementations: 156 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 157 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 158 | We use (2) as we find it slightly faster in PyTorch 159 | 160 | Args: 161 | dim (int): Number of input channels. 162 | drop_path (float): Stochastic depth rate. Default: 0.0 163 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 164 | """ 165 | def __init__(self, dim, drop_path=0., inflate_len=3, layer_scale_init_value=1e-6): 166 | super().__init__() 167 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv 168 | self.norm = LayerNorm(dim, eps=1e-6) 169 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 170 | self.act = nn.GELU() 171 | self.pwconv2 = nn.Linear(4 * dim, dim) 172 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 173 | requires_grad=True) if layer_scale_init_value > 0 else None 174 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 175 | 176 | def forward(self, x): 177 | input = x 178 | x = self.dwconv(x) 179 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) 180 | x = self.norm(x) 181 | x = self.pwconv1(x) 182 | x = self.act(x) 183 | x = self.pwconv2(x) 184 | if self.gamma is not None: 185 | x = self.gamma * x 186 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) 187 | 188 | x = input + self.drop_path(x) 189 | return x 190 | 191 | class BlockV2(nn.Module): 192 | """ ConvNeXtV2 Block. 193 | 194 | Args: 195 | dim (int): Number of input channels. 196 | drop_path (float): Stochastic depth rate. Default: 0.0 197 | """ 198 | def __init__(self, dim, drop_path=0.): 199 | super().__init__() 200 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 201 | self.norm = LayerNorm(dim, eps=1e-6) 202 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 203 | self.act = nn.GELU() 204 | self.grn = GRN(4 * dim) 205 | self.pwconv2 = nn.Linear(4 * dim, dim) 206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 207 | 208 | def forward(self, x): 209 | input = x 210 | x = self.dwconv(x) 211 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 212 | x = self.norm(x) 213 | x = self.pwconv1(x) 214 | x = self.act(x) 215 | x = self.grn(x) 216 | x = self.pwconv2(x) 217 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 218 | 219 | x = input + self.drop_path(x) 220 | return x 221 | 222 | class BlockV23D(nn.Module): 223 | """ ConvNeXtV2 Block. 224 | 225 | Args: 226 | dim (int): Number of input channels. 227 | drop_path (float): Stochastic depth rate. Default: 0.0 228 | """ 229 | def __init__(self, dim, drop_path=0., inflate_len=3,): 230 | super().__init__() 231 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv 232 | self.norm = LayerNorm(dim, eps=1e-6) 233 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 234 | self.act = nn.GELU() 235 | self.grn = GRN(4 * dim) 236 | self.pwconv2 = nn.Linear(4 * dim, dim) 237 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 238 | 239 | def forward(self, x): 240 | input = x 241 | x = self.dwconv(x) 242 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) 243 | x = self.norm(x) 244 | x = self.pwconv1(x) 245 | x = self.act(x) 246 | x = self.grn(x) 247 | x = self.pwconv2(x) 248 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) 249 | 250 | x = input + self.drop_path(x) 251 | return x 252 | 253 | class ConvNeXtV2(nn.Module): 254 | """ ConvNeXt V2 255 | 256 | Args: 257 | in_chans (int): Number of input image channels. Default: 3 258 | num_classes (int): Number of classes for classification head. Default: 1000 259 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 260 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 261 | drop_path_rate (float): Stochastic depth rate. Default: 0. 262 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 263 | """ 264 | def __init__(self, in_chans=3, num_classes=1000, 265 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 266 | drop_path_rate=0., head_init_scale=1. 267 | ): 268 | super().__init__() 269 | self.depths = depths 270 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 271 | stem = nn.Sequential( 272 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 273 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 274 | ) 275 | self.downsample_layers.append(stem) 276 | for i in range(3): 277 | downsample_layer = nn.Sequential( 278 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 279 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 280 | ) 281 | self.downsample_layers.append(downsample_layer) 282 | 283 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 284 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 285 | cur = 0 286 | for i in range(4): 287 | stage = nn.Sequential( 288 | *[BlockV2(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 289 | ) 290 | self.stages.append(stage) 291 | cur += depths[i] 292 | 293 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 294 | self.head = nn.Linear(dims[-1], num_classes) 295 | 296 | self.apply(self._init_weights) 297 | self.head.weight.data.mul_(head_init_scale) 298 | self.head.bias.data.mul_(head_init_scale) 299 | 300 | def _init_weights(self, m): 301 | if isinstance(m, (nn.Conv2d, nn.Linear)): 302 | trunc_normal_(m.weight, std=.02) 303 | nn.init.constant_(m.bias, 0) 304 | 305 | def forward_features(self, x): 306 | for i in range(4): 307 | x = self.downsample_layers[i](x) 308 | x = self.stages[i](x) 309 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 310 | 311 | def forward(self, x): 312 | x = self.forward_features(x) 313 | x = self.head(x) 314 | return x 315 | 316 | def convnextv2_atto(**kwargs): 317 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 318 | return model 319 | 320 | def convnextv2_femto(**kwargs): 321 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 322 | return model 323 | 324 | def convnext_pico(**kwargs): 325 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 326 | return model 327 | 328 | def convnextv2_nano(**kwargs): 329 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 330 | return model 331 | 332 | def convnextv2_tiny(**kwargs): 333 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 334 | return model 335 | 336 | def convnextv2_base(**kwargs): 337 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 338 | return model 339 | 340 | def convnextv2_large(**kwargs): 341 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 342 | return model 343 | 344 | def convnextv2_huge(**kwargs): 345 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 346 | return model 347 | 348 | class ConvNeXt3D(nn.Module): 349 | r""" ConvNeXt 350 | A PyTorch impl of : `A ConvNet for the 2020s` - 351 | https://arxiv.org/pdf/2201.03545.pdf 352 | Args: 353 | in_chans (int): Number of input image channels. Default: 3 354 | num_classes (int): Number of classes for classification head. Default: 1000 355 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 356 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 357 | drop_path_rate (float): Stochastic depth rate. Default: 0. 358 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 359 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 360 | """ 361 | def __init__(self, in_chans=3, num_classes=1000, 362 | inflate_strategy='131', 363 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 364 | layer_scale_init_value=1e-6, head_init_scale=1., 365 | ): 366 | super().__init__() 367 | 368 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 369 | stem = nn.Sequential( 370 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), 371 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 372 | ) 373 | self.downsample_layers.append(stem) 374 | for i in range(3): 375 | downsample_layer = nn.Sequential( 376 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 377 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), 378 | ) 379 | self.downsample_layers.append(downsample_layer) 380 | 381 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 382 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 383 | cur = 0 384 | for i in range(4): 385 | stage = nn.Sequential( 386 | *[Block3D(dim=dims[i], inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), 387 | drop_path=dp_rates[cur + j], 388 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 389 | ) 390 | self.stages.append(stage) 391 | cur += depths[i] 392 | 393 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 394 | 395 | self.apply(self._init_weights) 396 | 397 | def inflate_weights(self, s_state_dict): 398 | t_state_dict = self.state_dict() 399 | from collections import OrderedDict 400 | for key in t_state_dict.keys(): 401 | if key not in s_state_dict: 402 | # print(key) 403 | continue 404 | if t_state_dict[key].shape != s_state_dict[key].shape: 405 | t = t_state_dict[key].shape[2] 406 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t 407 | self.load_state_dict(s_state_dict, strict=False) 408 | 409 | def _init_weights(self, m): 410 | if isinstance(m, (nn.Conv3d, nn.Linear)): 411 | trunc_normal_(m.weight, std=.02) 412 | nn.init.constant_(m.bias, 0) 413 | 414 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1): 415 | if multi: 416 | xs = [] 417 | for i in range(4): 418 | x = self.downsample_layers[i](x) 419 | x = self.stages[i](x) 420 | if multi: 421 | xs.append(x) 422 | if return_spatial: 423 | if multi: 424 | shape = xs[-1].shape[2:] 425 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) 426 | elif layer > -1: 427 | return xs[layer] 428 | else: 429 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) 430 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) 431 | 432 | def forward(self, x, multi=False, layer=-1): 433 | x = self.forward_features(x, True, multi=multi, layer=layer) 434 | return x 435 | 436 | 437 | class ConvNeXtV23D(nn.Module): 438 | """ ConvNeXt V2 439 | 440 | Args: 441 | in_chans (int): Number of input image channels. Default: 3 442 | num_classes (int): Number of classes for classification head. Default: 1000 443 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 444 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 445 | drop_path_rate (float): Stochastic depth rate. Default: 0. 446 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 447 | """ 448 | def __init__(self, in_chans=3, num_classes=1000, 449 | inflate_strategy='131', 450 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 451 | drop_path_rate=0., head_init_scale=1. 452 | ): 453 | super().__init__() 454 | self.depths = depths 455 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 456 | stem = nn.Sequential( 457 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), 458 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 459 | ) 460 | self.downsample_layers.append(stem) 461 | for i in range(3): 462 | downsample_layer = nn.Sequential( 463 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 464 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), 465 | ) 466 | self.downsample_layers.append(downsample_layer) 467 | 468 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 469 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 470 | cur = 0 471 | for i in range(4): 472 | stage = nn.Sequential( 473 | *[BlockV23D(dim=dims[i], drop_path=dp_rates[cur + j], 474 | inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), 475 | ) for j in range(depths[i])] 476 | ) 477 | self.stages.append(stage) 478 | cur += depths[i] 479 | 480 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 481 | self.head = nn.Linear(dims[-1], num_classes) 482 | 483 | self.apply(self._init_weights) 484 | self.head.weight.data.mul_(head_init_scale) 485 | self.head.bias.data.mul_(head_init_scale) 486 | 487 | def inflate_weights(self, pretrained_path): 488 | t_state_dict = self.state_dict() 489 | s_state_dict = torch.load(pretrained_path)["model"] 490 | from collections import OrderedDict 491 | for key in t_state_dict.keys(): 492 | if key not in s_state_dict: 493 | # print(key) 494 | continue 495 | if t_state_dict[key].shape != s_state_dict[key].shape: 496 | # print(t_state_dict[key].shape, s_state_dict[key].shape) 497 | t = t_state_dict[key].shape[2] 498 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t 499 | self.load_state_dict(s_state_dict, strict=False) 500 | 501 | def _init_weights(self, m): 502 | if isinstance(m, (nn.Conv3d, nn.Linear)): 503 | trunc_normal_(m.weight, std=.02) 504 | nn.init.constant_(m.bias, 0) 505 | 506 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1): 507 | if multi: 508 | xs = [] 509 | for i in range(4): 510 | x = self.downsample_layers[i](x) 511 | x = self.stages[i](x) 512 | if multi: 513 | xs.append(x) 514 | if return_spatial: 515 | if multi: 516 | shape = xs[-1].shape[2:] 517 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) 518 | elif layer > -1: 519 | return xs[layer] 520 | else: 521 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) 522 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) 523 | 524 | def forward(self, x, multi=False, layer=-1): 525 | x = self.forward_features(x, True, multi=multi, layer=layer) 526 | return x 527 | 528 | 529 | 530 | def convnext_3d_tiny(pretrained, in_22k=False, **kwargs): 531 | # print("Using Imagenet 22K pretrain", in_22k) 532 | model = ConvNeXt3D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 533 | checkpoint = torch.load(os.path.join(pretrained, 'convnext_tiny_1k_224_ema.pth'), map_location="cpu") 534 | model.inflate_weights(checkpoint["model"]) 535 | 536 | return model 537 | 538 | --------------------------------------------------------------------------------