├── vebench
├── blip_models
│ ├── __init__.py
│ ├── blip_itm.py
│ ├── blip_nlvr.py
│ ├── blip_vqa.py
│ ├── blip.py
│ ├── blip_retrieval.py
│ ├── blip_pretrain.py
│ └── vit.py
├── models
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── BLIP_configs
│ │ │ ├── retrieval_msrvtt.yaml
│ │ │ ├── nocaps.yaml
│ │ │ ├── nlvr.yaml
│ │ │ ├── med_config.json
│ │ │ ├── bert_config.json
│ │ │ ├── pretrain.yaml
│ │ │ ├── vqa.yaml
│ │ │ ├── caption_coco.yaml
│ │ │ ├── retrieval_coco.yaml
│ │ │ └── retrieval_flickr.yaml
│ │ ├── blip.py
│ │ ├── uniformer_backbone.py
│ │ └── conv_backbone.py
│ ├── __init__.py
│ ├── network.py
│ ├── tradition.py
│ ├── text_alignment.py
│ ├── fidelity.py
│ └── head.py
├── __init__.py
├── README.md
├── configs
│ ├── text.yaml
│ ├── doublestream.yaml
│ └── dover.yaml
├── infer.py
├── evaluator.py
└── preprocess.py
├── MANIFEST.in
├── assets
├── .DS_Store
├── dst.mp4
├── dst2.mp4
├── src.mp4
├── scores.jpg
└── overview.jpg
├── requirements.txt
├── .gitignore
├── test.py
├── setup.py
├── LICENSE
└── README.md
/vebench/blip_models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vebench/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vebench/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluator import VEBenchModel
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include config/*.yaml
2 | include models/backbone/BLIP_configs
3 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/dst.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/dst.mp4
--------------------------------------------------------------------------------
/assets/dst2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/dst2.mp4
--------------------------------------------------------------------------------
/assets/src.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/src.mp4
--------------------------------------------------------------------------------
/assets/scores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/scores.jpg
--------------------------------------------------------------------------------
/assets/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlespray/VE-Bench/HEAD/assets/overview.jpg
--------------------------------------------------------------------------------
/vebench/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .network import EvalEditModel
2 |
3 | __all__=['EvalEditModel']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | decord
2 | einops
3 | fairscale
4 | numpy
5 | timm
6 | transformers
7 | scikit-video
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpts
2 | __pycache__
3 | data
4 | */__pycache__
5 | */*/__pycache__
6 | */*/*/__pycache__
7 | build
8 | dist
9 | *egg-info
10 | .DS_Store
11 | */.DS_Store
12 | */*/.DS_Store
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from vebench import VEBenchModel
2 |
3 | evaluator = VEBenchModel()
4 |
5 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4')
6 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4')
7 | print(score1, score2)
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/retrieval_msrvtt.yaml:
--------------------------------------------------------------------------------
1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2 | ann_root: 'annotation'
3 |
4 | # set pretrained as a file path or an url
5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6 |
7 | # size of vit model; base or large
8 | vit: 'base'
9 | batch_size: 64
10 | k_test: 128
11 | image_size: 384
12 | num_frm_test: 8
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/nocaps.yaml:
--------------------------------------------------------------------------------
1 | image_root: '/export/share/datasets/vision/nocaps/'
2 | ann_root: 'annotation'
3 |
4 | # set pretrained as a file path or an url
5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6 |
7 | vit: 'base'
8 | batch_size: 32
9 |
10 | image_size: 384
11 |
12 | max_length: 20
13 | min_length: 5
14 | num_beams: 3
15 | prompt: 'a picture of '
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/nlvr.yaml:
--------------------------------------------------------------------------------
1 | image_root: '/export/share/datasets/vision/NLVR2/'
2 | ann_root: 'annotation'
3 |
4 | # set pretrained as a file path or an url
5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6 |
7 | #size of vit model; base or large
8 | vit: 'base'
9 | batch_size_train: 16
10 | batch_size_test: 64
11 | vit_grad_ckpt: False
12 | vit_ckpt_layer: 0
13 | max_epoch: 15
14 |
15 | image_size: 384
16 |
17 | # optimizer
18 | weight_decay: 0.05
19 | init_lr: 3e-5
20 | min_lr: 0
21 |
22 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/med_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30524,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3 | ]
4 | laion_path: ''
5 |
6 | # size of vit model; base or large
7 | vit: 'base'
8 | vit_grad_ckpt: False
9 | vit_ckpt_layer: 0
10 |
11 | image_size: 224
12 | batch_size: 75
13 |
14 | queue_size: 57600
15 | alpha: 0.4
16 |
17 | # optimizer
18 | weight_decay: 0.05
19 | init_lr: 3e-4
20 | min_lr: 1e-6
21 | warmup_lr: 1e-6
22 | lr_decay_rate: 0.9
23 | max_epoch: 20
24 | warmup_steps: 3000
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/vebench/README.md:
--------------------------------------------------------------------------------
1 | ## Easy Use
2 | VE-Bench can be installed with a single ``pip`` command. Since the model employs normalization during training, its output does not represent absolute scores. We **recommend performing comparisons between video pairs**, as demonstrated below:
3 | ```
4 | pip install vebench
5 | ```
6 | When comparing videos:
7 | ```
8 | from vebench import VEBenchModel
9 |
10 | evaluator = VEBenchModel()
11 |
12 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4')
13 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4')
14 | print(score1, score2) # Score1: 1.3563, Score2: 0.66194
15 | ```
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/vqa.yaml:
--------------------------------------------------------------------------------
1 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2 | vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3 | train_files: ['vqa_train','vqa_val','vg_qa']
4 | ann_root: 'annotation'
5 |
6 | # set pretrained as a file path or an url
7 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8 |
9 | # size of vit model; base or large
10 | vit: 'base'
11 | batch_size_train: 16
12 | batch_size_test: 32
13 | vit_grad_ckpt: False
14 | vit_ckpt_layer: 0
15 | init_lr: 2e-5
16 |
17 | image_size: 480
18 |
19 | k_test: 128
20 | inference: 'rank'
21 |
22 | # optimizer
23 | weight_decay: 0.05
24 | min_lr: 0
25 | max_epoch: 10
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | name = 'vebench'
4 |
5 | long_description='Please refer to https://github.com/littlespray/VE-Bench'
6 |
7 | setup(
8 | name=name, # 包名同工程名,这样导入包的时候更有对应性
9 | version='1.0.0',
10 | author="Shangkun Sun",
11 | license="MIT Licence",
12 | author_email='sunshk@stu.pku.edu.cn',
13 | description="Evaluator for Text-driven Video Editing",
14 | packages=find_packages(),
15 | python_requires='>=3',
16 | long_description=long_description,
17 | # 设置依赖包
18 | install_requires=['torch', 'decord', 'einops', 'fairscale', 'numpy', 'timm', 'transformers', 'sk-video'],
19 | include_package_data=True, # 包含额外的非Python文件
20 | package_data={
21 | '': ['configs/*.yaml', 'models/backbone/BLIP_configs/*'], # 匹配目录下的所有文件
22 | },
23 | )
24 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/caption_coco.yaml:
--------------------------------------------------------------------------------
1 | image_root: '/export/share/datasets/vision/coco/images/'
2 | ann_root: 'annotation'
3 | coco_gt_root: 'annotation/coco_gt'
4 |
5 | # set pretrained as a file path or an url
6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7 |
8 | # size of vit model; base or large
9 | vit: 'base'
10 | vit_grad_ckpt: False
11 | vit_ckpt_layer: 0
12 | batch_size: 32
13 | init_lr: 1e-5
14 |
15 | # vit: 'large'
16 | # vit_grad_ckpt: True
17 | # vit_ckpt_layer: 5
18 | # batch_size: 16
19 | # init_lr: 2e-6
20 |
21 | image_size: 384
22 |
23 | # generation configs
24 | max_length: 20
25 | min_length: 5
26 | num_beams: 3
27 | prompt: 'a picture of '
28 |
29 | # optimizer
30 | weight_decay: 0.05
31 | min_lr: 0
32 | max_epoch: 5
33 |
34 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/retrieval_coco.yaml:
--------------------------------------------------------------------------------
1 | image_root: '/export/share/datasets/vision/coco/images/'
2 | ann_root: 'annotation'
3 | dataset: 'coco'
4 |
5 | # set pretrained as a file path or an url
6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7 |
8 | # size of vit model; base or large
9 |
10 | vit: 'base'
11 | batch_size_train: 32
12 | batch_size_test: 64
13 | vit_grad_ckpt: True
14 | vit_ckpt_layer: 4
15 | init_lr: 1e-5
16 |
17 | # vit: 'large'
18 | # batch_size_train: 16
19 | # batch_size_test: 32
20 | # vit_grad_ckpt: True
21 | # vit_ckpt_layer: 12
22 | # init_lr: 5e-6
23 |
24 | image_size: 384
25 | queue_size: 57600
26 | alpha: 0.4
27 | k_test: 256
28 | negative_all_rank: True
29 |
30 | # optimizer
31 | weight_decay: 0.05
32 | min_lr: 0
33 | max_epoch: 6
34 |
35 |
--------------------------------------------------------------------------------
/vebench/models/backbone/BLIP_configs/retrieval_flickr.yaml:
--------------------------------------------------------------------------------
1 | image_root: '/export/share/datasets/vision/flickr30k/'
2 | ann_root: 'annotation'
3 | dataset: 'flickr'
4 |
5 | # set pretrained as a file path or an url
6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7 |
8 | # size of vit model; base or large
9 |
10 | vit: 'base'
11 | batch_size_train: 32
12 | batch_size_test: 64
13 | vit_grad_ckpt: True
14 | vit_ckpt_layer: 4
15 | init_lr: 1e-5
16 |
17 | # vit: 'large'
18 | # batch_size_train: 16
19 | # batch_size_test: 32
20 | # vit_grad_ckpt: True
21 | # vit_ckpt_layer: 10
22 | # init_lr: 5e-6
23 |
24 | image_size: 384
25 | queue_size: 57600
26 | alpha: 0.4
27 | k_test: 128
28 | negative_all_rank: False
29 |
30 | # optimizer
31 | weight_decay: 0.05
32 | min_lr: 0
33 | max_epoch: 6
34 |
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 sunshk1227
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/vebench/configs/text.yaml:
--------------------------------------------------------------------------------
1 | name: e-bench-blip-test
2 | num_epochs: 0
3 | l_num_epochs: 20
4 | warmup_epochs: 2.5
5 | ema: true
6 | save_model: true
7 | batch_size: 8
8 | num_workers: 6
9 | split_seed: 42 #useless
10 |
11 | wandb:
12 | project_name: e-bench-blip-test
13 |
14 | data:
15 | videoQA:
16 | type: ViewDecompositionDataset
17 | args:
18 | weight: 0.443
19 | phase: train
20 | anno_file: ../e-bench-db/label.txt
21 | data_prefix: ../e-bench-db/edited
22 | sample_types:
23 | time:
24 | size_h: 224
25 | size_w: 224
26 | clip_len: 16
27 | frame_interval: 1
28 | t_frag: 16
29 | num_clips: 1
30 |
31 | model:
32 | type: VideoTextAlignmentModel
33 | args:
34 | backbone:
35 | time:
36 | #in_channels: 768
37 | type: blip
38 | pretrained: true
39 | checkpoint: true
40 | blip_type: multimodal_text
41 |
42 | backbone_preserve_keys: time
43 | divide_head: true
44 | use_tn: true
45 | vqa_head:
46 | #in_channels: 768
47 | hidden_channels: 64
48 | attn_pool3d: true # 代码里默认为false
49 | text_pool3d: false
50 |
51 | optimizer:
52 | lr: !!float 6.25e-4
53 | backbone_lr_mult: !!float 1e-1
54 | wd: 0.05
55 |
56 | test_load_path: [./ckpts/e-bench-blip_head_videoQA_9_eval_s_finetuned.pth]
--------------------------------------------------------------------------------
/vebench/configs/doublestream.yaml:
--------------------------------------------------------------------------------
1 | name: e-bench-uniformer-src-edit-test
2 | num_epochs: 10
3 | l_num_epochs: 20
4 | warmup_epochs: 2.5
5 | ema: true
6 | save_model: true
7 | batch_size: 8
8 | num_workers: 6
9 | split_seed: 42 #useless
10 |
11 | wandb:
12 | project_name: e-bench-uniformer-src-edit-test
13 |
14 | data:
15 | videoQA:
16 | type: ViewDecompositionDataset
17 | args:
18 | weight: 0.443
19 | phase: train
20 | anno_file: ../e-bench-db/label.txt
21 | data_prefix: ../e-bench-db/src/
22 | sample_types:
23 | technical:
24 | fragments_h: 7
25 | fragments_w: 7
26 | fsize_h: 32
27 | fsize_w: 32
28 | aligned: 32
29 | clip_len: 32
30 | frame_interval: 1
31 | num_clips: 1
32 | aesthetic:
33 | size_h: 224
34 | size_w: 224
35 | clip_len: 32
36 | frame_interval: 1
37 | t_frag: 32
38 | num_clips: 1
39 |
40 | model:
41 | type: DoubleStreamModel
42 | args:
43 | backbone:
44 | technical:
45 | type: uniformerv2_b16
46 | checkpoint: true
47 | pretrained:
48 | aesthetic:
49 | type: uniformerv2_b16
50 | pretrained: true # 代码里默认为True
51 | in22k: false # 代码里默认为False
52 |
53 | backbone_preserve_keys: technical,aesthetic
54 | divide_head: true
55 | use_tn: true
56 | vqa_head:
57 | #in_channels: 768
58 | hidden_channels: 64
59 | attn_pool3d: true # 代码里默认为false
60 | text_pool3d: false
61 |
62 | optimizer:
63 | lr: !!float 6.25e-4
64 | backbone_lr_mult: !!float 1e-1
65 | wd: 0.05
66 |
67 | test_load_path: [./ckpts/e-bench-uniformer-src-edit_head_videoQA_3_eval_s_finetuned.pth]
--------------------------------------------------------------------------------
/vebench/models/network.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | import torch
4 | import torch.nn as nn
5 | import os
6 |
7 | from .tradition import DOVER
8 | from .fidelity import DoubleStreamModel
9 | from .text_alignment import VideoTextAlignmentModel
10 |
11 | from huggingface_hub import snapshot_download
12 |
13 | class EvalEditModel(nn.Module):
14 | def __init__(self, dover_opt, doublestream_opt, text_opt, model_path='ckpts'):
15 | super().__init__()
16 |
17 | if not os.path.isdir(model_path):
18 | model_path = snapshot_download('sunshk/vebench')
19 |
20 | # build model
21 | self.traditional_branch = DOVER(**dover_opt['model']['args'],model_path=model_path).eval()
22 | self.fidelity_branch = DoubleStreamModel(**doublestream_opt['model']['args'], model_path=model_path).eval()
23 | self.text_branch = VideoTextAlignmentModel(**text_opt['model']['args'], model_path=model_path).eval()
24 |
25 | # load_weight
26 | self.load_ckpt(model_path)
27 |
28 |
29 | def load_ckpt(self, model_path):
30 | # print('111')
31 | self.traditional_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-dover_head_videoQA_0_eval_n_finetuned.pth'), map_location='cpu')['state_dict'])
32 | self.fidelity_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-uniformer-src-edit_head_videoQA_3_eval_s_finetuned.pth'),map_location='cpu')['state_dict'],strict=False)
33 | self.text_branch.load_state_dict(torch.load(os.path.join(model_path, 'e-bench-blip_head_videoQA_9_eval_s_finetuned.pth'), map_location='cpu')['state_dict'],strict=False)
34 |
35 | def forward(self, src_video, edit_video, prompt):
36 | traditional_score = self.traditional_branch(edit_video,reduce_scores=True)
37 | fidelity_score = self.fidelity_branch(src_video, edit_video)
38 | text_score = self.text_branch(edit_video,prompts=prompt)
39 | # the weight of each score is pre-computed within each branch
40 | return (traditional_score + fidelity_score[0] + text_score[0]).item()
41 |
42 |
43 |
44 | if __name__ == "__main__":
45 | eval_model=EvalEditModel()
46 |
--------------------------------------------------------------------------------
/vebench/configs/dover.yaml:
--------------------------------------------------------------------------------
1 | name: e-bench-dover-test
2 | num_epochs: 20
3 | l_num_epochs: 40
4 | warmup_epochs: 2.5
5 | ema: true
6 | save_model: true
7 | batch_size: 8
8 | num_workers: 6
9 | split_seed: 42 #useless
10 |
11 | wandb:
12 | project_name: e-bench-dover-test
13 |
14 | data:
15 | videoQA:
16 | type: ViewDecompositionDataset
17 | args:
18 | weight: 0.443
19 | phase: train
20 | anno_file: ../e-bench-db/label.txt
21 | data_prefix: ../e-bench-db/edited
22 | sample_types:
23 | technical:
24 | fragments_h: 7
25 | fragments_w: 7
26 | fsize_h: 32
27 | fsize_w: 32
28 | aligned: 32
29 | clip_len: 32
30 | frame_interval: 1
31 | num_clips: 1
32 | aesthetic:
33 | size_h: 224
34 | size_w: 224
35 | clip_len: 32
36 | frame_interval: 1
37 | t_frag: 32
38 | num_clips: 1
39 | # time:
40 | # size_h: 224
41 | # size_w: 224
42 | # clip_len: 16
43 | # frame_interval: 1
44 | # t_frag: 16
45 | # num_clips: 1
46 |
47 | model:
48 | type: DOVER
49 | args:
50 | backbone:
51 | technical:
52 | type: swin_tiny_grpb
53 | checkpoint: true
54 | pretrained:
55 | aesthetic:
56 | type: conv_tiny
57 | pretrained: true # 代码里默认为True
58 | in22k: false # 代码里默认为False
59 |
60 | backbone_preserve_keys: technical,aesthetic
61 | divide_head: true
62 | vqa_head:
63 | #in_channels: 768
64 | hidden_channels: 64
65 | attn_pool3d: true # 代码里默认为false
66 | text_pool3d: false
67 |
68 | optimizer:
69 | lr: !!float 6.25e-4
70 | backbone_lr_mult: !!float 1e-1
71 | wd: 0.05
72 |
73 | test_load_path: [./ckpts/e-bench-dover_head_videoQA_0_eval_n_finetuned.pth]
--------------------------------------------------------------------------------
/vebench/blip_models/blip_itm.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel
2 | from transformers import BertTokenizer
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from .blip import create_vit, init_tokenizer, load_checkpoint
9 |
10 | class BLIP_ITM(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 384,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | embed_dim = 256,
18 | ):
19 | """
20 | Args:
21 | med_config (str): path for the mixture of encoder-decoder model's configuration file
22 | image_size (int): input image size
23 | vit (str): model size of vision transformer
24 | """
25 | super().__init__()
26 |
27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28 | self.tokenizer = init_tokenizer()
29 | med_config = BertConfig.from_json_file(med_config)
30 | med_config.encoder_width = vision_width
31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32 |
33 | text_width = self.text_encoder.config.hidden_size
34 |
35 | self.vision_proj = nn.Linear(vision_width, embed_dim)
36 | self.text_proj = nn.Linear(text_width, embed_dim)
37 |
38 | self.itm_head = nn.Linear(text_width, 2)
39 |
40 |
41 | def forward(self, image, caption, match_head='itm'):
42 |
43 | image_embeds = self.visual_encoder(image)
44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45 |
46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47 | return_tensors="pt").to(image.device)
48 |
49 |
50 | if match_head=='itm':
51 | output = self.text_encoder(text.input_ids,
52 | attention_mask = text.attention_mask,
53 | encoder_hidden_states = image_embeds,
54 | encoder_attention_mask = image_atts,
55 | return_dict = True,
56 | )
57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58 | return itm_output
59 |
60 | elif match_head=='itc':
61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62 | return_dict = True, mode = 'text')
63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65 |
66 | sim = image_feat @ text_feat.t()
67 | return sim
68 |
69 |
70 | def blip_itm(pretrained='',**kwargs):
71 | model = BLIP_ITM(**kwargs)
72 | if pretrained:
73 | model,msg = load_checkpoint(model,pretrained)
74 | assert(len(msg.missing_keys)==0)
75 | return model
76 |
--------------------------------------------------------------------------------
/vebench/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models import EvalEditModel
5 | from preprocess import Processor
6 | import yaml
7 | import argparse
8 |
9 | #fixed seed
10 | seed_n = 42
11 | print('seed is ' + str(seed_n))
12 | torch.manual_seed(seed_n)
13 |
14 | device='cuda'
15 | class EBenchModel(nn.Module):
16 | def __init__(self):
17 | super().__init__()
18 | dover_config = 'configs/dover.yaml'
19 | doublestream_config = 'configs/doublestream.yaml'
20 | text_config = 'configs/text.yaml'
21 |
22 | with open(dover_config, "r") as f:
23 | dover_opt = yaml.safe_load(f)
24 | with open(doublestream_config, "r") as f:
25 | doublestream_opt = yaml.safe_load(f)
26 | with open(text_config, "r") as f:
27 | text_opt = yaml.safe_load(f)
28 | self.model = EvalEditModel().cuda()
29 | self.traditional_processor=Processor(dover_opt['data']['videoQA']['args'])
30 | self.text_pocessor=Processor(text_opt['data']['videoQA']['args'])
31 | self.doublestream_processor=Processor(doublestream_opt['data']['videoQA']['args'])
32 |
33 |
34 | def read_data(self, path):
35 | traditional_data=self.traditional_processor.preprocess(path)
36 | text_data=self.text_pocessor.preprocess(path)
37 | doublestream_data = self.doublestream_processor.preprocess(path)
38 | data={}
39 | for branch_data in[traditional_data,text_data,doublestream_data]:
40 | for key in branch_data.keys():
41 | data[key]=branch_data[key]
42 | return data
43 |
44 |
45 | @torch.no_grad()
46 | def evaluate(self, prompt, src_path, dst_path):
47 | src_video = self.read_data(src_path)
48 | dst_video = self.read_data(dst_path)
49 | result = self.model(src_video, dst_video, prompt)
50 | return result
51 |
52 | if __name__ == "__main__":
53 |
54 | parser = argparse.ArgumentParser(description='Process video files with EBenchModel.')
55 |
56 |
57 | parser.add_argument('--single_test', action='store_true', help='Run a single test with specified paths and prompt.')
58 | parser.add_argument('--src_path', type=str, help='Source video path for single test.')
59 | parser.add_argument('--dst_path', type=str, help='Destination video path for single test.')
60 | parser.add_argument('--prompt', type=str, help='Prompt for single test.')
61 | parser.add_argument('--data_path', type=str, help='Data path for batch processing.')
62 | parser.add_argument('--label_path', type=str, help='Label path for batch processing.')
63 |
64 |
65 | args = parser.parse_args()
66 |
67 |
68 | if args.single_test:
69 | if args.src_path and args.dst_path and args.prompt:
70 | src_path = args.src_path
71 | dst_path = args.dst_path
72 | prompt = args.prompt
73 | ebench = EBenchModel()
74 | result = ebench.evaluate(prompt, src_path, dst_path)
75 | print(f"The result is {result}")
76 | else:
77 | print("Error: For single test, --src_path, --dst_path, and --prompt must be provided.")
78 | else:
79 | if args.data_path and args.label_path:
80 | data_path = args.data_path
81 | label_path = args.label_path
82 | src=[]
83 | dst=[]
84 | prompts=[]
85 | with open(label_path,'r') as file:
86 | for line in file:
87 | video_name,_,prompt=line.split('|')
88 | src+=[data_path+"src/"+video_name]
89 | dst += [data_path + "edited/" + video_name]
90 | prompts+=[prompt]
91 | ebench = EBenchModel()
92 | results=[]
93 | for src_path,dst_path,prompt in zip(src,dst,prompts):
94 | result = ebench.evaluate(prompt, src_path, dst_path)
95 | results+=[result]
96 | print(len(results))
97 | with open("label.txt","w") as file:
98 | for src_path,result in zip(src,results):
99 | file.write(f"{src_path.split('/')[-1]},{result}\n")
100 | else:
101 | print("Error: For batch test, --data_path, --label_path must be provided.")
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [\[AAAI 25\] VE-Bench: Subjective-Aligned Benchmark Suite for Text-Driven Video Editing Quality Assessment](https://arxiv.org/abs/2408.11481)
2 |
3 |
4 | Shangkun Sun, Xiaoyu Liang, Songlin Fan, Wenxu Gao, Wei Gao*
5 |
6 | (* Corresponding author)
7 |
8 | from MMCAL, Peking University
9 |
10 |
11 |
14 |
15 | ## 🎦 Introduction
16 | TL;DR: VE-Bench is an evaluation suite for text-driven video editing, consisting of a quality assessment model to provide a human-aligned metric for edited videos, and a database containing rich video-prompt pairs and the corresponding human scores.
17 |
18 |
19 |

20 |
21 | Overview of the VE-Bench Suite
22 |
23 |
24 | VE-Bench DB contains a rich collection of source videos, including real-world videos, AIGC videos, and CG videos, covering various aspects such as people, objects, animals, and landscapes. It also includes a variety of editing instructions across different categories, including semantic editing like addition, removal, replacement, etc., as well as structural changes in size, shape, etc., and stylizations such as color, texture, etc. Additionally, it features editing results based on different video editing models. We conducted a subjective experiment involving 24 participants from diverse backgrounds, resulting in 28,080 score samples. We further trained VE-Bench QA model based on this data. The left image below shows the box plot of average scores obtained by each model during the subjective experiment, while the right image illustrates the scores for each model across different types of prompts.
25 |
26 |
27 |

28 |
29 | Left: Average score distributions of 8 editing methods. Right: Performance on different types of prompts from previous video-editing methods.
30 |
31 |
32 | ## Easy Use
33 | VE-Bench can be installed with a single ``pip`` command.
34 | ```
35 | pip install vebench
36 | ```
37 | When comparing videos, you can use ``python test.py``, namely:
38 | ```
39 | from vebench import VEBenchModel
40 |
41 | evaluator = VEBenchModel()
42 |
43 | score1 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst.mp4')
44 | score2 = evaluator.evaluate('A black-haired boy is turning his head', 'assets/src.mp4', 'assets/dst2.mp4')
45 | print(score1, score2) # Score1: 1.3563, Score2: 0.66194
46 | ```
47 | Since the model employs normalization during training, its output does not represent exactly absolute 1 \~ 10 scores, as demonstrated above.
48 |
49 | ## Database
50 | VE-Bench DB is available here. [baidu netdisk](https://pan.baidu.com/s/1D5y6ADXgz8PPHGCxROlNIQ?pwd=sggc) | [google drive](https://drive.google.com/file/d/1SBmXK6XKuyGTaV9LUQXfy5w82bsA3Nve/view?usp=sharing)
51 |
52 |
53 | ## Local Inference
54 |
55 | ### 💼 Preparation
56 | ``
57 | cd vebench
58 | ``
59 |
60 | You can also download all checkpoints from [google drive](https://drive.google.com/drive/folders/1kD82Ex90VP9A_AqjYV1J5DYvBQW-hkXa?usp=sharing) and put them into ``ckpts``.
61 |
62 | ### ✨ Usage
63 | To evaluate one single video:
64 | ```
65 | python -m infer.py --single_test --src_path ${path_to_source_video} --dst_path ${path_to_dst_video} --prompt ${editing_prompt}
66 |
67 | # Run on example videos
68 | # python -m infer.py --single_test --src_path "./data/src/00433tokenflow_baby_gaze.mp4" --dst_path "./data/edited/00433tokenflow_baby_gaze.mp4" --prompt "A black-haired boy is turning his head"
69 | ```
70 |
71 |
72 | To evaluate a set of videos:
73 | ```
74 | python -m infer.py --data_path ${path_to_data_folder} --label_path ${path_to_prompt_txt_file}
75 | ```
76 |
77 | ## 🙏 Acknowledgements
78 | Part of the code is developed based on [DOVER](https://github.com/VQAssessment/DOVER) and [BLIP](https://github.com/salesforce/BLIP). We would like to thank the authors for their contributions to the community.
79 |
80 |
81 | ## 📭 Contact
82 | If your have any comments or questions, feel free to contact [sunshk@stu.pku.edu.cn](lsunshk@stu.pku.edu.cn).
83 |
84 |
85 |
86 | ## 📖 BibTex
87 | ```bibtex
88 | @article{sun2024bench,
89 | title={VE-Bench: Subjective-Aligned Benchmark Suite for Text-Driven Video Editing Quality Assessment},
90 | author={Sun, Shangkun and Liang, Xiaoyu and Fan, Songlin and Gao, Wenxu and Gao, Wei},
91 | journal={arXiv preprint arXiv:2408.11481},
92 | year={2024}
93 | }
94 | ```
95 |
--------------------------------------------------------------------------------
/vebench/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | from .models import EvalEditModel
6 | from .preprocess import Processor
7 | import yaml
8 | import argparse
9 | import random
10 | import numpy as np
11 |
12 |
13 | device='cuda'
14 | class VEBenchModel(nn.Module):
15 | def __init__(self, seed=42):
16 | super().__init__()
17 |
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | np.random.seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 |
24 | base_dir = os.path.dirname(os.path.abspath(__file__))
25 | # 构造配置文件的绝对路径
26 | dover_config = os.path.join(base_dir, 'configs', 'dover.yaml')
27 | doublestream_config = os.path.join(base_dir, 'configs', 'doublestream.yaml')
28 | text_config = os.path.join(base_dir, 'configs', 'text.yaml')
29 |
30 |
31 |
32 | with open(dover_config, "r") as f:
33 | dover_opt = yaml.safe_load(f)
34 | with open(doublestream_config, "r") as f:
35 | doublestream_opt = yaml.safe_load(f)
36 | with open(text_config, "r") as f:
37 | text_opt = yaml.safe_load(f)
38 | self.model = EvalEditModel(dover_opt, doublestream_opt, text_opt).cuda()
39 | self.traditional_processor=Processor(dover_opt['data']['videoQA']['args'])
40 | self.text_pocessor=Processor(text_opt['data']['videoQA']['args'])
41 | self.doublestream_processor=Processor(doublestream_opt['data']['videoQA']['args'])
42 |
43 |
44 | def read_data(self, path):
45 | traditional_data=self.traditional_processor.preprocess(path)
46 | text_data=self.text_pocessor.preprocess(path)
47 | doublestream_data = self.doublestream_processor.preprocess(path)
48 | data={}
49 | for branch_data in[traditional_data,text_data,doublestream_data]:
50 | for key in branch_data.keys():
51 | data[key]=branch_data[key]
52 | return data
53 |
54 |
55 | @torch.no_grad()
56 | def evaluate(self, prompt, src_path, dst_path):
57 | src_video = self.read_data(src_path)
58 | dst_video = self.read_data(dst_path)
59 | result = self.model(src_video, dst_video, prompt)
60 | return result
61 |
62 | if __name__ == "__main__":
63 |
64 | parser = argparse.ArgumentParser(description='Process video files with VEBenchModel.')
65 |
66 |
67 | parser.add_argument('--single_test', action='store_true', help='Run a single test with specified paths and prompt.')
68 | parser.add_argument('--src_path', type=str, help='Source video path for single test.')
69 | parser.add_argument('--dst_path', type=str, help='Destination video path for single test.')
70 | parser.add_argument('--prompt', type=str, help='Prompt for single test.')
71 | parser.add_argument('--data_path', type=str, help='Data path for batch processing.')
72 | parser.add_argument('--label_path', type=str, help='Label path for batch processing.')
73 |
74 |
75 | args = parser.parse_args()
76 |
77 |
78 | if args.single_test:
79 | if args.src_path and args.dst_path and args.prompt:
80 | src_path = args.src_path
81 | dst_path = args.dst_path
82 | prompt = args.prompt
83 | ebench = VEBenchModel()
84 | result = ebench.evaluate(prompt, src_path, dst_path)
85 | print(f"The result is {result}")
86 | else:
87 | print("Error: For single test, --src_path, --dst_path, and --prompt must be provided.")
88 | else:
89 | if args.data_path and args.label_path:
90 | data_path = args.data_path
91 | label_path = args.label_path
92 | src=[]
93 | dst=[]
94 | prompts=[]
95 | with open(label_path,'r') as file:
96 | for line in file:
97 | video_name,_,prompt=line.split('|')
98 | src+=[data_path+"src/"+video_name]
99 | dst += [data_path + "edited/" + video_name]
100 | prompts+=[prompt]
101 | ebench = EBenchModel()
102 | results=[]
103 | for src_path,dst_path,prompt in zip(src,dst,prompts):
104 | result = ebench.evaluate(prompt, src_path, dst_path)
105 | results+=[result]
106 | print(len(results))
107 | with open("label.txt","w") as file:
108 | for src_path,result in zip(src,results):
109 | file.write(f"{src_path.split('/')[-1]},{result}\n")
110 | else:
111 | print("Error: For batch test, --data_path, --label_path must be provided.")
112 |
--------------------------------------------------------------------------------
/vebench/blip_models/blip_nlvr.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig
2 | from .nlvr_encoder import BertModel
3 | from .vit import interpolate_pos_embed
4 | from .blip import create_vit, init_tokenizer, is_url
5 |
6 | from timm.models.hub import download_cached_file
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from transformers import BertTokenizer
12 | import numpy as np
13 |
14 | class BLIP_NLVR(nn.Module):
15 | def __init__(self,
16 | med_config = 'configs/med_config.json',
17 | image_size = 480,
18 | vit = 'base',
19 | vit_grad_ckpt = False,
20 | vit_ckpt_layer = 0,
21 | ):
22 | """
23 | Args:
24 | med_config (str): path for the mixture of encoder-decoder model's configuration file
25 | image_size (int): input image size
26 | vit (str): model size of vision transformer
27 | """
28 | super().__init__()
29 |
30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
31 | self.tokenizer = init_tokenizer()
32 | med_config = BertConfig.from_json_file(med_config)
33 | med_config.encoder_width = vision_width
34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35 |
36 | self.cls_head = nn.Sequential(
37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
38 | nn.ReLU(),
39 | nn.Linear(self.text_encoder.config.hidden_size, 2)
40 | )
41 |
42 | def forward(self, image, text, targets, train=True):
43 |
44 | image_embeds = self.visual_encoder(image)
45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
47 |
48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
49 | text.input_ids[:,0] = self.tokenizer.enc_token_id
50 |
51 | output = self.text_encoder(text.input_ids,
52 | attention_mask = text.attention_mask,
53 | encoder_hidden_states = [image0_embeds,image1_embeds],
54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
55 | image_atts[image0_embeds.size(0):]],
56 | return_dict = True,
57 | )
58 | hidden_state = output.last_hidden_state[:,0,:]
59 | prediction = self.cls_head(hidden_state)
60 |
61 | if train:
62 | loss = F.cross_entropy(prediction, targets)
63 | return loss
64 | else:
65 | return prediction
66 |
67 | def blip_nlvr(pretrained='',**kwargs):
68 | model = BLIP_NLVR(**kwargs)
69 | if pretrained:
70 | model,msg = load_checkpoint(model,pretrained)
71 | print("missing keys:")
72 | print(msg.missing_keys)
73 | return model
74 |
75 |
76 | def load_checkpoint(model,url_or_filename):
77 | if is_url(url_or_filename):
78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
79 | checkpoint = torch.load(cached_file, map_location='cpu')
80 | elif os.path.isfile(url_or_filename):
81 | checkpoint = torch.load(url_or_filename, map_location='cpu')
82 | else:
83 | raise RuntimeError('checkpoint url or path is invalid')
84 | state_dict = checkpoint['model']
85 |
86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
87 |
88 | for key in list(state_dict.keys()):
89 | if 'crossattention.self.' in key:
90 | new_key0 = key.replace('self','self0')
91 | new_key1 = key.replace('self','self1')
92 | state_dict[new_key0] = state_dict[key]
93 | state_dict[new_key1] = state_dict[key]
94 | elif 'crossattention.output.dense.' in key:
95 | new_key0 = key.replace('dense','dense0')
96 | new_key1 = key.replace('dense','dense1')
97 | state_dict[new_key0] = state_dict[key]
98 | state_dict[new_key1] = state_dict[key]
99 |
100 | msg = model.load_state_dict(state_dict,strict=False)
101 | # print('load checkpoint from %s'%url_or_filename)
102 | return model,msg
103 |
--------------------------------------------------------------------------------
/vebench/models/tradition.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | from functools import partial, reduce
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | from .backbone.conv_backbone import convnext_3d_tiny
10 | from .head import VARHead, VQAHead,VQAHead_cls
11 | from .backbone.swin_backbone import SwinTransformer3D as VideoBackbone
12 |
13 | class DOVER(nn.Module):
14 | def __init__(
15 | self,
16 | backbone_size="divided",
17 | backbone_preserve_keys="technical,aesthetic",
18 | multi=False,
19 | layer=-1,
20 | backbone=dict(
21 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)}
22 | ),
23 | divide_head=True,
24 | vqa_head=dict(in_channels=768),
25 | var=False,
26 | model_path=None,
27 | ):
28 | self.backbone_preserve_keys = backbone_preserve_keys.split(",")
29 | self.multi = multi
30 | self.layer = layer
31 | super().__init__()
32 | for key, hypers in backbone.items():
33 | if key not in self.backbone_preserve_keys:
34 | continue
35 | if backbone_size == "divided":
36 | t_backbone_size = hypers["type"]
37 | else:
38 | t_backbone_size = backbone_size
39 | if t_backbone_size == "swin_tiny_grpb":
40 | # to reproduce fast-vqa
41 | b = VideoBackbone()
42 | elif t_backbone_size == "conv_tiny":
43 | b = convnext_3d_tiny(pretrained=model_path)
44 | else:
45 | raise NotImplementedError
46 | setattr(self, key + "_backbone", b)
47 | if divide_head:
48 | for key in backbone:
49 | pre_pool = False #if key == "technical" else True
50 | if key not in self.backbone_preserve_keys:
51 | continue
52 | b = VQAHead_cls(pre_pool=pre_pool, **vqa_head)
53 | setattr(self, key + "_head", b)
54 | else:
55 | if var:
56 | self.vqa_head = VARHead(**vqa_head)
57 | else:
58 | self.vqa_head = VQAHead(**vqa_head)
59 |
60 | def forward(
61 | self,
62 | vclips,
63 | inference=True,
64 | return_pooled_feats=False,
65 | return_raw_feats=False,
66 | reduce_scores=False,
67 | pooled=False,
68 | **kwargs
69 | ):
70 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return"
71 | if inference:
72 | self.eval()
73 | with torch.no_grad():
74 | scores = []
75 | feats = {}
76 | for key in self.backbone_preserve_keys:
77 | feat = getattr(self, key.split("_")[0] + "_backbone")(
78 | vclips[key], multi=self.multi, layer=self.layer, **kwargs
79 | )
80 | if hasattr(self, key.split("_")[0] + "_head"):
81 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
82 | else:
83 | scores += [getattr(self, "vqa_head")(feat)]
84 | if return_pooled_feats:
85 | feats[key] = feat
86 | if return_raw_feats:
87 | feats[key] = feat
88 | if reduce_scores:
89 | if len(scores) > 1:
90 | scores = reduce(lambda x, y: x + y, scores)
91 | else:
92 | scores = scores[0]
93 | if pooled:
94 | scores = torch.mean(scores, (1, 2, 3, 4))
95 | self.train()
96 | if return_pooled_feats or return_raw_feats:
97 | return scores, feats
98 | return scores
99 | else:
100 | self.train()
101 | scores = []
102 | feats = {}
103 | for key in vclips:
104 | feat = getattr(self, key.split("_")[0] + "_backbone")(
105 | vclips[key], multi=self.multi, layer=self.layer, **kwargs
106 | )
107 | if hasattr(self, key.split("_")[0] + "_head"):
108 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
109 | else:
110 | scores += [getattr(self, "vqa_head")(feat)]
111 | if return_pooled_feats:
112 | feats[key] = feat.mean((-3, -2, -1))
113 | if reduce_scores:
114 | if len(scores) > 1:
115 | scores = reduce(lambda x, y: x + y, scores)
116 | else:
117 | scores = scores[0]
118 | if pooled:
119 | scores = torch.mean(scores, (1, 2, 3, 4))
120 |
121 | if return_pooled_feats:
122 | return scores, feats
123 | return scores
--------------------------------------------------------------------------------
/vebench/models/text_alignment.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .head import VQAHead_cls,VARHead,VQAHead
4 | from .backbone.blip import MyBLIP as BLIP
5 |
6 | class VideoTextAlignmentModel(nn.Module):
7 | def __init__(
8 | self,
9 | backbone_size="divided",
10 | backbone_preserve_keys="fragments,resize",
11 | multi=False,
12 | layer=-1,
13 | backbone=dict(
14 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)}
15 | ),
16 | divide_head=False,
17 | head_type='VQAhead_cls',
18 | vqa_head=dict(in_channels=768),
19 | var=False,
20 | use_tn=False,
21 | model_path=None,
22 | ):
23 | self.backbone_preserve_keys = backbone_preserve_keys.split(",")
24 | self.multi = multi
25 | self.layer = layer
26 | super().__init__()
27 |
28 | for key, hypers in backbone.items():
29 | if key not in self.backbone_preserve_keys:
30 | continue
31 | if backbone_size == "divided":
32 | t_backbone_size = hypers["type"]
33 | else:
34 | t_backbone_size = backbone_size
35 |
36 | assert t_backbone_size == "blip"
37 | type = hypers["blip_type"]
38 | b = BLIP(type, model_path)
39 |
40 | setattr(self, key + "_backbone", b)
41 | if divide_head:
42 | for key in backbone:
43 | pre_pool = False # if key == "technical" else True
44 | if key not in self.backbone_preserve_keys:
45 | continue
46 | in_channel = 768
47 | b = VQAHead_cls(pre_pool=pre_pool, in_channels=in_channel, **vqa_head)
48 | setattr(self, key + "_head", b)
49 | else:
50 | if var:
51 | self.vqa_head = VARHead(**vqa_head)
52 | else:
53 | self.vqa_head = VQAHead(**vqa_head)
54 |
55 | def forward(
56 | self,
57 | vclips,
58 | prompts=None,
59 | inference=True,
60 | return_pooled_feats=False,
61 | return_raw_feats=False,
62 | reduce_scores=False,
63 | pooled=False,
64 | **kwargs
65 | ):
66 | # import pdb;pdb.set_trace()
67 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return"
68 | if inference:
69 | self.eval()
70 | with torch.no_grad():
71 | scores = []
72 | feats = {}
73 | for key in self.backbone_preserve_keys:
74 | feat = getattr(self, key.split("_")[0] + "_backbone")(
75 | vclips[key], prompts
76 | )
77 | if hasattr(self, key.split("_")[0] + "_head"):
78 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)[0]]
79 | else:
80 | scores += [getattr(self, "vqa_head")(feat)]
81 | if return_pooled_feats:
82 | feats[key] = feat
83 | if return_raw_feats:
84 | feats[key] = feat
85 | if reduce_scores:
86 | if len(scores) > 1:
87 | scores = reduce(lambda x, y: x + y, scores)
88 | else:
89 | scores = scores[0]
90 | if pooled:
91 | scores = torch.mean(scores, (1, 2, 3, 4))
92 | self.train()
93 | if return_pooled_feats or return_raw_feats:
94 | return scores, feats
95 | return scores
96 | else:
97 | self.train()
98 | scores = []
99 | feats = {}
100 |
101 | for key in vclips:
102 | feat = getattr(self, key.split("_")[0] + "_backbone")(
103 | vclips[key], prompts
104 | )
105 | if hasattr(self, key.split("_")[0] + "_head"):
106 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)[0]]
107 | else:
108 | scores += [getattr(self, "vqa_head")(feat)]
109 | if return_pooled_feats:
110 | feats[key] = feat.mean((-3, -2, -1))
111 | if reduce_scores:
112 | if len(scores) > 1:
113 | scores = reduce(lambda x, y: x + y, scores)
114 | else:
115 | scores = scores[0]
116 | if pooled:
117 | # print(scores.shape)
118 | scores = torch.mean(scores, (1, 2, 3, 4))
119 | # print(scores.shape)
120 |
121 | if return_pooled_feats:
122 | return scores, feats
123 | return scores
124 |
--------------------------------------------------------------------------------
/vebench/models/fidelity.py:
--------------------------------------------------------------------------------
1 | from .backbone.uniformer_backbone import uniformerv2_b16
2 | from .head import VQAHead_cls,VARHead,VQAHead
3 | import torch.nn as nn
4 | import torch
5 |
6 | class DoubleStreamModel(nn.Module):
7 | def __init__(
8 | self,
9 | backbone_size="divided",
10 | backbone_preserve_keys="fragments,resize",
11 | multi=False,
12 | layer=-1,
13 | backbone=dict(
14 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)}
15 | ),
16 | divide_head=False,
17 | head_type='VQAhead_cls',
18 | vqa_head=dict(in_channels=768),
19 | var=False,
20 | use_tn=False,
21 | model_path=None,
22 | ):
23 | self.backbone_preserve_keys = backbone_preserve_keys.split(",")
24 | self.multi = multi
25 | self.layer = layer
26 | super().__init__()
27 |
28 | for key, hypers in backbone.items():
29 | if key not in self.backbone_preserve_keys:
30 | continue
31 | if backbone_size == "divided":
32 | t_backbone_size = hypers["type"]
33 | else:
34 | t_backbone_size = backbone_size
35 | assert t_backbone_size == "uniformerv2_b16"
36 | b = uniformerv2_b16(pretrained=model_path, temporal_downsample=False, no_lmhra=True, t_size=32)
37 | setattr(self, key + "_backbone", b)
38 | if divide_head:
39 | for key in backbone:
40 | pre_pool = False # if key == "technical" else True
41 | if key not in self.backbone_preserve_keys:
42 | continue
43 | in_channel = 1536
44 | b = VQAHead_cls(pre_pool=pre_pool, in_channels=in_channel, **vqa_head)
45 | setattr(self, key + "_head", b)
46 | else:
47 | if var:
48 | self.vqa_head = VARHead(**vqa_head)
49 | else:
50 | self.vqa_head = VQAHead(**vqa_head)
51 |
52 | def forward(
53 | self,
54 | vclips,
55 | prompts=None,
56 | inference=True,
57 | return_pooled_feats=False,
58 | return_raw_feats=False,
59 | reduce_scores=False,
60 | pooled=False,
61 | **kwargs
62 | ):
63 | # import pdb;pdb.set_trace()
64 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return"
65 | if inference:
66 | self.eval()
67 | with torch.no_grad():
68 | scores = []
69 | feats = []
70 | for key in self.backbone_preserve_keys:
71 | if "time" in key:
72 | feat = getattr(self, key.split("_")[0] + "_backbone")(
73 | vclips[key], prompts
74 | )
75 | else:
76 | feat = getattr(self, key.split("_")[0] + "_backbone")(
77 | vclips[key]
78 | )
79 | feats += [feat]
80 |
81 | feats = (torch.cat(feats, dim=1))
82 | if hasattr(self, key.split("_")[0] + "_head"):
83 | scores += [getattr(self, key.split("_")[0] + "_head")(feats)[0]]
84 | else:
85 | scores += [getattr(self, "vqa_head")(feats)]
86 |
87 | if reduce_scores:
88 | if len(scores) > 1:
89 | scores = reduce(lambda x, y: x + y, scores)
90 | else:
91 | scores = scores[0]
92 | if pooled:
93 | scores = torch.mean(scores, (1, 2, 3, 4))
94 | self.train()
95 | if return_pooled_feats or return_raw_feats:
96 | return scores, feats
97 | return scores
98 | else:
99 | self.train()
100 | scores = []
101 | feats = []
102 | for key in vclips:
103 | feat = getattr(self, key.split("_")[0] + "_backbone")(
104 | vclips[key]
105 | )
106 | feats.append(feat)
107 | feats = (torch.cat(feats, dim=1))
108 | if hasattr(self, key.split("_")[0] + "_head"):
109 | scores += [getattr(self, key.split("_")[0] + "_head")(feats)[0]]
110 | else:
111 | scores += [getattr(self, "vqa_head")(feats)]
112 | scores += [torch.zeros_like(scores[0])]
113 | if reduce_scores:
114 | if len(scores) > 1:
115 | scores = reduce(lambda x, y: x + y, scores)
116 | else:
117 | scores = scores[0]
118 | if pooled:
119 | scores = torch.mean(scores, (1, 2, 3, 4))
120 |
121 | if return_pooled_feats:
122 | return scores, feats
123 | return scores
--------------------------------------------------------------------------------
/vebench/blip_models/blip_vqa.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel, BertLMHeadModel
2 | from .blip import create_vit, init_tokenizer, load_checkpoint
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from transformers import BertTokenizer
8 | import numpy as np
9 |
10 | class BLIP_VQA(nn.Module):
11 | def __init__(self,
12 | med_config = 'BLIP_configs/med_config.json',
13 | image_size = 480,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | ):
18 | """
19 | Args:
20 | med_config (str): path for the mixture of encoder-decoder model's configuration file
21 | image_size (int): input image size
22 | vit (str): model size of vision transformer
23 | """
24 | super().__init__()
25 |
26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
27 | self.tokenizer = init_tokenizer()
28 |
29 | encoder_config = BertConfig.from_json_file(med_config)
30 | encoder_config.encoder_width = vision_width
31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
32 |
33 | decoder_config = BertConfig.from_json_file(med_config)
34 | self.text_decoder = BertLMHeadModel(config=decoder_config)
35 |
36 |
37 | def forward(self, video, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
38 | temporal=[]
39 | for i in range(video.shape[2]):
40 | image=video[:,:,i,...]
41 | image_embeds = self.visual_encoder(image)
42 | temporal.append(image_embeds)
43 | temporal=torch.cat(temporal,dim=2)
44 | image_embeds = self.visual_encoder(image)
45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46 |
47 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
48 | return_tensors="pt").to(image.device)
49 | question.input_ids[:,0] = self.tokenizer.enc_token_id
50 |
51 | if train:
52 | '''
53 | n: number of answers for each question
54 | weights: weight for each answer
55 | '''
56 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
57 | answer.input_ids[:,0] = self.tokenizer.bos_token_id
58 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
59 |
60 | question_output = self.text_encoder(question.input_ids,
61 | attention_mask = question.attention_mask,
62 | encoder_hidden_states = image_embeds,
63 | encoder_attention_mask = image_atts,
64 | return_dict = True)
65 |
66 | question_states = []
67 | question_atts = []
68 | for b, n in enumerate(n):
69 | question_states += [question_output.last_hidden_state[b]]*n
70 | question_atts += [question.attention_mask[b]]*n
71 | question_states = torch.stack(question_states,0)
72 | question_atts = torch.stack(question_atts,0)
73 |
74 | answer_output = self.text_decoder(answer.input_ids,
75 | attention_mask = answer.attention_mask,
76 | encoder_hidden_states = question_states,
77 | encoder_attention_mask = question_atts,
78 | labels = answer_targets,
79 | return_dict = True,
80 | reduction = 'none',
81 | )
82 |
83 | loss = weights * answer_output.loss
84 | loss = loss.sum()/image.size(0)
85 |
86 | return loss
87 |
88 |
89 | else:
90 | question_output = self.text_encoder(question.input_ids,
91 | attention_mask = question.attention_mask,
92 | encoder_hidden_states = image_embeds,
93 | encoder_attention_mask = image_atts,
94 | return_dict = True)
95 |
96 | if inference=='generate':
97 | num_beams = 3
98 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
99 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
100 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
101 |
102 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
103 |
104 | outputs = self.text_decoder.generate(input_ids=bos_ids,
105 | max_length=10,
106 | min_length=1,
107 | num_beams=num_beams,
108 | eos_token_id=self.tokenizer.sep_token_id,
109 | pad_token_id=self.tokenizer.pad_token_id,
110 | **model_kwargs)
111 |
112 | answers = []
113 | for output in outputs:
114 | answer = self.tokenizer.decode(output, skip_special_tokens=True)
115 | answers.append(answer)
116 | return answers
117 |
118 | elif inference=='rank':
119 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
120 | answer.input_ids, answer.attention_mask, k_test)
121 | return max_ids
122 |
123 |
124 |
125 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
126 |
127 | num_ques = question_states.size(0)
128 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
129 |
130 | start_output = self.text_decoder(start_ids,
131 | encoder_hidden_states = question_states,
132 | encoder_attention_mask = question_atts,
133 | return_dict = True,
134 | reduction = 'none')
135 | logits = start_output.logits[:,0,:] # first token's logit
136 |
137 | # topk_probs: top-k probability
138 | # topk_ids: [num_question, k]
139 | answer_first_token = answer_ids[:,1]
140 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
141 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
142 |
143 | # answer input: [num_question*k, answer_len]
144 | input_ids = []
145 | input_atts = []
146 | for b, topk_id in enumerate(topk_ids):
147 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
148 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
149 | input_ids = torch.cat(input_ids,dim=0)
150 | input_atts = torch.cat(input_atts,dim=0)
151 |
152 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
153 |
154 | # repeat encoder's output for top-k answers
155 | question_states = tile(question_states, 0, k)
156 | question_atts = tile(question_atts, 0, k)
157 |
158 | output = self.text_decoder(input_ids,
159 | attention_mask = input_atts,
160 | encoder_hidden_states = question_states,
161 | encoder_attention_mask = question_atts,
162 | labels = targets_ids,
163 | return_dict = True,
164 | reduction = 'none')
165 |
166 | log_probs_sum = -output.loss
167 | log_probs_sum = log_probs_sum.view(num_ques,k)
168 |
169 | max_topk_ids = log_probs_sum.argmax(dim=1)
170 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
171 |
172 | return max_ids
173 |
174 |
175 | def blip_vqa(pretrained='',**kwargs):
176 | model = BLIP_VQA(**kwargs)
177 | if pretrained:
178 | model,msg = load_checkpoint(model,pretrained)
179 | # assert(len(msg.missing_keys)==0)
180 | return model
181 |
182 |
183 | def tile(x, dim, n_tile):
184 | init_dim = x.size(dim)
185 | repeat_idx = [1] * x.dim()
186 | repeat_idx[dim] = n_tile
187 | x = x.repeat(*(repeat_idx))
188 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
189 | return torch.index_select(x, dim, order_index.to(x.device))
190 |
191 |
--------------------------------------------------------------------------------
/vebench/blip_models/blip.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | import warnings
9 | warnings.filterwarnings("ignore")
10 |
11 | from .vit import VisionTransformer, interpolate_pos_embed
12 | from .med import BertConfig, BertModel, BertLMHeadModel
13 | from transformers import BertTokenizer
14 |
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 |
19 | import os
20 | from urllib.parse import urlparse
21 | from timm.models.hub import download_cached_file
22 |
23 | class BLIP_Base(nn.Module):
24 | def __init__(self,
25 | med_config = 'BLIP_configs/med_config.json',
26 | image_size = 224,
27 | vit = 'base',
28 | vit_grad_ckpt = False,
29 | vit_ckpt_layer = 0,
30 | ):
31 | """
32 | Args:
33 | med_config (str): path for the mixture of encoder-decoder model's configuration file
34 | image_size (int): input image size
35 | vit (str): model size of vision transformer
36 | """
37 | super().__init__()
38 |
39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40 | self.tokenizer = init_tokenizer()
41 | med_config = BertConfig.from_json_file(med_config)
42 | med_config.encoder_width = vision_width
43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44 |
45 |
46 | def forward(self, image, caption, mode):
47 |
48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50 |
51 | if mode=='image':
52 | # return image features
53 | image_embeds = self.visual_encoder(image)
54 | return image_embeds
55 |
56 | elif mode=='text':
57 | # return text features
58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59 | return_dict = True, mode = 'text')
60 | return text_output.last_hidden_state
61 |
62 | elif mode=='multimodal':
63 | # return multimodel features
64 | image_embeds = self.visual_encoder(image)
65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66 |
67 | text.input_ids[:,0] = self.tokenizer.enc_token_id
68 | output = self.text_encoder(text.input_ids,
69 | attention_mask = text.attention_mask,
70 | encoder_hidden_states = image_embeds,
71 | encoder_attention_mask = image_atts,
72 | return_dict = True,
73 | )
74 | return output.last_hidden_state
75 |
76 |
77 |
78 | class BLIP_Decoder(nn.Module):
79 | def __init__(self,
80 | med_config = 'BLIP_configs/med_config.json',
81 | image_size = 384,
82 | vit = 'base',
83 | vit_grad_ckpt = False,
84 | vit_ckpt_layer = 0,
85 | prompt = 'a picture of ',
86 | ):
87 | """
88 | Args:
89 | med_config (str): path for the mixture of encoder-decoder model's configuration file
90 | image_size (int): input image size
91 | vit (str): model size of vision transformer
92 | """
93 | super().__init__()
94 |
95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96 | self.tokenizer = init_tokenizer()
97 | med_config = BertConfig.from_json_file(med_config)
98 | med_config.encoder_width = vision_width
99 | self.text_decoder = BertLMHeadModel(config=med_config)
100 |
101 | self.prompt = prompt
102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103 |
104 |
105 | def forward(self, image, caption):
106 |
107 | image_embeds = self.visual_encoder(image)
108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109 |
110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111 |
112 | text.input_ids[:,0] = self.tokenizer.bos_token_id
113 |
114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115 | decoder_targets[:,:self.prompt_length] = -100
116 |
117 | decoder_output = self.text_decoder(text.input_ids,
118 | attention_mask = text.attention_mask,
119 | encoder_hidden_states = image_embeds,
120 | encoder_attention_mask = image_atts,
121 | labels = decoder_targets,
122 | return_dict = True,
123 | )
124 | loss_lm = decoder_output.loss
125 |
126 | return loss_lm
127 |
128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129 | image_embeds = self.visual_encoder(image)
130 |
131 | if not sample:
132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133 |
134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136 |
137 | prompt = [self.prompt] * image.size(0)
138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139 | input_ids[:,0] = self.tokenizer.bos_token_id
140 | input_ids = input_ids[:, :-1]
141 |
142 | if sample:
143 | #nucleus sampling
144 | outputs = self.text_decoder.generate(input_ids=input_ids,
145 | max_length=max_length,
146 | min_length=min_length,
147 | do_sample=True,
148 | top_p=top_p,
149 | num_return_sequences=1,
150 | eos_token_id=self.tokenizer.sep_token_id,
151 | pad_token_id=self.tokenizer.pad_token_id,
152 | repetition_penalty=1.1,
153 | **model_kwargs)
154 | else:
155 | #beam search
156 | outputs = self.text_decoder.generate(input_ids=input_ids,
157 | max_length=max_length,
158 | min_length=min_length,
159 | num_beams=num_beams,
160 | eos_token_id=self.tokenizer.sep_token_id,
161 | pad_token_id=self.tokenizer.pad_token_id,
162 | repetition_penalty=repetition_penalty,
163 | **model_kwargs)
164 |
165 | captions = []
166 | for output in outputs:
167 | caption = self.tokenizer.decode(output, skip_special_tokens=True)
168 | captions.append(caption[len(self.prompt):])
169 | return captions
170 |
171 |
172 | def blip_decoder(pretrained='',**kwargs):
173 | model = BLIP_Decoder(**kwargs)
174 | if pretrained:
175 | model,msg = load_checkpoint(model,pretrained)
176 | assert(len(msg.missing_keys)==0)
177 | return model
178 |
179 | def blip_feature_extractor(pretrained='',**kwargs):
180 | model = BLIP_Base(**kwargs)
181 | if pretrained:
182 | model,msg = load_checkpoint(model,pretrained)
183 | assert(len(msg.missing_keys)==0)
184 | return model
185 |
186 | def init_tokenizer():
187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'})
189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
191 | return tokenizer
192 |
193 |
194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
195 |
196 | assert vit in ['base', 'large'], "vit parameter must be base or large"
197 | if vit=='base':
198 | vision_width = 768
199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
201 | drop_path_rate=0 or drop_path_rate
202 | )
203 | elif vit=='large':
204 | vision_width = 1024
205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207 | drop_path_rate=0.1 or drop_path_rate
208 | )
209 | return visual_encoder, vision_width
210 |
211 | def is_url(url_or_filename):
212 | parsed = urlparse(url_or_filename)
213 | return parsed.scheme in ("http", "https")
214 |
215 | def load_checkpoint(model,url_or_filename):
216 | if is_url(url_or_filename):
217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
218 | checkpoint = torch.load(cached_file, map_location='cpu')
219 | elif os.path.isfile(url_or_filename):
220 | checkpoint = torch.load(url_or_filename, map_location='cpu')
221 | else:
222 | raise RuntimeError('checkpoint url or path is invalid')
223 |
224 | state_dict = checkpoint['model']
225 |
226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
229 | model.visual_encoder_m)
230 | for key in model.state_dict().keys():
231 | if key in state_dict.keys():
232 | if state_dict[key].shape!=model.state_dict()[key].shape:
233 | del state_dict[key]
234 |
235 | msg = model.load_state_dict(state_dict,strict=False)
236 | # print('load checkpoint from %s'%url_or_filename)
237 | return model,msg
238 |
239 |
--------------------------------------------------------------------------------
/vebench/models/head.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from torchvision.ops import roi_align, roi_pool
8 |
9 |
10 | class MultiHeadCrossAttention(nn.Module):
11 | def __init__(self, embed_dim, query_dim, kv_dim, num_heads, output_dim=None):
12 | super(MultiHeadCrossAttention, self).__init__()
13 | # assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
14 |
15 | self.embed_dim = embed_dim
16 | self.query_dim = query_dim
17 | self.kv_dim = kv_dim
18 | self.output_dim = output_dim if output_dim else embed_dim
19 |
20 | self.num_heads = num_heads
21 | self.head_dim = embed_dim // num_heads
22 |
23 | self.q_proj = nn.Linear(query_dim, embed_dim)
24 | self.k_proj = nn.Linear(kv_dim, embed_dim)
25 | self.v_proj = nn.Linear(kv_dim, embed_dim)
26 | self.out_proj = nn.Linear(embed_dim, output_dim)
27 |
28 | def forward(self, query, key, value, mask=None, return_attn=False):
29 | batch_size = query.size(0)
30 |
31 | # Linear projections
32 | q = self.q_proj(query) # NLC
33 | k = self.k_proj(key)
34 | v = self.v_proj(value)
35 |
36 | # Reshape and transpose for multi-head attention
37 | q = q.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
38 | k = k.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
39 | v = v.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
40 |
41 | # Scaled dot-product attention
42 | scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
43 | if mask is not None:
44 | scores = scores.masked_fill(mask == 0, float('-inf'))
45 | attn = F.softmax(scores, dim=-1)
46 |
47 | # Combine heads
48 | context = torch.matmul(attn, v)
49 | context = context.transpose(1, 2).reshape(batch_size, -1, self.embed_dim)
50 |
51 | # Final linear projection
52 | output = self.out_proj(context)
53 | if return_attn:
54 | return output, attn
55 | return output
56 |
57 |
58 | class AttentionPool3d(nn.Module):
59 | def __init__(self, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.cross_attn = MultiHeadCrossAttention(
62 | embed_dim=embed_dim,
63 | query_dim=embed_dim,
64 | kv_dim=embed_dim,
65 | num_heads=num_heads,
66 | output_dim=output_dim
67 | )
68 | self.num_heads = num_heads
69 |
70 | def forward(self, x, return_attn=False): # x: BCLHW
71 | # import pdb;pdb.set_trace()
72 | x = x.flatten(start_dim=2).permute(2, 0, 1) # BC(LHW) -> (LHW)BC
73 | x_mean = x.mean(dim=0, keepdim=True) # (1)BC
74 | x = torch.cat([x_mean, x], dim=0) # (LHW+1)BC
75 | x = x.permute(1, 0, 2).contiguous() # B(LHW+1)C
76 | x_mean = x_mean.permute(1, 0, 2).contiguous() # B(1)C
77 |
78 | if return_attn:
79 | x, attn = self.cross_attn(query=x_mean, key=x, value=x, return_attn=True) # B(1)C
80 | return x.squeeze(dim=-1), attn
81 | x = self.cross_attn(query=x_mean, key=x, value=x).squeeze(dim=1) # BC
82 | batch, channels = x.shape
83 | x = x.view(batch, channels, 1, 1, 1)
84 |
85 | return x
86 |
87 |
88 | class TextAttentionPool3d(nn.Module):
89 | def __init__(self, embed_dim: int, txt_dim: int, num_heads: int, output_dim: int = None):
90 | super().__init__()
91 | self.cross_attn = MultiHeadCrossAttention(
92 | embed_dim=embed_dim,
93 | query_dim=txt_dim,
94 | kv_dim=embed_dim,
95 | num_heads=num_heads,
96 | output_dim=output_dim
97 | )
98 | self.num_heads = num_heads
99 |
100 | def forward(self, x, txt_feat):
101 | # import pdb;pdb.set_trace()
102 | # import pdb;pdb.set_trace()
103 | x = x.flatten(start_dim=2).permute(2, 0, 1) # BC(LHW) -> (LHW)BC
104 | x_mean = x.mean(dim=0, keepdim=True) # (1)BC
105 | x = torch.cat([x_mean, x], dim=0) # (LHW+1)BC
106 | x = x.permute(1, 0, 2).contiguous() # B(LHW+1)C
107 | x_mean = x_mean.permute(1, 0, 2).contiguous() # B(1)C
108 |
109 | txt_feat = txt_feat.unsqueeze(dim=1) # BC -> B(1)C
110 |
111 | x = self.cross_attn(query=txt_feat, key=x, value=x) # B(1)C
112 | x = x.squeeze(dim=1)
113 | batch, channels = x.shape
114 | x = x.view(batch, channels, 1, 1, 1)
115 | return x
116 |
117 |
118 | class VQAHead(nn.Module):
119 | """MLP Regression Head for VQA.
120 | Args:
121 | in_channels: input channels for MLP
122 | hidden_channels: hidden channels for MLP
123 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
124 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes)
125 | """
126 |
127 | def __init__(
128 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, attn_pool3d=False,
129 | text_pool3d=False, **kwargs
130 | ):
131 | super().__init__()
132 | self.dropout_ratio = dropout_ratio
133 | self.in_channels = in_channels
134 | self.hidden_channels = hidden_channels
135 | self.pre_pool = pre_pool
136 | self.attn_pool3d = attn_pool3d
137 | self.text_pool3d = text_pool3d
138 | if self.dropout_ratio != 0:
139 | self.dropout = nn.Dropout(p=self.dropout_ratio)
140 | else:
141 | self.dropout = None
142 |
143 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
144 | if self.attn_pool3d:
145 | self.attn_pool = AttentionPool3d(embed_dim=self.in_channels, num_heads=12,
146 | output_dim=self.in_channels) # 768//64=12
147 | if self.text_pool3d:
148 | self.text_pool = TextAttentionPool3d(embed_dim=self.in_channels, txt_dim=1024, num_heads=12,
149 | output_dim=self.in_channels)
150 |
151 | self.fc_hid = nn.Conv3d(2 * self.in_channels, self.hidden_channels,
152 | (1, 1, 1)) if self.text_pool3d else nn.Conv3d(self.in_channels, self.hidden_channels,
153 | (1, 1, 1))
154 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1))
155 | self.gelu = nn.GELU()
156 |
157 | def forward(self, x, txt=None, inference=False, rois=None):
158 | # import pdb;pdb.set_trace()
159 | if self.pre_pool:
160 | x = self.avg_pool(x)
161 | if self.attn_pool3d:
162 | x_vis = self.attn_pool(x)
163 | if self.text_pool3d and txt is not None:
164 | x_txt = self.text_pool(x, txt)
165 | if inference and x_txt.size(0) != x_vis.size(0):
166 | x_txt = x_txt.expand(x_vis.size(0), -1, -1, -1, -1)
167 | x = torch.concat([x_vis, x_txt], dim=1)
168 | if self.attn_pool3d and not self.text_pool3d:
169 | x = self.dropout(x_vis)
170 | else:
171 | x = self.dropout(x)
172 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x))))
173 | return qlt_score
174 |
175 |
176 | def clean(serie):
177 | output = serie[(np.isnan(serie) == False) & (np.isinf(serie) == False)]
178 | return output
179 |
180 |
181 | class VQAHead_cls(nn.Module):
182 | """MLP Regression Head for VQA.
183 | Args:
184 | in_channels: input channels for MLP
185 | hidden_channels: hidden channels for MLP
186 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
187 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes)
188 | """
189 |
190 | def __init__(
191 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, attn_pool3d=False,
192 | text_pool3d=False, **kwargs
193 | ):
194 | super().__init__()
195 | self.dropout_ratio = dropout_ratio
196 | self.in_channels = in_channels
197 | self.hidden_channels = hidden_channels
198 | self.pre_pool = pre_pool
199 | self.attn_pool3d = attn_pool3d
200 | self.text_pool3d = text_pool3d
201 | if self.dropout_ratio != 0:
202 | self.dropout = nn.Dropout(p=self.dropout_ratio)
203 | else:
204 | self.dropout = None
205 |
206 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
207 | if self.attn_pool3d:
208 | self.attn_pool = AttentionPool3d(embed_dim=self.in_channels, num_heads=16,
209 | output_dim=self.in_channels) # 768//64=12
210 | if self.text_pool3d:
211 | self.text_pool = TextAttentionPool3d(embed_dim=self.in_channels, txt_dim=1024, num_heads=16,
212 | output_dim=self.in_channels)
213 | # self.fc_hid=nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1))
214 | self.fc_hid = nn.Conv3d(2 * self.in_channels, self.hidden_channels,
215 | (1, 1, 1)) if self.text_pool3d else nn.Conv3d(self.in_channels, self.hidden_channels,
216 | (1, 1, 1))
217 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1))
218 | self.gelu = nn.GELU()
219 |
220 | self.fc_cls1 = nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1))
221 | self.fc_cls2 = nn.Conv3d(self.hidden_channels, 10, (1, 1, 1))
222 | self.gelu_cls = nn.GELU()
223 |
224 | def forward(self, x, txt=None, inference=False, rois=None):
225 | # import pdb;pdb.set_trace()
226 | if self.pre_pool:
227 | x = self.avg_pool(x)
228 | if self.attn_pool3d:
229 | x_vis = self.attn_pool(x)
230 | x_cls = self.fc_cls2(self.dropout(self.gelu_cls(self.fc_cls1(x_vis))))
231 | if self.text_pool3d and txt is not None:
232 | x_txt = self.text_pool(x, txt)
233 | if inference and x_txt.size(0) != x_vis.size(0):
234 | x_txt = x_txt.expand(x_vis.size(0), -1, -1, -1, -1)
235 | x = torch.concat([x_vis, x_txt], dim=1)
236 | if self.attn_pool3d and not self.text_pool3d:
237 | x = self.dropout(x_vis)
238 | else:
239 | x = self.dropout(x)
240 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x))))
241 | # print(qlt_score.shape)
242 | return qlt_score#, x_cls
243 | class VARHead(nn.Module):
244 | """MLP Regression Head for Video Action Recognition.
245 | Args:
246 | in_channels: input channels for MLP
247 | hidden_channels: hidden channels for MLP
248 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
249 | """
250 |
251 | def __init__(self, in_channels=768, out_channels=400, dropout_ratio=0.5, **kwargs):
252 | super().__init__()
253 | self.dropout_ratio = dropout_ratio
254 | self.in_channels = in_channels
255 | self.out_channels = out_channels
256 | if self.dropout_ratio != 0:
257 | self.dropout = nn.Dropout(p=self.dropout_ratio)
258 | else:
259 | self.dropout = None
260 | self.fc = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1))
261 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
262 |
263 | def forward(self, x, rois=None):
264 | x = self.dropout(x)
265 | x = self.avg_pool(x)
266 | out = self.fc(x)
267 | return out
268 |
--------------------------------------------------------------------------------
/vebench/models/backbone/blip.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | import os
9 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
10 | import warnings
11 |
12 | warnings.filterwarnings("ignore")
13 |
14 | from ...blip_models.vit import VisionTransformer, interpolate_pos_embed
15 | from ...blip_models.med import BertConfig, BertModel, BertLMHeadModel
16 | from transformers import BertTokenizer
17 | from timm.models.vision_transformer import Attention as TemporalAttention
18 | from timm.layers import Mlp, DropPath, to_2tuple
19 | from timm.layers import PatchEmbed, Mlp, DropPath, RmsNorm, PatchDropout, SwiGLUPacked, \
20 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
21 | get_act_layer, get_norm_layer
22 |
23 | import torch
24 | from torch import nn
25 | import torch.nn.functional as F
26 | import numpy as np
27 | import os
28 | from urllib.parse import urlparse
29 | from timm.models.hub import download_cached_file
30 |
31 | class MyAttention(nn.Module):
32 |
33 | def __init__(
34 | self,
35 | dim: int,
36 | num_heads: int = 8,
37 | qkv_bias: bool = False,
38 | qk_norm: bool = False,
39 | attn_drop: float = 0.,
40 | proj_drop: float = 0.,
41 | step:int=1,
42 | norm_layer: nn.Module = nn.LayerNorm,
43 | ) -> None:
44 | super().__init__()
45 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
46 | self.num_heads = num_heads
47 | self.head_dim = dim // num_heads
48 | self.scale = self.head_dim ** -0.5
49 | self.fused_attn = use_fused_attn()
50 |
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
53 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
54 | self.attn_drop = nn.Dropout(attn_drop)
55 | self.proj = nn.Linear(dim, dim)
56 | self.proj_drop = nn.Dropout(proj_drop)
57 | self.step=step
58 |
59 | def forward(self, x: torch.Tensor) -> torch.Tensor:
60 | B, T, N, C = x.shape
61 | qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, self.head_dim).permute(3, 1, 0, 4, 2, 5)
62 | q, k, v = qkv.unbind(0)
63 | q, k = self.q_norm(q), self.k_norm(k)
64 | k=torch.cat((k[:self.step,...],k),dim=0)[:int(-1*self.step),...]
65 | v=torch.cat((v[:self.step,...],v),dim=0)[:int(-1*self.step),...]
66 | if self.fused_attn:
67 | x = F.scaled_dot_product_attention(
68 | q, k, v,
69 | dropout_p=self.attn_drop.p if self.training else 0.,
70 | )
71 | else:
72 | q = q * self.scale
73 | attn = q @ k.transpose(-2, -1)
74 | attn = attn.softmax(dim=-1)
75 | attn = self.attn_drop(attn)
76 | return attn
77 |
78 |
79 | from einops import rearrange
80 |
81 | class Block(nn.Module):
82 | def __init__(
83 | self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
84 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None,type="A"):
85 | super().__init__()
86 | self.norm1 = norm_layer(dim)
87 | if ws is None:
88 | self.attn = TemporalAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop)
89 | # elif ws == 1:
90 | # self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
91 | # else:
92 | # self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
94 | self.norm2 = norm_layer(dim)
95 | mlp_hidden_dim = int(dim * mlp_ratio)
96 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
97 | self.temporal_attn_1=MyAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop,step=1)
98 | self.temporal_attn_2 = MyAttention(dim, num_heads, attn_drop=attn_drop, proj_drop=drop,step=2)
99 | self.temporal_conv = nn.Conv1d(dim, dim, kernel_size=3,stride=1, padding=1)
100 | self.type=type
101 | self.gelu=nn.GELU()
102 |
103 | def forward(self, x,B):
104 | # x: (B*T, h*w, C)
105 | x = x + self.drop_path(self.attn(self.norm1(x)))
106 | # spatial
107 | if self.type=="A":
108 | temp = self.mlp(self.norm2(x))
109 |
110 | temp=rearrange(temp,'(b t) l c -> b t l c', b=B)
111 |
112 | # step_1=self.drop_path(self.temporal_attn_1(temp))
113 | # step_2=self.drop_path(self.temporal_attn_2(temp))
114 | # step=torch.cat((step_2,step_1),dim=1)
115 | # temp=torch.cat((step,temp),dim=1)
116 | # temporal
117 | temp = rearrange(temp, 'b t l c -> (b l) c t', b=B)
118 |
119 | temp = self.temporal_conv(temp)
120 | temp = rearrange(temp, '(b l) c t -> (b t) l c', b=B)
121 |
122 | # output
123 | x = x + self.drop_path(temp)
124 | elif self.type=="B":
125 | spatial = self.mlp(self.norm2(x))
126 | temp=rearrange(spatial,'(b t) l c->(b l) c t',b=B)
127 | temp = self.temporal_conv(temp)
128 | temp = rearrange(temp, '(b l) c t -> (b t) l c', b=B)
129 | x=x+self.gelu(temp)+self.gelu(spatial)
130 |
131 | #x=rearrange(x,'(b t) l c -> b t l c',b=B).mean(1)
132 | return rearrange(x,'(b t) l c -> b t l c',b=B).mean(1),rearrange(x,'(b t) l c -> b t l c',b=B)
133 |
134 |
135 | class My_BLIP_Base(nn.Module):
136 | def __init__(self,
137 | med_config,
138 | image_size=224,
139 | vit='base',
140 | vit_grad_ckpt=False,
141 | vit_ckpt_layer=0,
142 | drop_path=0.2,
143 | in_chans=1024,
144 | embed_dim=1024,
145 | patch_size=2,
146 | ):
147 | """
148 | Args:
149 | med_config (str): path for the mixture of encoder-decoder model's configuration file
150 | image_size (int): input image size
151 | vit (str): model size of vision transformer
152 | """
153 | super().__init__()
154 |
155 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
156 | self.tokenizer = init_tokenizer()
157 | med_config = BertConfig.from_json_file(med_config)
158 | med_config.encoder_width = vision_width
159 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
160 | self.temporal_block=Block(dim=1024,num_heads=8,drop_path=0.2)
161 | #self.post_block=PostBlock(t_dim=768,v_dim=1024)
162 | self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163 |
164 |
165 | self.softmax=nn.Softmax(dim=1)
166 |
167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
168 | self.norm = nn.LayerNorm(embed_dim)
169 | for name, m in self.named_modules():
170 | if 'temporal_conv' in name:
171 | nn.init.dirac_(m.weight.data) # initialized to be identity
172 | nn.init.zeros_(m.bias.data)
173 | if 'temporal_fc' in name:
174 | nn.init.constant_(m.weight, 0)
175 | nn.init.constant_(m.bias, 0)
176 |
177 |
178 |
179 | def threeDConv(self,video):#3DConv
180 | temporal=self.visual_encoder(video)
181 | return self.temporal_block(temporal.reshape(-1,temporal.shape[-2],temporal.shape[-1]),B=temporal.shape[0])
182 |
183 |
184 |
185 | def forward(self, video, caption, mode):
186 |
187 | text = self.tokenizer(caption, return_tensors="pt",padding=True).to(video.device)
188 |
189 | assert mode=="multimodal_text"
190 | image_embeds, frame_embeds = self.threeDConv(video) # 8,197,1024
191 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(video.device)
192 |
193 | text.input_ids[:, 0] = self.tokenizer.enc_token_id
194 | output = self.text_encoder(text.input_ids,
195 | attention_mask=text.attention_mask,
196 | encoder_hidden_states=image_embeds,
197 | encoder_attention_mask=image_atts,
198 | return_dict=True,
199 | )
200 | return output.last_hidden_state
201 |
202 |
203 |
204 | def blip_feature_extractor(pretrained='', **kwargs):
205 | base_dir = os.path.dirname(os.path.abspath(__file__))
206 | config = os.path.join(base_dir, 'BLIP_configs', 'med_config.json')
207 | model = My_BLIP_Base(config, vit="large",**kwargs)
208 | if pretrained:
209 | model, msg = load_checkpoint(model, pretrained)
210 | #assert (len(msg.missing_keys) == 0)
211 | return model
212 |
213 |
214 | def init_tokenizer():
215 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
216 | tokenizer.add_special_tokens({'bos_token': '[DEC]'})
217 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
218 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
219 | return tokenizer
220 |
221 |
222 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
223 | assert vit in ['base', 'large'], "vit parameter must be base or large"
224 | if vit == 'base':
225 | vision_width = 768
226 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
227 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing,
228 | ckpt_layer=ckpt_layer,
229 | drop_path_rate=0 or drop_path_rate
230 | )
231 | elif vit == 'large':
232 | vision_width = 1024
233 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
234 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing,
235 | ckpt_layer=ckpt_layer,
236 | drop_path_rate=0.1 or drop_path_rate
237 | )
238 | return visual_encoder, vision_width
239 |
240 |
241 | def is_url(url_or_filename):
242 | parsed = urlparse(url_or_filename)
243 | return parsed.scheme in ("http", "https")
244 |
245 |
246 | def load_checkpoint(model, url_or_filename):
247 | if is_url(url_or_filename):
248 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
249 | checkpoint = torch.load(cached_file, map_location='cpu')
250 | elif os.path.isfile(url_or_filename):
251 | checkpoint = torch.load(url_or_filename, map_location='cpu')
252 | else:
253 | print(url_or_filename)
254 | raise RuntimeError('checkpoint url or path is invalid')
255 |
256 | state_dict = checkpoint['model']
257 |
258 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
259 | model.visual_encoder)
260 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
261 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
262 | model.visual_encoder_m)
263 | for key in model.state_dict().keys():
264 | if key in state_dict.keys():
265 | if state_dict[key].shape != model.state_dict()[key].shape:
266 | del state_dict[key]
267 |
268 | msg = model.load_state_dict(state_dict, strict=False)
269 | # print('load checkpoint from %s' % url_or_filename)
270 | return model, msg
271 |
272 | class MyBLIP(nn.Module):
273 | def __init__(self,type="multimodal", model_path=None):
274 | super().__init__()
275 | self.model = blip_feature_extractor(pretrained=os.path.join(model_path, 'model_large.pth'))
276 | self.type=type
277 |
278 |
279 |
280 | def forward(self, x, text):
281 | B, C, T, H, W = x.size()
282 | return self.model(x,text,self.type).permute(0,2,1).unsqueeze(-1).unsqueeze(-1)
283 |
284 | #return self.model(x, text, "image_attn")
285 |
286 |
--------------------------------------------------------------------------------
/vebench/blip_models/blip_retrieval.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel
2 | from transformers import BertTokenizer
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from models.blip import create_vit, init_tokenizer, load_checkpoint
9 |
10 | class BLIP_Retrieval(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 384,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | embed_dim = 256,
18 | queue_size = 57600,
19 | momentum = 0.995,
20 | negative_all_rank = False,
21 | ):
22 | """
23 | Args:
24 | med_config (str): path for the mixture of encoder-decoder model's configuration file
25 | image_size (int): input image size
26 | vit (str): model size of vision transformer
27 | """
28 | super().__init__()
29 |
30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
31 | self.tokenizer = init_tokenizer()
32 | med_config = BertConfig.from_json_file(med_config)
33 | med_config.encoder_width = vision_width
34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35 |
36 | text_width = self.text_encoder.config.hidden_size
37 |
38 | self.vision_proj = nn.Linear(vision_width, embed_dim)
39 | self.text_proj = nn.Linear(text_width, embed_dim)
40 |
41 | self.itm_head = nn.Linear(text_width, 2)
42 |
43 | # create momentum encoders
44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size)
45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim)
46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
47 | self.text_proj_m = nn.Linear(text_width, embed_dim)
48 |
49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50 | [self.vision_proj,self.vision_proj_m],
51 | [self.text_encoder,self.text_encoder_m],
52 | [self.text_proj,self.text_proj_m],
53 | ]
54 | self.copy_params()
55 |
56 | # create the queue
57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
61 |
62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
64 |
65 | self.queue_size = queue_size
66 | self.momentum = momentum
67 | self.temp = nn.Parameter(0.07*torch.ones([]))
68 |
69 | self.negative_all_rank = negative_all_rank
70 |
71 |
72 | def forward(self, image, caption, alpha, idx):
73 | with torch.no_grad():
74 | self.temp.clamp_(0.001,0.5)
75 |
76 | image_embeds = self.visual_encoder(image)
77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
79 |
80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
81 | return_tensors="pt").to(image.device)
82 |
83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
84 | return_dict = True, mode = 'text')
85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
86 |
87 | ###============== Image-text Contrastive Learning ===================###
88 | idx = idx.view(-1,1)
89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
90 | pos_idx = torch.eq(idx, idx_all).float()
91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
92 |
93 | # get momentum features
94 | with torch.no_grad():
95 | self._momentum_update()
96 | image_embeds_m = self.visual_encoder_m(image)
97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
99 |
100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
101 | return_dict = True, mode = 'text')
102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
104 |
105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
107 |
108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
110 |
111 | sim_i2t = image_feat @ text_feat_m_all / self.temp
112 | sim_t2i = text_feat @ image_feat_m_all / self.temp
113 |
114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
116 |
117 | loss_ita = (loss_i2t+loss_t2i)/2
118 |
119 | idxs = concat_all_gather(idx)
120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
121 |
122 | ###============== Image-text Matching ===================###
123 | encoder_input_ids = text.input_ids.clone()
124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id
125 |
126 | # forward the positve image-text pair
127 | bs = image.size(0)
128 | output_pos = self.text_encoder(encoder_input_ids,
129 | attention_mask = text.attention_mask,
130 | encoder_hidden_states = image_embeds,
131 | encoder_attention_mask = image_atts,
132 | return_dict = True,
133 | )
134 |
135 |
136 | if self.negative_all_rank:
137 | # compute sample similarity
138 | with torch.no_grad():
139 | mask = torch.eq(idx, idxs.t())
140 |
141 | image_feat_world = concat_all_gather(image_feat)
142 | text_feat_world = concat_all_gather(text_feat)
143 |
144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp
145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp
146 |
147 | weights_i2t = F.softmax(sim_i2t,dim=1)
148 | weights_i2t.masked_fill_(mask, 0)
149 |
150 | weights_t2i = F.softmax(sim_t2i,dim=1)
151 | weights_t2i.masked_fill_(mask, 0)
152 |
153 | image_embeds_world = all_gather_with_grad(image_embeds)
154 |
155 | # select a negative image (from all ranks) for each text
156 | image_embeds_neg = []
157 | for b in range(bs):
158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
159 | image_embeds_neg.append(image_embeds_world[neg_idx])
160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
161 |
162 | # select a negative text (from all ranks) for each image
163 | input_ids_world = concat_all_gather(encoder_input_ids)
164 | att_mask_world = concat_all_gather(text.attention_mask)
165 |
166 | text_ids_neg = []
167 | text_atts_neg = []
168 | for b in range(bs):
169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
170 | text_ids_neg.append(input_ids_world[neg_idx])
171 | text_atts_neg.append(att_mask_world[neg_idx])
172 |
173 | else:
174 | with torch.no_grad():
175 | mask = torch.eq(idx, idx.t())
176 |
177 | sim_i2t = image_feat @ text_feat.t() / self.temp
178 | sim_t2i = text_feat @ image_feat.t() / self.temp
179 |
180 | weights_i2t = F.softmax(sim_i2t,dim=1)
181 | weights_i2t.masked_fill_(mask, 0)
182 |
183 | weights_t2i = F.softmax(sim_t2i,dim=1)
184 | weights_t2i.masked_fill_(mask, 0)
185 |
186 | # select a negative image (from same rank) for each text
187 | image_embeds_neg = []
188 | for b in range(bs):
189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
190 | image_embeds_neg.append(image_embeds[neg_idx])
191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
192 |
193 | # select a negative text (from same rank) for each image
194 | text_ids_neg = []
195 | text_atts_neg = []
196 | for b in range(bs):
197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
198 | text_ids_neg.append(encoder_input_ids[neg_idx])
199 | text_atts_neg.append(text.attention_mask[neg_idx])
200 |
201 | text_ids_neg = torch.stack(text_ids_neg,dim=0)
202 | text_atts_neg = torch.stack(text_atts_neg,dim=0)
203 |
204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
206 |
207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0)
209 |
210 | output_neg = self.text_encoder(text_ids_all,
211 | attention_mask = text_atts_all,
212 | encoder_hidden_states = image_embeds_all,
213 | encoder_attention_mask = image_atts_all,
214 | return_dict = True,
215 | )
216 |
217 |
218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
219 | vl_output = self.itm_head(vl_embeddings)
220 |
221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
222 | dim=0).to(image.device)
223 | loss_itm = F.cross_entropy(vl_output, itm_labels)
224 |
225 | return loss_ita, loss_itm
226 |
227 |
228 | @torch.no_grad()
229 | def copy_params(self):
230 | for model_pair in self.model_pairs:
231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
232 | param_m.data.copy_(param.data) # initialize
233 | param_m.requires_grad = False # not update by gradient
234 |
235 |
236 | @torch.no_grad()
237 | def _momentum_update(self):
238 | for model_pair in self.model_pairs:
239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
241 |
242 |
243 | @torch.no_grad()
244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
245 | # gather keys before updating queue
246 | image_feats = concat_all_gather(image_feat)
247 | text_feats = concat_all_gather(text_feat)
248 |
249 |
250 | batch_size = image_feats.shape[0]
251 |
252 | ptr = int(self.ptr_queue)
253 | assert self.queue_size % batch_size == 0 # for simplicity
254 |
255 | # replace the keys at ptr (dequeue and enqueue)
256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
259 | ptr = (ptr + batch_size) % self.queue_size # move pointer
260 |
261 | self.ptr_queue[0] = ptr
262 |
263 |
264 | def blip_retrieval(pretrained='',**kwargs):
265 | model = BLIP_Retrieval(**kwargs)
266 | if pretrained:
267 | model,msg = load_checkpoint(model,pretrained)
268 | print("missing keys:")
269 | print(msg.missing_keys)
270 | return model
271 |
272 |
273 | @torch.no_grad()
274 | def concat_all_gather(tensor):
275 | """
276 | Performs all_gather operation on the provided tensors.
277 | *** Warning ***: torch.distributed.all_gather has no gradient.
278 | """
279 | tensors_gather = [torch.ones_like(tensor)
280 | for _ in range(torch.distributed.get_world_size())]
281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
282 |
283 | output = torch.cat(tensors_gather, dim=0)
284 | return output
285 |
286 |
287 | class GatherLayer(torch.autograd.Function):
288 | """
289 | Gather tensors from all workers with support for backward propagation:
290 | This implementation does not cut the gradients as torch.distributed.all_gather does.
291 | """
292 |
293 | @staticmethod
294 | def forward(ctx, x):
295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
296 | torch.distributed.all_gather(output, x)
297 | return tuple(output)
298 |
299 | @staticmethod
300 | def backward(ctx, *grads):
301 | all_gradients = torch.stack(grads)
302 | torch.distributed.all_reduce(all_gradients)
303 | return all_gradients[torch.distributed.get_rank()]
304 |
305 |
306 | def all_gather_with_grad(tensors):
307 | """
308 | Performs all_gather operation on the provided tensors.
309 | Graph remains connected for backward grad computation.
310 | """
311 | # Queue the gathered tensors
312 | world_size = torch.distributed.get_world_size()
313 | # There is no need for reduction in the single-proc case
314 | if world_size == 1:
315 | return tensors
316 |
317 | tensor_all = GatherLayer.apply(tensors)
318 |
319 | return torch.cat(tensor_all, dim=0)
320 |
--------------------------------------------------------------------------------
/vebench/preprocess.py:
--------------------------------------------------------------------------------
1 |
2 | import copy
3 | import glob
4 | import os
5 | os.environ["WANDB_MODE"] = "offline"
6 | import os.path as osp
7 | import random
8 | from functools import lru_cache
9 |
10 | import decord
11 | import skvideo.io
12 | import torch
13 | import torchvision
14 | from decord import VideoReader, cpu, gpu
15 |
16 |
17 | decord.bridge.set_bridge("torch")
18 |
19 |
20 | def get_spatial_fragments(
21 | video,
22 | fragments_h=7,
23 | fragments_w=7,
24 | fsize_h=32,
25 | fsize_w=32,
26 | aligned=32,
27 | nfrags=1,
28 | random=False,
29 | random_upsample=False,
30 | fallback_type="upsample",
31 | upsample=-1,
32 | **kwargs,
33 | ):
34 | if upsample > 0:
35 | old_h, old_w = video.shape[-2], video.shape[-1]
36 | if old_h >= old_w:
37 | w = upsample
38 | h = int(upsample * old_h / old_w)
39 | else:
40 | h = upsample
41 | w = int(upsample * old_w / old_h)
42 |
43 | video = get_resized_video(video, h, w)
44 | size_h = fragments_h * fsize_h
45 | size_w = fragments_w * fsize_w
46 | ## video: [C,T,H,W]
47 | ## situation for images
48 | if video.shape[1] == 1:
49 | aligned = 1
50 |
51 | dur_t, res_h, res_w = video.shape[-3:]
52 | ratio = min(res_h / size_h, res_w / size_w)
53 | if fallback_type == "upsample" and ratio < 1:
54 |
55 | ovideo = video
56 | video = torch.nn.functional.interpolate(
57 | video / 255.0, scale_factor=1 / ratio, mode="bilinear"
58 | )
59 | video = (video * 255.0).type_as(ovideo)
60 |
61 | if random_upsample:
62 |
63 | randratio = random.random() * 0.5 + 1
64 | video = torch.nn.functional.interpolate(
65 | video / 255.0, scale_factor=randratio, mode="bilinear"
66 | )
67 | video = (video * 255.0).type_as(ovideo)
68 |
69 | assert dur_t % aligned == 0, "Please provide match vclip and align index"
70 | size = size_h, size_w
71 |
72 | ## make sure that sampling will not run out of the picture
73 | hgrids = torch.LongTensor(
74 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)]
75 | )
76 | wgrids = torch.LongTensor(
77 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)]
78 | )
79 | hlength, wlength = res_h // fragments_h, res_w // fragments_w
80 |
81 | if random:
82 | print("This part is deprecated. Please remind that.")
83 | if res_h > fsize_h:
84 | rnd_h = torch.randint(
85 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
86 | )
87 | else:
88 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
89 | if res_w > fsize_w:
90 | rnd_w = torch.randint(
91 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
92 | )
93 | else:
94 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
95 | else:
96 | if hlength > fsize_h:
97 | rnd_h = torch.randint(
98 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
99 | )
100 | else:
101 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
102 | if wlength > fsize_w:
103 | rnd_w = torch.randint(
104 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
105 | )
106 | else:
107 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
108 |
109 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device)
110 | # target_videos = []
111 |
112 | for i, hs in enumerate(hgrids):
113 | for j, ws in enumerate(wgrids):
114 | for t in range(dur_t // aligned):
115 | t_s, t_e = t * aligned, (t + 1) * aligned
116 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h
117 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w
118 | if random:
119 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
120 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
121 | else:
122 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
123 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
124 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[
125 | :, t_s:t_e, h_so:h_eo, w_so:w_eo
126 | ]
127 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo])
128 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6)
129 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments
130 | return target_video
131 |
132 |
133 | @lru_cache
134 | def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False):
135 | if random_crop:
136 | return torchvision.transforms.RandomResizedCrop(
137 | (size_h, size_w), scale=(0.40, 1.0)
138 | )
139 | if target_ratio > 1:
140 | size_h = int(target_ratio * size_w)
141 | assert size_h > size_w
142 | elif target_ratio < 1:
143 | size_w = int(size_h / target_ratio)
144 | assert size_w > size_h
145 | return torchvision.transforms.Resize((size_h, size_w))
146 |
147 |
148 | def get_resized_video(
149 | video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs,
150 | ):
151 | video = video.permute(1, 0, 2, 3)
152 | resize_opt = get_resize_function(
153 | size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop
154 | )
155 | video = resize_opt(video).permute(1, 0, 2, 3)
156 | return video
157 |
158 |
159 | def get_arp_resized_video(
160 | video, short_edge=224, train=False, **kwargs,
161 | ):
162 | if train: ## if during training, will random crop into square and then resize
163 | res_h, res_w = video.shape[-2:]
164 | ori_short_edge = min(video.shape[-2:])
165 | if res_h > ori_short_edge:
166 | rnd_h = random.randrange(res_h - ori_short_edge)
167 | video = video[..., rnd_h : rnd_h + ori_short_edge, :]
168 | elif res_w > ori_short_edge:
169 | rnd_w = random.randrange(res_w - ori_short_edge)
170 | video = video[..., :, rnd_h : rnd_h + ori_short_edge]
171 | ori_short_edge = min(video.shape[-2:])
172 | scale_factor = short_edge / ori_short_edge
173 | ovideo = video
174 | video = torch.nn.functional.interpolate(
175 | video / 255.0, scale_factors=scale_factor, mode="bilinear"
176 | )
177 | video = (video * 255.0).type_as(ovideo)
178 | return video
179 |
180 |
181 | def get_arp_fragment_video(
182 | video, short_fragments=7, fsize=32, train=False, **kwargs,
183 | ):
184 | if (
185 | train
186 | ): ## if during training, will random crop into square and then get fragments
187 | res_h, res_w = video.shape[-2:]
188 | ori_short_edge = min(video.shape[-2:])
189 | if res_h > ori_short_edge:
190 | rnd_h = random.randrange(res_h - ori_short_edge)
191 | video = video[..., rnd_h : rnd_h + ori_short_edge, :]
192 | elif res_w > ori_short_edge:
193 | rnd_w = random.randrange(res_w - ori_short_edge)
194 | video = video[..., :, rnd_h : rnd_h + ori_short_edge]
195 | kwargs["fsize_h"], kwargs["fsize_w"] = fsize, fsize
196 | res_h, res_w = video.shape[-2:]
197 | if res_h > res_w:
198 | kwargs["fragments_w"] = short_fragments
199 | kwargs["fragments_h"] = int(short_fragments * res_h / res_w)
200 | else:
201 | kwargs["fragments_h"] = short_fragments
202 | kwargs["fragments_w"] = int(short_fragments * res_w / res_h)
203 | return get_spatial_fragments(video, **kwargs)
204 |
205 |
206 | def get_cropped_video(
207 | video, size_h=224, size_w=224, **kwargs,
208 | ):
209 | kwargs["fragments_h"], kwargs["fragments_w"] = 1, 1
210 | kwargs["fsize_h"], kwargs["fsize_w"] = size_h, size_w
211 | return get_spatial_fragments(video, **kwargs)
212 |
213 |
214 | def get_single_view(
215 | video, sample_type="aesthetic", **kwargs,
216 | ):
217 | if sample_type.startswith("aesthetic"):
218 | video = get_resized_video(video, **kwargs)
219 | elif sample_type.startswith("technical"):
220 | video = get_spatial_fragments(video, **kwargs)
221 | elif sample_type.startswith("clip"):
222 | video = get_resized_video(video, **kwargs)
223 | elif sample_type.startswith("time"):
224 | video = get_resized_video(video, **kwargs)
225 | elif sample_type.startswith("other"):
226 | video = get_spatial_fragments(video, **kwargs)
227 | elif "flow" in sample_type:
228 | video = get_resized_video(video, **kwargs)
229 | elif sample_type == "original":
230 | return video
231 |
232 | return video
233 |
234 |
235 | def spatial_temporal_view_decomposition(
236 | video_path, sample_types, samplers, edit_video_path=None,is_train=False, augment=False,
237 | ):
238 | video = {}
239 | if video_path.endswith(".yuv"):
240 | print("This part will be deprecated due to large memory cost.")
241 | ## This is only an adaptation to LIVE-Qualcomm
242 | ovideo = skvideo.io.vread(
243 | video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
244 | )
245 | for stype in samplers:
246 | frame_inds = samplers[stype](ovideo.shape[0], is_train)
247 | imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
248 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
249 | del ovideo
250 | else:
251 | decord.bridge.set_bridge("torch")
252 | vreader = VideoReader(video_path)
253 | ### Avoid duplicated video decoding!!! Important!!!!
254 | all_frame_inds = []
255 | frame_inds = {}
256 | for stype in samplers:
257 | frame_inds[stype] = samplers[stype](len(vreader), is_train)
258 | all_frame_inds.append(frame_inds[stype])
259 |
260 | ### Each frame is only decoded one time!!!
261 | all_frame_inds = np.concatenate(all_frame_inds, 0)
262 | frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
263 |
264 | for stype in samplers:
265 | imgs = [frame_dict[idx] for idx in frame_inds[stype]]
266 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
267 |
268 | sampled_video = {}
269 | for stype, sopt in sample_types.items():
270 | sampled_video[stype] = get_single_view(video[stype], stype, **sopt)
271 | return sampled_video, frame_inds
272 |
273 |
274 |
275 |
276 |
277 | import random
278 |
279 | import numpy as np
280 |
281 |
282 | class UnifiedFrameSampler:
283 | def __init__(
284 | self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0,
285 | ):
286 |
287 | self.fragments_t = fragments_t
288 | self.fsize_t = fsize_t
289 | self.size_t = fragments_t * fsize_t
290 | self.frame_interval = frame_interval
291 | self.num_clips = num_clips
292 | self.drop_rate = drop_rate
293 |
294 | def get_frame_indices(self, num_frames, train=False):
295 |
296 | tgrids = np.array(
297 | [num_frames // self.fragments_t * i for i in range(self.fragments_t)],
298 | dtype=np.int32,
299 | )
300 | tlength = num_frames // self.fragments_t
301 |
302 | if tlength > self.fsize_t * self.frame_interval:
303 | rnd_t = np.random.randint(
304 | 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids)
305 | )
306 | else:
307 | rnd_t = np.zeros(len(tgrids), dtype=np.int32)
308 |
309 | ranges_t = (
310 | np.arange(self.fsize_t)[None, :] * self.frame_interval
311 | + rnd_t[:, None]
312 | + tgrids[:, None]
313 | )
314 |
315 | drop = random.sample(
316 | list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate)
317 | )
318 | dropped_ranges_t = []
319 | for i, rt in enumerate(ranges_t):
320 | if i not in drop:
321 | dropped_ranges_t.append(rt)
322 | return np.concatenate(dropped_ranges_t)
323 |
324 | def __call__(self, total_frames, train=False, start_index=0):
325 | frame_inds = []
326 |
327 | for i in range(self.num_clips):
328 | frame_inds += [self.get_frame_indices(total_frames)]
329 |
330 | frame_inds = np.concatenate(frame_inds)
331 | frame_inds = np.mod(frame_inds + start_index, total_frames)
332 | return frame_inds.astype(np.int32)
333 |
334 |
335 | class Processor():
336 | def __init__(self, opt,from_src=False):
337 | ## opt is a dictionary that includes options for video sampling
338 |
339 | super().__init__()
340 |
341 | self.sample_types = opt["sample_types"]
342 | self.phase = opt["phase"]
343 | self.crop = opt.get("random_crop", False)
344 | self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
345 | self.std = torch.FloatTensor([58.395, 57.12, 57.375])
346 | self.samplers = {}
347 | for stype, sopt in opt["sample_types"].items():
348 | if "t_frag" not in sopt:
349 | # resized temporal sampling for TQE in DOVER
350 | self.samplers[stype] = UnifiedFrameSampler(
351 | sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"]
352 | )
353 | else:
354 | # temporal sampling for AQE in DOVER
355 | self.samplers[stype] = UnifiedFrameSampler(
356 | sopt["clip_len"] // sopt["t_frag"],
357 | sopt["t_frag"],
358 | sopt["frame_interval"],
359 | sopt["num_clips"],
360 | )
361 |
362 | def preprocess(self, filename):
363 | #try:
364 | ## Read Original Frames
365 | ## Process Frames
366 | data, frame_inds = spatial_temporal_view_decomposition(
367 | filename,
368 | self.sample_types,
369 | self.samplers,
370 | self.phase == "test",
371 | (self.phase == "train"),
372 | )
373 |
374 | for k, v in data.items():
375 | data[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute(
376 | 3, 0, 1, 2
377 | ).unsqueeze(0).cuda()
378 |
379 | data["num_clips"] = {}
380 | for stype, sopt in self.sample_types.items():
381 | data["num_clips"][stype] = sopt["num_clips"]
382 | data["frame_inds"] = frame_inds
383 | # except:
384 | # # exception flow
385 | # return {"name": filename}
386 | # edit_name是technical
387 | return data
388 |
--------------------------------------------------------------------------------
/vebench/blip_models/blip_pretrain.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | from .med import BertConfig, BertModel, BertLMHeadModel
9 | from transformers import BertTokenizer
10 | import transformers
11 | transformers.logging.set_verbosity_error()
12 |
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 |
17 | from models.blip import create_vit, init_tokenizer, load_checkpoint
18 |
19 | class BLIP_Pretrain(nn.Module):
20 | def __init__(self,
21 | med_config = 'configs/bert_config.json',
22 | image_size = 224,
23 | vit = 'base',
24 | vit_grad_ckpt = False,
25 | vit_ckpt_layer = 0,
26 | embed_dim = 256,
27 | queue_size = 57600,
28 | momentum = 0.995,
29 | ):
30 | """
31 | Args:
32 | med_config (str): path for the mixture of encoder-decoder model's configuration file
33 | image_size (int): input image size
34 | vit (str): model size of vision transformer
35 | """
36 | super().__init__()
37 |
38 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
39 |
40 | if vit=='base':
41 | checkpoint = torch.hub.load_state_dict_from_url(
42 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
43 | map_location="cpu", check_hash=True)
44 | state_dict = checkpoint["model"]
45 | msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
46 | elif vit=='large':
47 | from timm.models.helpers import load_custom_pretrained
48 | from timm.models.vision_transformer import default_cfgs
49 | load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
50 |
51 | self.tokenizer = init_tokenizer()
52 | encoder_config = BertConfig.from_json_file(med_config)
53 | encoder_config.encoder_width = vision_width
54 | self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
55 | self.text_encoder.resize_token_embeddings(len(self.tokenizer))
56 |
57 | text_width = self.text_encoder.config.hidden_size
58 |
59 | self.vision_proj = nn.Linear(vision_width, embed_dim)
60 | self.text_proj = nn.Linear(text_width, embed_dim)
61 |
62 | self.itm_head = nn.Linear(text_width, 2)
63 |
64 | # create momentum encoders
65 | self.visual_encoder_m, vision_width = create_vit(vit,image_size)
66 | self.vision_proj_m = nn.Linear(vision_width, embed_dim)
67 | self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
68 | self.text_proj_m = nn.Linear(text_width, embed_dim)
69 |
70 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
71 | [self.vision_proj,self.vision_proj_m],
72 | [self.text_encoder,self.text_encoder_m],
73 | [self.text_proj,self.text_proj_m],
74 | ]
75 | self.copy_params()
76 |
77 | # create the queue
78 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
79 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
80 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
81 |
82 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
83 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
84 |
85 | self.queue_size = queue_size
86 | self.momentum = momentum
87 | self.temp = nn.Parameter(0.07*torch.ones([]))
88 |
89 | # create the decoder
90 | decoder_config = BertConfig.from_json_file(med_config)
91 | decoder_config.encoder_width = vision_width
92 | self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
93 | self.text_decoder.resize_token_embeddings(len(self.tokenizer))
94 | tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
95 |
96 |
97 | def forward(self, image, caption, alpha):
98 | with torch.no_grad():
99 | self.temp.clamp_(0.001,0.5)
100 |
101 | image_embeds = self.visual_encoder(image)
102 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
103 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
104 |
105 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
106 | return_tensors="pt").to(image.device)
107 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
108 | return_dict = True, mode = 'text')
109 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
110 |
111 | # get momentum features
112 | with torch.no_grad():
113 | self._momentum_update()
114 | image_embeds_m = self.visual_encoder_m(image)
115 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
116 | image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
117 |
118 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
119 | return_dict = True, mode = 'text')
120 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
121 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
122 |
123 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp
124 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp
125 |
126 | sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
127 | sim_targets.fill_diagonal_(1)
128 |
129 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
130 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
131 |
132 | sim_i2t = image_feat @ text_feat_all / self.temp
133 | sim_t2i = text_feat @ image_feat_all / self.temp
134 |
135 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
136 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
137 |
138 | loss_ita = (loss_i2t+loss_t2i)/2
139 |
140 | self._dequeue_and_enqueue(image_feat_m, text_feat_m)
141 |
142 | ###============== Image-text Matching ===================###
143 | encoder_input_ids = text.input_ids.clone()
144 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id
145 |
146 | # forward the positve image-text pair
147 | bs = image.size(0)
148 | output_pos = self.text_encoder(encoder_input_ids,
149 | attention_mask = text.attention_mask,
150 | encoder_hidden_states = image_embeds,
151 | encoder_attention_mask = image_atts,
152 | return_dict = True,
153 | )
154 | with torch.no_grad():
155 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
156 | weights_t2i.fill_diagonal_(0)
157 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
158 | weights_i2t.fill_diagonal_(0)
159 |
160 | # select a negative image for each text
161 | image_embeds_neg = []
162 | for b in range(bs):
163 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
164 | image_embeds_neg.append(image_embeds[neg_idx])
165 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
166 |
167 | # select a negative text for each image
168 | text_ids_neg = []
169 | text_atts_neg = []
170 | for b in range(bs):
171 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
172 | text_ids_neg.append(encoder_input_ids[neg_idx])
173 | text_atts_neg.append(text.attention_mask[neg_idx])
174 |
175 | text_ids_neg = torch.stack(text_ids_neg,dim=0)
176 | text_atts_neg = torch.stack(text_atts_neg,dim=0)
177 |
178 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
179 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
180 |
181 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
182 | image_atts_all = torch.cat([image_atts,image_atts],dim=0)
183 |
184 | output_neg = self.text_encoder(text_ids_all,
185 | attention_mask = text_atts_all,
186 | encoder_hidden_states = image_embeds_all,
187 | encoder_attention_mask = image_atts_all,
188 | return_dict = True,
189 | )
190 |
191 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
192 | vl_output = self.itm_head(vl_embeddings)
193 |
194 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
195 | dim=0).to(image.device)
196 | loss_itm = F.cross_entropy(vl_output, itm_labels)
197 |
198 | ##================= LM ========================##
199 | decoder_input_ids = text.input_ids.clone()
200 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id
201 | decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
202 |
203 | decoder_output = self.text_decoder(decoder_input_ids,
204 | attention_mask = text.attention_mask,
205 | encoder_hidden_states = image_embeds,
206 | encoder_attention_mask = image_atts,
207 | labels = decoder_targets,
208 | return_dict = True,
209 | )
210 |
211 | loss_lm = decoder_output.loss
212 | return loss_ita, loss_itm, loss_lm
213 |
214 |
215 |
216 | @torch.no_grad()
217 | def copy_params(self):
218 | for model_pair in self.model_pairs:
219 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
220 | param_m.data.copy_(param.data) # initialize
221 | param_m.requires_grad = False # not update by gradient
222 |
223 |
224 | @torch.no_grad()
225 | def _momentum_update(self):
226 | for model_pair in self.model_pairs:
227 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
228 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
229 |
230 |
231 | @torch.no_grad()
232 | def _dequeue_and_enqueue(self, image_feat, text_feat):
233 | # gather keys before updating queue
234 | image_feats = concat_all_gather(image_feat)
235 | text_feats = concat_all_gather(text_feat)
236 |
237 | batch_size = image_feats.shape[0]
238 |
239 | ptr = int(self.queue_ptr)
240 | assert self.queue_size % batch_size == 0 # for simplicity
241 |
242 | # replace the keys at ptr (dequeue and enqueue)
243 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
244 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
245 | ptr = (ptr + batch_size) % self.queue_size # move pointer
246 |
247 | self.queue_ptr[0] = ptr
248 |
249 |
250 | def blip_pretrain(**kwargs):
251 | model = BLIP_Pretrain(**kwargs)
252 | return model
253 |
254 |
255 | @torch.no_grad()
256 | def concat_all_gather(tensor):
257 | """
258 | Performs all_gather operation on the provided tensors.
259 | *** Warning ***: torch.distributed.all_gather has no gradient.
260 | """
261 | tensors_gather = [torch.ones_like(tensor)
262 | for _ in range(torch.distributed.get_world_size())]
263 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
264 |
265 | output = torch.cat(tensors_gather, dim=0)
266 | return output
267 |
268 |
269 | from typing import List
270 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
271 | uninitialized_encoder_weights: List[str] = []
272 | if decoder.__class__ != encoder.__class__:
273 | logger.info(
274 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
275 | )
276 |
277 | def tie_encoder_to_decoder_recursively(
278 | decoder_pointer: nn.Module,
279 | encoder_pointer: nn.Module,
280 | module_name: str,
281 | uninitialized_encoder_weights: List[str],
282 | skip_key: str,
283 | depth=0,
284 | ):
285 | assert isinstance(decoder_pointer, nn.Module) and isinstance(
286 | encoder_pointer, nn.Module
287 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
288 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
289 | assert hasattr(encoder_pointer, "weight")
290 | encoder_pointer.weight = decoder_pointer.weight
291 | if hasattr(decoder_pointer, "bias"):
292 | assert hasattr(encoder_pointer, "bias")
293 | encoder_pointer.bias = decoder_pointer.bias
294 | print(module_name+' is tied')
295 | return
296 |
297 | encoder_modules = encoder_pointer._modules
298 | decoder_modules = decoder_pointer._modules
299 | if len(decoder_modules) > 0:
300 | assert (
301 | len(encoder_modules) > 0
302 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
303 |
304 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
305 | encoder_layer_pos = 0
306 | for name, module in decoder_modules.items():
307 | if name.isdigit():
308 | encoder_name = str(int(name) + encoder_layer_pos)
309 | decoder_name = name
310 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
311 | encoder_modules
312 | ) != len(decoder_modules):
313 | # this can happen if the name corresponds to the position in a list module list of layers
314 | # in this case the decoder has added a cross-attention that the encoder does not have
315 | # thus skip this step and subtract one layer pos from encoder
316 | encoder_layer_pos -= 1
317 | continue
318 | elif name not in encoder_modules:
319 | continue
320 | elif depth > 500:
321 | raise ValueError(
322 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
323 | )
324 | else:
325 | decoder_name = encoder_name = name
326 | tie_encoder_to_decoder_recursively(
327 | decoder_modules[decoder_name],
328 | encoder_modules[encoder_name],
329 | module_name + "/" + name,
330 | uninitialized_encoder_weights,
331 | skip_key,
332 | depth=depth + 1,
333 | )
334 | all_encoder_weights.remove(module_name + "/" + encoder_name)
335 |
336 | uninitialized_encoder_weights += list(all_encoder_weights)
337 |
338 | # tie weights recursively
339 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
340 |
--------------------------------------------------------------------------------
/vebench/models/backbone/uniformer_backbone.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | from collections import OrderedDict
4 |
5 | from timm.models.layers import DropPath
6 | import torch
7 | from torch import nn
8 | from torch.nn import MultiheadAttention
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint as checkpoint
11 |
12 | # import logging as logging
13 |
14 | # logger = logging.get_logger(__name__)
15 |
16 |
17 |
18 | class LayerNorm(nn.LayerNorm):
19 | """Subclass torch's LayerNorm to handle fp16."""
20 |
21 | def forward(self, x):
22 | orig_type = x.dtype
23 | ret = super().forward(x.type(torch.float32))
24 | return ret.type(orig_type)
25 |
26 |
27 | class QuickGELU(nn.Module):
28 | def forward(self, x):
29 | return x * torch.sigmoid(1.702 * x)
30 |
31 |
32 | class Local_MHRA(nn.Module):
33 | def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3):
34 | super().__init__()
35 |
36 | padding = pos_kernel_size // 2
37 | re_d_model = int(d_model // dw_reduction)
38 | self.pos_embed = nn.Sequential(
39 | nn.BatchNorm3d(d_model),
40 | nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0),
41 | nn.Conv3d(re_d_model, re_d_model, kernel_size=(pos_kernel_size, 1, 1), stride=(1, 1, 1),
42 | padding=(padding, 0, 0), groups=re_d_model),
43 | nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0),
44 | )
45 |
46 | # init zero
47 | # logger.info('Init zero for Conv in pos_emb')
48 | nn.init.constant_(self.pos_embed[3].weight, 0)
49 | nn.init.constant_(self.pos_embed[3].bias, 0)
50 |
51 | def forward(self, x):
52 | return self.pos_embed(x)
53 |
54 |
55 | class ResidualAttentionBlock(nn.Module):
56 | def __init__(
57 | self, d_model, n_head, attn_mask=None, drop_path=0.0,
58 | dw_reduction=1.5, no_lmhra=False, double_lmhra=True
59 | ):
60 | super().__init__()
61 |
62 | self.n_head = n_head
63 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
64 | # logger.info(f'Drop path rate: {drop_path}')
65 |
66 | self.no_lmhra = no_lmhra
67 | self.double_lmhra = double_lmhra
68 | # logger.info(f'No L_MHRA: {no_lmhra}')
69 | # logger.info(f'Double L_MHRA: {double_lmhra}')
70 | if not no_lmhra:
71 | self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction)
72 | if double_lmhra:
73 | self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction)
74 |
75 | # spatial
76 | self.attn = MultiheadAttention(d_model, n_head)
77 | self.ln_1 = LayerNorm(d_model)
78 | self.mlp = nn.Sequential(OrderedDict([
79 | ("c_fc", nn.Linear(d_model, d_model * 4)),
80 | ("gelu", QuickGELU()),
81 | ("c_proj", nn.Linear(d_model * 4, d_model))
82 | ]))
83 | self.ln_2 = LayerNorm(d_model)
84 | self.attn_mask = attn_mask
85 |
86 | def attention(self, x):
87 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
88 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
89 |
90 | def forward(self, x, T=8, use_checkpoint=False):
91 | # x: 1+HW, NT, C
92 | if not self.no_lmhra:
93 | # Local MHRA
94 | tmp_x = x[1:, :, :]
95 | L, NT, C = tmp_x.shape
96 | N = NT // T
97 | H = W = int(L ** 0.5)
98 | tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
99 | tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x))
100 | tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
101 | x = torch.cat([x[:1, :, :], tmp_x], dim=0)
102 | # MHSA
103 | if use_checkpoint:
104 | attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x))
105 | x = x + self.drop_path(attn_out)
106 | else:
107 | x = x + self.drop_path(self.attention(self.ln_1(x)))
108 | # Local MHRA
109 | if not self.no_lmhra and self.double_lmhra:
110 | tmp_x = x[1:, :, :]
111 | tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
112 | tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x))
113 | tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
114 | x = torch.cat([x[:1, :, :], tmp_x], dim=0)
115 | # FFN
116 | if use_checkpoint:
117 | mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x))
118 | x = x + self.drop_path(mlp_out)
119 | else:
120 | x = x + self.drop_path(self.mlp(self.ln_2(x)))
121 | return x
122 |
123 |
124 | class Extractor(nn.Module):
125 | def __init__(
126 | self, d_model, n_head, attn_mask=None,
127 | mlp_factor=4.0, dropout=0.0, drop_path=0.0,
128 | ):
129 | super().__init__()
130 |
131 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
132 | # logger.info(f'Drop path rate: {drop_path}')
133 | self.attn = nn.MultiheadAttention(d_model, n_head)
134 | self.ln_1 = nn.LayerNorm(d_model)
135 | d_mlp = round(mlp_factor * d_model)
136 | self.mlp = nn.Sequential(OrderedDict([
137 | ("c_fc", nn.Linear(d_model, d_mlp)),
138 | ("gelu", QuickGELU()),
139 | ("dropout", nn.Dropout(dropout)),
140 | ("c_proj", nn.Linear(d_mlp, d_model))
141 | ]))
142 | self.ln_2 = nn.LayerNorm(d_model)
143 | self.ln_3 = nn.LayerNorm(d_model)
144 | self.attn_mask = attn_mask
145 |
146 | # zero init
147 | nn.init.xavier_uniform_(self.attn.in_proj_weight)
148 | nn.init.constant_(self.attn.out_proj.weight, 0.)
149 | nn.init.constant_(self.attn.out_proj.bias, 0.)
150 | nn.init.xavier_uniform_(self.mlp[0].weight)
151 | nn.init.constant_(self.mlp[-1].weight, 0.)
152 | nn.init.constant_(self.mlp[-1].bias, 0.)
153 |
154 | def attention(self, x, y):
155 | d_model = self.ln_1.weight.size(0)
156 | q = (x @ self.attn.in_proj_weight[:d_model].T) + self.attn.in_proj_bias[:d_model]
157 |
158 | k = (y @ self.attn.in_proj_weight[d_model:-d_model].T) + self.attn.in_proj_bias[d_model:-d_model]
159 | v = (y @ self.attn.in_proj_weight[-d_model:].T) + self.attn.in_proj_bias[-d_model:]
160 | Tx, Ty, N = q.size(0), k.size(0), q.size(1)
161 | q = q.view(Tx, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
162 | k = k.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
163 | v = v.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
164 | aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim ** 0.5))
165 |
166 | aff = aff.softmax(dim=-1)
167 | out = aff @ v
168 | out = out.permute(2, 0, 1, 3).flatten(2)
169 | out = self.attn.out_proj(out)
170 | return out
171 |
172 | def forward(self, x, y):
173 | x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y)))
174 | x = x + self.drop_path(self.mlp(self.ln_2(x)))
175 | return x
176 |
177 |
178 | class Transformer(nn.Module):
179 | def __init__(
180 | self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
181 | use_checkpoint=False, checkpoint_num=[0], t_size=8, dw_reduction=2,
182 | no_lmhra=False, double_lmhra=True,
183 | return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
184 | n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
185 | mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
186 | cls_dropout=0.5, num_classes=400,
187 | frozen=False,
188 | ):
189 | super().__init__()
190 | self.T = t_size
191 | self.return_list = return_list
192 | # backbone
193 | b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)]
194 | self.resblocks = nn.ModuleList([
195 | ResidualAttentionBlock(
196 | width, heads, attn_mask,
197 | drop_path=b_dpr[i],
198 | dw_reduction=dw_reduction,
199 | no_lmhra=no_lmhra,
200 | double_lmhra=double_lmhra,
201 | ) for i in range(layers)
202 | ])
203 | # checkpoint
204 | self.use_checkpoint = use_checkpoint
205 | self.checkpoint_num = checkpoint_num
206 | # logger.info(f'Use checkpoint: {self.use_checkpoint}')
207 | # logger.info(f'Checkpoint number: {self.checkpoint_num}')
208 |
209 | # global block
210 | assert n_layers == len(return_list)
211 | self.frozen = frozen
212 | # self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim))
213 | self.dpe = nn.ModuleList([
214 | nn.Conv3d(n_dim, n_dim, kernel_size=3, stride=1, padding=1, bias=True, groups=n_dim)
215 | for i in range(n_layers)
216 | ])
217 | for m in self.dpe:
218 | nn.init.constant_(m.bias, 0.)
219 | # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
220 | # self.dec = nn.ModuleList([
221 | # Extractor(
222 | # n_dim, n_head, mlp_factor=mlp_factor,
223 | # dropout=mlp_dropout[i], drop_path=dpr[i],
224 | # ) for i in range(n_layers)
225 | # ])
226 | # projection
227 | # self.proj = nn.Sequential(
228 | # nn.LayerNorm(n_dim),
229 | # nn.Dropout(cls_dropout),
230 | # nn.Linear(n_dim, num_classes),
231 | # )
232 | # if not self.frozen:
233 | # self.balance = nn.Parameter(torch.zeros((n_dim)))
234 | # self.sigmoid = nn.Sigmoid()
235 |
236 | def forward(self, x): # 577 80 1024
237 | T_down = self.T
238 | L, NT, C = x.shape
239 | N = NT // T_down
240 | H = W = int((L - 1) ** 0.5)
241 | # cls_token = self.temporal_cls_token.repeat(1, N, 1)
242 | feature = torch.zeros(N, C, T_down, H, W).cuda()
243 | j = -1
244 | for i, resblock in enumerate(self.resblocks):
245 | if self.use_checkpoint and i < self.checkpoint_num[0]:
246 | x = resblock(x, self.T, use_checkpoint=True)
247 | else:
248 | x = resblock(x, T_down)
249 | if i in self.return_list:
250 | j += 1
251 | tmp_x = x.clone()
252 | tmp_x = tmp_x.view(L, N, T_down, C)
253 | # dpe
254 | _, tmp_feats = tmp_x[:1], tmp_x[1:]
255 | tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
256 | tmp_feats = self.dpe[j](tmp_feats.clone()).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()
257 | tmp_x[1:] = tmp_x[1:] + tmp_feats
258 | # global block
259 | # tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C
260 | feature += tmp_x[1:].permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
261 | return feature / 4
262 | # cls_token = self.dec[j](cls_token, tmp_x)
263 | #
264 | # if self.frozen:
265 | # return self.proj(cls_token[0, :, :])
266 | # else:
267 | # weight = self.sigmoid(self.balance)
268 | # residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
269 | #
270 | # return torch.cat([self.proj((1 - weight) * cls_token[0, :, :] + weight * residual).reshape(N,710,1,1,1),torch.zeros(N,58,1,1,1).cuda()],dim=1)
271 |
272 |
273 | class VisionTransformer(nn.Module):
274 | def __init__(
275 | self,
276 | # backbone
277 | input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0.,
278 | use_checkpoint=False, checkpoint_num=[0], t_size=8, kernel_size=3, dw_reduction=1.5,
279 | temporal_downsample=True,
280 | no_lmhra=-False, double_lmhra=True,
281 | # global block
282 | return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
283 | n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
284 | mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
285 | cls_dropout=0.5, num_classes=400,
286 | frozen=False
287 | ):
288 | super().__init__()
289 | self.input_resolution = input_resolution
290 | self.output_dim = output_dim
291 | padding = (kernel_size - 1) // 2
292 | if temporal_downsample:
293 | self.conv1 = nn.Conv3d(3, width, (kernel_size, patch_size, patch_size), (2, patch_size, patch_size),
294 | (padding, 0, 0), bias=False)
295 | t_size = t_size // 2
296 | else:
297 | self.conv1 = nn.Conv3d(3, width, (1, patch_size, patch_size), (1, patch_size, patch_size), (0, 0, 0),
298 | bias=False)
299 |
300 | scale = width ** -0.5
301 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
302 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
303 | self.ln_pre = LayerNorm(width)
304 |
305 | self.transformer = Transformer(
306 | width, layers, heads, dw_reduction=dw_reduction,
307 | backbone_drop_path_rate=backbone_drop_path_rate,
308 | use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size,
309 | no_lmhra=no_lmhra, double_lmhra=double_lmhra,
310 | return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
311 | mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
312 | cls_dropout=cls_dropout, num_classes=num_classes,
313 | frozen=frozen,
314 | )
315 |
316 | def forward(self, x): # (5,3,64,336,336)
317 | x = self.conv1(x) # shape = [*, width, grid, grid]
318 | N, C, T, H, W = x.shape
319 | x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)
320 |
321 | x = torch.cat(
322 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
323 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
324 | x = x + self.positional_embedding.to(x.dtype)
325 | x = self.ln_pre(x)
326 |
327 | x = x.permute(1, 0, 2) # NLD -> LND #577,160,1024
328 | out = self.transformer(x) # 10,710
329 | return out
330 |
331 |
332 | def inflate_weight(weight_2d, time_dim, center=True):
333 | # logger.info(f'Init center: {center}')
334 | if center:
335 | weight_3d = torch.zeros(*weight_2d.shape)
336 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
337 | middle_idx = time_dim // 2
338 | weight_3d[:, :, middle_idx, :, :] = weight_2d
339 | else:
340 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
341 | weight_3d = weight_3d / time_dim
342 | return weight_3d
343 |
344 |
345 | def load_state_dict(model, state_dict):
346 | state_dict_3d = model.state_dict()
347 | new_state_dict = OrderedDict()
348 | for k in state_dict.keys():
349 | if k[9:] not in state_dict_3d:
350 | continue
351 | if state_dict[k].shape != state_dict_3d[k[9:]].shape:
352 | if len(state_dict_3d[k[9:]].shape) <= 2:
353 | new_state_dict[k[9:]] = state_dict[k]
354 | # logger.info(f'Ignore: {k}')
355 | continue
356 | # logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
357 | time_dim = state_dict_3d[k[9:]].shape[2]
358 | new_state_dict[k[9:]] = inflate_weight(state_dict[k], time_dim)
359 | else:
360 | new_state_dict[k[9:]] = state_dict[k]
361 | model.load_state_dict(new_state_dict, strict=False)
362 |
363 |
364 | def uniformerv2_b16(
365 | pretrained=True, use_checkpoint=False, checkpoint_num=[0],
366 | t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
367 | temporal_downsample=True,
368 | no_lmhra=False, double_lmhra=True,
369 | return_list=[8, 9, 10, 11],
370 | n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
371 | mlp_dropout=[0.5, 0.5, 0.5, 0.5],
372 | cls_dropout=0.5, num_classes=400,
373 | frozen=False,
374 | ):
375 | model = VisionTransformer(
376 | input_resolution=224,
377 | patch_size=16,
378 | width=768,
379 | layers=12,
380 | heads=12,
381 | output_dim=512,
382 | use_checkpoint=use_checkpoint,
383 | checkpoint_num=checkpoint_num,
384 | t_size=t_size,
385 | dw_reduction=dw_reduction,
386 | backbone_drop_path_rate=backbone_drop_path_rate,
387 | temporal_downsample=temporal_downsample,
388 | no_lmhra=no_lmhra,
389 | double_lmhra=double_lmhra,
390 | return_list=return_list,
391 | n_layers=n_layers,
392 | n_dim=n_dim,
393 | n_head=n_head,
394 | mlp_factor=mlp_factor,
395 | drop_path_rate=drop_path_rate,
396 | mlp_dropout=mlp_dropout,
397 | cls_dropout=cls_dropout,
398 | num_classes=num_classes,
399 | frozen=frozen,
400 | )
401 |
402 | if pretrained:
403 | # logger.info('load pretrained weights')
404 | state_dict = torch.load(os.path.join(pretrained, 'k400+k710_uniformerv2_b16_8x224.pth'), map_location='cpu')
405 | load_state_dict(model, state_dict)
406 | return model
407 |
--------------------------------------------------------------------------------
/vebench/blip_models/vit.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | * Based on timm code base
8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from functools import partial
15 |
16 | from timm.models.vision_transformer import _cfg, PatchEmbed
17 | from timm.models.registry import register_model
18 | from timm.models.layers import trunc_normal_, DropPath
19 | from timm.models.helpers import named_apply, adapt_input_conv
20 | from timm.models.vision_transformer import Attention as TemporalAttention
21 |
22 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
23 |
24 | class BottleNeckAdapter(nn.Module):
25 | def __init__(self,in_feature,downsample_rate):
26 | super().__init__()
27 | hidden_state_1=in_feature//downsample_rate
28 | hidden_state_2=hidden_state_1//downsample_rate
29 | self.downsample_1=nn.Linear(in_feature,hidden_state_1)
30 | self.downsample_2=nn.Linear(hidden_state_1,hidden_state_2)
31 | #self.gelu=nn.functional.gelu()
32 | self.upsample_1 = nn.Linear(hidden_state_2, hidden_state_1)
33 | self.upsample_2=nn.Linear(hidden_state_1,in_feature)
34 |
35 | def forward(self,x):
36 | y=self.downsample_1(x)
37 | y = self.downsample_2(y)
38 | y=nn.functional.gelu(y)
39 | y=self.upsample_1(y)
40 | y = self.upsample_2(y)
41 | return x+y
42 |
43 | from einops import rearrange
44 |
45 |
46 | class TreeDConvAdapter(nn.Module):
47 | def __init__(
48 | self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
49 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
50 | super().__init__()
51 | self.norm1 = norm_layer(dim)
52 | if ws is None:
53 | self.attn = TemporalAttention(dim, num_heads,attn_drop=attn_drop,proj_drop=drop)
54 | # elif ws == 1:
55 | # self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
56 | # else:
57 | # self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
59 | self.norm2 = norm_layer(dim)
60 | mlp_hidden_dim = int(dim * mlp_ratio)
61 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
62 | self.temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1,groups=dim)
63 | self.apply(self._init_weights)
64 |
65 | def _init_weights(self, m):
66 | if hasattr(m,"weight")and m.weight is not None:
67 | trunc_normal_(m.weight, mean=0.0, std=0.01)
68 | if hasattr(m,"bias") and m.bias is not None:
69 | nn.init.constant_(m.bias, 0)
70 |
71 | def forward(self, x,B):
72 | # x: (B*T, h*w, C)
73 | origin_x=x
74 | x = x + self.drop_path(self.attn(self.norm1(x)))
75 | # spatial
76 | x = self.mlp(self.norm2(x))
77 | #
78 | # temporal
79 | x = rearrange(x, '(b t) l c -> (b l) c t', b=B)
80 | x = self.temporal_conv(x)
81 | x = rearrange(x, '(b l) c t -> (b t) l c', b=B)
82 | #
83 | # # output
84 | x = origin_x + self.drop_path(x)
85 | return x
86 |
87 | class Mlp(nn.Module):
88 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
89 | """
90 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
91 | super().__init__()
92 | out_features = out_features or in_features
93 | hidden_features = hidden_features or in_features
94 | self.fc1 = nn.Linear(in_features, hidden_features)
95 | self.act = act_layer()
96 | self.fc2 = nn.Linear(hidden_features, out_features)
97 | self.drop = nn.Dropout(drop)
98 |
99 | def forward(self, x):
100 | x = self.fc1(x)
101 | x = self.act(x)
102 | x = self.drop(x)
103 | x = self.fc2(x)
104 | x = self.drop(x)
105 | return x
106 |
107 |
108 | class Attention(nn.Module):
109 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
110 | super().__init__()
111 | self.num_heads = num_heads
112 | head_dim = dim // num_heads
113 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
114 | self.scale = qk_scale or head_dim ** -0.5
115 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
116 | self.attn_drop = nn.Dropout(attn_drop)
117 | self.proj = nn.Linear(dim, dim)
118 | self.proj_drop = nn.Dropout(proj_drop)
119 | self.attn_gradients = None
120 | self.attention_map = None
121 |
122 | def save_attn_gradients(self, attn_gradients):
123 | self.attn_gradients = attn_gradients
124 |
125 | def get_attn_gradients(self):
126 | return self.attn_gradients
127 |
128 | def save_attention_map(self, attention_map):
129 | self.attention_map = attention_map
130 |
131 | def get_attention_map(self):
132 | return self.attention_map
133 |
134 | def forward(self, x, register_hook=False):
135 | B, N, C = x.shape
136 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
137 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
138 |
139 | attn = (q @ k.transpose(-2, -1)) * self.scale
140 | attn = attn.softmax(dim=-1)
141 | attn = self.attn_drop(attn)
142 |
143 | if register_hook:
144 | self.save_attention_map(attn)
145 | attn.register_hook(self.save_attn_gradients)
146 |
147 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
148 | x = self.proj(x)
149 | x = self.proj_drop(x)
150 | return x
151 |
152 |
153 | class Block(nn.Module):
154 |
155 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
156 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False,depth=-1):
157 | super().__init__()
158 | self.norm1 = norm_layer(dim)
159 | self.attn = Attention(
160 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
161 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163 | self.norm2 = norm_layer(dim)
164 | mlp_hidden_dim = int(dim * mlp_ratio)
165 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166 | self.depth=depth
167 | # if self.depth in [17,19,21,22,23]:
168 | # self.adapter = TreeDConvAdapter(dim=dim,num_heads=8,mlp_ratio=0.5)
169 |
170 | if use_grad_checkpointing:
171 | self.attn = checkpoint_wrapper(self.attn)
172 | self.mlp = checkpoint_wrapper(self.mlp)
173 |
174 | def forward(self, x, number,B=8,register_hook=False):
175 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
176 | x = x + self.drop_path(self.mlp(self.norm2(x)))
177 | # if self.depth in [17,19,21,22,23]:
178 | # x=self.adapter(x,B)
179 | return x
180 |
181 |
182 | class VisionTransformer(nn.Module):
183 | """ Vision Transformer
184 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
185 | https://arxiv.org/abs/2010.11929
186 | """
187 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
188 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
189 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
190 | use_grad_checkpointing=False, ckpt_layer=0):
191 | """
192 | Args:
193 | img_size (int, tuple): input image size
194 | patch_size (int, tuple): patch size
195 | in_chans (int): number of input channels
196 | num_classes (int): number of classes for classification head
197 | embed_dim (int): embedding dimension
198 | depth (int): depth of transformer
199 | num_heads (int): number of attention heads
200 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
201 | qkv_bias (bool): enable bias for qkv if True
202 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
203 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
204 | drop_rate (float): dropout rate
205 | attn_drop_rate (float): attention dropout rate
206 | drop_path_rate (float): stochastic depth rate
207 | norm_layer: (nn.Module): normalization layer
208 | """
209 | super().__init__()
210 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
211 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
212 |
213 | self.patch_embed = PatchEmbed(
214 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
215 |
216 | num_patches = self.patch_embed.num_patches
217 |
218 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
219 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
220 | self.pos_drop = nn.Dropout(p=drop_rate)
221 |
222 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
223 | self.blocks = nn.ModuleList([
224 | Block(
225 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
226 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
227 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer),
228 | depth=i
229 | )
230 | for i in range(depth)])
231 | self.norm = norm_layer(embed_dim)
232 |
233 | trunc_normal_(self.pos_embed, std=.02)
234 | trunc_normal_(self.cls_token, std=.02)
235 | self.apply(self._init_weights)
236 |
237 | def _init_weights(self, m):
238 | if isinstance(m, nn.Linear):
239 | trunc_normal_(m.weight, std=.02)
240 | if isinstance(m, nn.Linear) and m.bias is not None:
241 | nn.init.constant_(m.bias, 0)
242 | elif isinstance(m, nn.LayerNorm):
243 | nn.init.constant_(m.bias, 0)
244 | nn.init.constant_(m.weight, 1.0)
245 |
246 | @torch.jit.ignore
247 | def no_weight_decay(self):
248 | return {'pos_embed', 'cls_token'}
249 |
250 | def forward(self, video, register_blk=-1):#temporal
251 | B,C,L,W,H = video.shape
252 | temporal = []
253 | video=self.patch_embed(video.reshape(-1,C,W,H))
254 | cls_tokens = self.cls_token.expand(B*L, -1, -1)
255 | video=torch.cat((cls_tokens,video), dim=1)
256 |
257 | video = video + self.pos_embed[:, :video.size(1), :]
258 | x=self.pos_drop(video)
259 | #x=video.reshape(B,L,video.shape[-2],video.shape[-1])
260 | for i,blk in enumerate(self.blocks):
261 | x = blk(x, i,B, register_blk==i)
262 | x = self.norm(x).reshape(B,L,x.shape[-2],x.shape[-1])
263 | # video_mean=rearrange(x, '(b t) l c -> b t l c', b=B).mean(1)
264 | # frame=rearrange(x, '(b t) l c -> b t l c', b=B)
265 | # for i in range(video.shape[2]):
266 | # image = video[:, :, i, ...]
267 | # image_embeds = self.visual_encoder(image)
268 | # temporal.append(image_embeds.unsqueeze(1))
269 | # temporal = torch.cat(temporal, dim=1)
270 |
271 |
272 |
273 | # x = self.patch_embed(x)
274 | #
275 | # cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
276 | # x = torch.cat((cls_tokens, x), dim=1)
277 | #
278 | # x = x + self.pos_embed[:,:x.size(1),:]
279 | # x = self.pos_drop(x)
280 | #
281 | # for i,blk in enumerate(self.blocks):
282 | # x = blk(x, register_blk==i)
283 | # x = self.norm(x)
284 |
285 | return x
286 |
287 | @torch.jit.ignore()
288 | def load_pretrained(self, checkpoint_path, prefix=''):
289 | _load_weights(self, checkpoint_path, prefix)
290 |
291 |
292 | @torch.no_grad()
293 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
294 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
295 | """
296 | import numpy as np
297 |
298 | def _n2p(w, t=True):
299 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
300 | w = w.flatten()
301 | if t:
302 | if w.ndim == 4:
303 | w = w.transpose([3, 2, 0, 1])
304 | elif w.ndim == 3:
305 | w = w.transpose([2, 0, 1])
306 | elif w.ndim == 2:
307 | w = w.transpose([1, 0])
308 | return torch.from_numpy(w)
309 |
310 | w = np.load(checkpoint_path)
311 | if not prefix and 'opt/target/embedding/kernel' in w:
312 | prefix = 'opt/target/'
313 |
314 | if hasattr(model.patch_embed, 'backbone'):
315 | # hybrid
316 | backbone = model.patch_embed.backbone
317 | stem_only = not hasattr(backbone, 'stem')
318 | stem = backbone if stem_only else backbone.stem
319 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
320 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
321 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
322 | if not stem_only:
323 | for i, stage in enumerate(backbone.stages):
324 | for j, block in enumerate(stage.blocks):
325 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
326 | for r in range(3):
327 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
328 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
329 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
330 | if block.downsample is not None:
331 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
332 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
333 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
334 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
335 | else:
336 | embed_conv_w = adapt_input_conv(
337 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
338 | model.patch_embed.proj.weight.copy_(embed_conv_w)
339 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
340 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
341 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
342 | if pos_embed_w.shape != model.pos_embed.shape:
343 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
344 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
345 | model.pos_embed.copy_(pos_embed_w)
346 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
347 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
348 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
349 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
350 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
351 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
352 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
353 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
354 | for i, block in enumerate(model.blocks.children()):
355 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
356 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
357 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
358 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
359 | block.attn.qkv.weight.copy_(torch.cat([
360 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
361 | block.attn.qkv.bias.copy_(torch.cat([
362 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
363 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
364 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
365 | for r in range(2):
366 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
367 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
368 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
369 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
370 |
371 |
372 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
373 | # interpolate position embedding
374 | embedding_size = pos_embed_checkpoint.shape[-1]
375 | num_patches = visual_encoder.patch_embed.num_patches
376 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
377 | # height (== width) for the checkpoint position embedding
378 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
379 | # height (== width) for the new position embedding
380 | new_size = int(num_patches ** 0.5)
381 |
382 | if orig_size!=new_size:
383 | # class_token and dist_token are kept unchanged
384 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
385 | # only the position tokens are interpolated
386 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
387 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
388 | pos_tokens = torch.nn.functional.interpolate(
389 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
390 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
391 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
392 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
393 |
394 | return new_pos_embed
395 | else:
396 | return pos_embed_checkpoint
--------------------------------------------------------------------------------
/vebench/models/backbone/conv_backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import trunc_normal_, DropPath
5 | import os
6 |
7 |
8 | class GRN(nn.Module):
9 | """ GRN (Global Response Normalization) layer
10 | """
11 | def __init__(self, dim):
12 | super().__init__()
13 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
14 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
15 |
16 | def forward(self, x):
17 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
18 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
19 | return self.gamma * (x * Nx) + self.beta + x
20 |
21 | class Block(nn.Module):
22 | r""" ConvNeXt Block. There are two equivalent implementations:
23 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
24 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
25 | We use (2) as we find it slightly faster in PyTorch
26 |
27 | Args:
28 | dim (int): Number of input channels.
29 | drop_path (float): Stochastic depth rate. Default: 0.0
30 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
31 | """
32 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
33 | super().__init__()
34 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
35 | self.norm = LayerNorm(dim, eps=1e-6)
36 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
37 | self.act = nn.GELU()
38 | self.pwconv2 = nn.Linear(4 * dim, dim)
39 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
40 | requires_grad=True) if layer_scale_init_value > 0 else None
41 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
42 |
43 | def forward(self, x):
44 | input = x
45 | x = self.dwconv(x)
46 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
47 | x = self.norm(x)
48 | x = self.pwconv1(x)
49 | x = self.act(x)
50 | x = self.pwconv2(x)
51 | if self.gamma is not None:
52 | x = self.gamma * x
53 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
54 |
55 | x = input + self.drop_path(x)
56 | return x
57 |
58 | class ConvNeXt(nn.Module):
59 | r""" ConvNeXt
60 | A PyTorch impl of : `A ConvNet for the 2020s` -
61 | https://arxiv.org/pdf/2201.03545.pdf
62 | Args:
63 | in_chans (int): Number of input image channels. Default: 3
64 | num_classes (int): Number of classes for classification head. Default: 1000
65 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
66 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
67 | drop_path_rate (float): Stochastic depth rate. Default: 0.
68 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
69 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
70 | """
71 | def __init__(self, in_chans=3, num_classes=1000,
72 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
73 | layer_scale_init_value=1e-6, head_init_scale=1.,
74 | ):
75 | super().__init__()
76 |
77 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
78 | stem = nn.Sequential(
79 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
80 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
81 | )
82 | self.downsample_layers.append(stem)
83 | for i in range(3):
84 | downsample_layer = nn.Sequential(
85 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
86 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
87 | )
88 | self.downsample_layers.append(downsample_layer)
89 |
90 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
91 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
92 | cur = 0
93 | for i in range(4):
94 | stage = nn.Sequential(
95 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
96 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
97 | )
98 | self.stages.append(stage)
99 | cur += depths[i]
100 |
101 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
102 | self.head = nn.Linear(dims[-1], num_classes)
103 |
104 | self.apply(self._init_weights)
105 | self.head.weight.data.mul_(head_init_scale)
106 | self.head.bias.data.mul_(head_init_scale)
107 |
108 | def _init_weights(self, m):
109 | if isinstance(m, (nn.Conv2d, nn.Linear)):
110 | trunc_normal_(m.weight, std=.02)
111 | nn.init.constant_(m.bias, 0)
112 |
113 | def forward_features(self, x):
114 | for i in range(4):
115 | x = self.downsample_layers[i](x)
116 | x = self.stages[i](x)
117 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118 |
119 | def forward(self, x):
120 | x = self.forward_features(x)
121 | x = self.head(x)
122 | return x
123 |
124 | class LayerNorm(nn.Module):
125 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
126 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
127 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
128 | with shape (batch_size, channels, height, width).
129 | """
130 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
131 | super().__init__()
132 | self.weight = nn.Parameter(torch.ones(normalized_shape))
133 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
134 | self.eps = eps
135 | self.data_format = data_format
136 | if self.data_format not in ["channels_last", "channels_first"]:
137 | raise NotImplementedError
138 | self.normalized_shape = (normalized_shape, )
139 |
140 | def forward(self, x):
141 | if self.data_format == "channels_last":
142 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
143 | elif self.data_format == "channels_first":
144 | u = x.mean(1, keepdim=True)
145 | s = (x - u).pow(2).mean(1, keepdim=True)
146 | x = (x - u) / torch.sqrt(s + self.eps)
147 | if len(x.shape) == 4:
148 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
149 | elif len(x.shape) == 5:
150 | x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
151 | return x
152 |
153 |
154 | class Block3D(nn.Module):
155 | r""" ConvNeXt Block. There are two equivalent implementations:
156 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
157 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
158 | We use (2) as we find it slightly faster in PyTorch
159 |
160 | Args:
161 | dim (int): Number of input channels.
162 | drop_path (float): Stochastic depth rate. Default: 0.0
163 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
164 | """
165 | def __init__(self, dim, drop_path=0., inflate_len=3, layer_scale_init_value=1e-6):
166 | super().__init__()
167 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv
168 | self.norm = LayerNorm(dim, eps=1e-6)
169 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
170 | self.act = nn.GELU()
171 | self.pwconv2 = nn.Linear(4 * dim, dim)
172 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
173 | requires_grad=True) if layer_scale_init_value > 0 else None
174 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
175 |
176 | def forward(self, x):
177 | input = x
178 | x = self.dwconv(x)
179 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C)
180 | x = self.norm(x)
181 | x = self.pwconv1(x)
182 | x = self.act(x)
183 | x = self.pwconv2(x)
184 | if self.gamma is not None:
185 | x = self.gamma * x
186 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
187 |
188 | x = input + self.drop_path(x)
189 | return x
190 |
191 | class BlockV2(nn.Module):
192 | """ ConvNeXtV2 Block.
193 |
194 | Args:
195 | dim (int): Number of input channels.
196 | drop_path (float): Stochastic depth rate. Default: 0.0
197 | """
198 | def __init__(self, dim, drop_path=0.):
199 | super().__init__()
200 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
201 | self.norm = LayerNorm(dim, eps=1e-6)
202 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
203 | self.act = nn.GELU()
204 | self.grn = GRN(4 * dim)
205 | self.pwconv2 = nn.Linear(4 * dim, dim)
206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
207 |
208 | def forward(self, x):
209 | input = x
210 | x = self.dwconv(x)
211 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
212 | x = self.norm(x)
213 | x = self.pwconv1(x)
214 | x = self.act(x)
215 | x = self.grn(x)
216 | x = self.pwconv2(x)
217 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
218 |
219 | x = input + self.drop_path(x)
220 | return x
221 |
222 | class BlockV23D(nn.Module):
223 | """ ConvNeXtV2 Block.
224 |
225 | Args:
226 | dim (int): Number of input channels.
227 | drop_path (float): Stochastic depth rate. Default: 0.0
228 | """
229 | def __init__(self, dim, drop_path=0., inflate_len=3,):
230 | super().__init__()
231 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv
232 | self.norm = LayerNorm(dim, eps=1e-6)
233 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
234 | self.act = nn.GELU()
235 | self.grn = GRN(4 * dim)
236 | self.pwconv2 = nn.Linear(4 * dim, dim)
237 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
238 |
239 | def forward(self, x):
240 | input = x
241 | x = self.dwconv(x)
242 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C)
243 | x = self.norm(x)
244 | x = self.pwconv1(x)
245 | x = self.act(x)
246 | x = self.grn(x)
247 | x = self.pwconv2(x)
248 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
249 |
250 | x = input + self.drop_path(x)
251 | return x
252 |
253 | class ConvNeXtV2(nn.Module):
254 | """ ConvNeXt V2
255 |
256 | Args:
257 | in_chans (int): Number of input image channels. Default: 3
258 | num_classes (int): Number of classes for classification head. Default: 1000
259 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
260 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
261 | drop_path_rate (float): Stochastic depth rate. Default: 0.
262 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
263 | """
264 | def __init__(self, in_chans=3, num_classes=1000,
265 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
266 | drop_path_rate=0., head_init_scale=1.
267 | ):
268 | super().__init__()
269 | self.depths = depths
270 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
271 | stem = nn.Sequential(
272 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
273 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
274 | )
275 | self.downsample_layers.append(stem)
276 | for i in range(3):
277 | downsample_layer = nn.Sequential(
278 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
279 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
280 | )
281 | self.downsample_layers.append(downsample_layer)
282 |
283 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
284 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
285 | cur = 0
286 | for i in range(4):
287 | stage = nn.Sequential(
288 | *[BlockV2(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
289 | )
290 | self.stages.append(stage)
291 | cur += depths[i]
292 |
293 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
294 | self.head = nn.Linear(dims[-1], num_classes)
295 |
296 | self.apply(self._init_weights)
297 | self.head.weight.data.mul_(head_init_scale)
298 | self.head.bias.data.mul_(head_init_scale)
299 |
300 | def _init_weights(self, m):
301 | if isinstance(m, (nn.Conv2d, nn.Linear)):
302 | trunc_normal_(m.weight, std=.02)
303 | nn.init.constant_(m.bias, 0)
304 |
305 | def forward_features(self, x):
306 | for i in range(4):
307 | x = self.downsample_layers[i](x)
308 | x = self.stages[i](x)
309 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
310 |
311 | def forward(self, x):
312 | x = self.forward_features(x)
313 | x = self.head(x)
314 | return x
315 |
316 | def convnextv2_atto(**kwargs):
317 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
318 | return model
319 |
320 | def convnextv2_femto(**kwargs):
321 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
322 | return model
323 |
324 | def convnext_pico(**kwargs):
325 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
326 | return model
327 |
328 | def convnextv2_nano(**kwargs):
329 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
330 | return model
331 |
332 | def convnextv2_tiny(**kwargs):
333 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
334 | return model
335 |
336 | def convnextv2_base(**kwargs):
337 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
338 | return model
339 |
340 | def convnextv2_large(**kwargs):
341 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
342 | return model
343 |
344 | def convnextv2_huge(**kwargs):
345 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
346 | return model
347 |
348 | class ConvNeXt3D(nn.Module):
349 | r""" ConvNeXt
350 | A PyTorch impl of : `A ConvNet for the 2020s` -
351 | https://arxiv.org/pdf/2201.03545.pdf
352 | Args:
353 | in_chans (int): Number of input image channels. Default: 3
354 | num_classes (int): Number of classes for classification head. Default: 1000
355 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
356 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
357 | drop_path_rate (float): Stochastic depth rate. Default: 0.
358 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
359 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
360 | """
361 | def __init__(self, in_chans=3, num_classes=1000,
362 | inflate_strategy='131',
363 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
364 | layer_scale_init_value=1e-6, head_init_scale=1.,
365 | ):
366 | super().__init__()
367 |
368 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
369 | stem = nn.Sequential(
370 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)),
371 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
372 | )
373 | self.downsample_layers.append(stem)
374 | for i in range(3):
375 | downsample_layer = nn.Sequential(
376 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
377 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)),
378 | )
379 | self.downsample_layers.append(downsample_layer)
380 |
381 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
382 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
383 | cur = 0
384 | for i in range(4):
385 | stage = nn.Sequential(
386 | *[Block3D(dim=dims[i], inflate_len=int(inflate_strategy[j%len(inflate_strategy)]),
387 | drop_path=dp_rates[cur + j],
388 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
389 | )
390 | self.stages.append(stage)
391 | cur += depths[i]
392 |
393 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
394 |
395 | self.apply(self._init_weights)
396 |
397 | def inflate_weights(self, s_state_dict):
398 | t_state_dict = self.state_dict()
399 | from collections import OrderedDict
400 | for key in t_state_dict.keys():
401 | if key not in s_state_dict:
402 | # print(key)
403 | continue
404 | if t_state_dict[key].shape != s_state_dict[key].shape:
405 | t = t_state_dict[key].shape[2]
406 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t
407 | self.load_state_dict(s_state_dict, strict=False)
408 |
409 | def _init_weights(self, m):
410 | if isinstance(m, (nn.Conv3d, nn.Linear)):
411 | trunc_normal_(m.weight, std=.02)
412 | nn.init.constant_(m.bias, 0)
413 |
414 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1):
415 | if multi:
416 | xs = []
417 | for i in range(4):
418 | x = self.downsample_layers[i](x)
419 | x = self.stages[i](x)
420 | if multi:
421 | xs.append(x)
422 | if return_spatial:
423 | if multi:
424 | shape = xs[-1].shape[2:]
425 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1)
426 | elif layer > -1:
427 | return xs[layer]
428 | else:
429 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
430 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C)
431 |
432 | def forward(self, x, multi=False, layer=-1):
433 | x = self.forward_features(x, True, multi=multi, layer=layer)
434 | return x
435 |
436 |
437 | class ConvNeXtV23D(nn.Module):
438 | """ ConvNeXt V2
439 |
440 | Args:
441 | in_chans (int): Number of input image channels. Default: 3
442 | num_classes (int): Number of classes for classification head. Default: 1000
443 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
444 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
445 | drop_path_rate (float): Stochastic depth rate. Default: 0.
446 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
447 | """
448 | def __init__(self, in_chans=3, num_classes=1000,
449 | inflate_strategy='131',
450 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
451 | drop_path_rate=0., head_init_scale=1.
452 | ):
453 | super().__init__()
454 | self.depths = depths
455 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
456 | stem = nn.Sequential(
457 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)),
458 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
459 | )
460 | self.downsample_layers.append(stem)
461 | for i in range(3):
462 | downsample_layer = nn.Sequential(
463 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
464 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)),
465 | )
466 | self.downsample_layers.append(downsample_layer)
467 |
468 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
469 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
470 | cur = 0
471 | for i in range(4):
472 | stage = nn.Sequential(
473 | *[BlockV23D(dim=dims[i], drop_path=dp_rates[cur + j],
474 | inflate_len=int(inflate_strategy[j%len(inflate_strategy)]),
475 | ) for j in range(depths[i])]
476 | )
477 | self.stages.append(stage)
478 | cur += depths[i]
479 |
480 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
481 | self.head = nn.Linear(dims[-1], num_classes)
482 |
483 | self.apply(self._init_weights)
484 | self.head.weight.data.mul_(head_init_scale)
485 | self.head.bias.data.mul_(head_init_scale)
486 |
487 | def inflate_weights(self, pretrained_path):
488 | t_state_dict = self.state_dict()
489 | s_state_dict = torch.load(pretrained_path)["model"]
490 | from collections import OrderedDict
491 | for key in t_state_dict.keys():
492 | if key not in s_state_dict:
493 | # print(key)
494 | continue
495 | if t_state_dict[key].shape != s_state_dict[key].shape:
496 | # print(t_state_dict[key].shape, s_state_dict[key].shape)
497 | t = t_state_dict[key].shape[2]
498 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t
499 | self.load_state_dict(s_state_dict, strict=False)
500 |
501 | def _init_weights(self, m):
502 | if isinstance(m, (nn.Conv3d, nn.Linear)):
503 | trunc_normal_(m.weight, std=.02)
504 | nn.init.constant_(m.bias, 0)
505 |
506 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1):
507 | if multi:
508 | xs = []
509 | for i in range(4):
510 | x = self.downsample_layers[i](x)
511 | x = self.stages[i](x)
512 | if multi:
513 | xs.append(x)
514 | if return_spatial:
515 | if multi:
516 | shape = xs[-1].shape[2:]
517 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1)
518 | elif layer > -1:
519 | return xs[layer]
520 | else:
521 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
522 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C)
523 |
524 | def forward(self, x, multi=False, layer=-1):
525 | x = self.forward_features(x, True, multi=multi, layer=layer)
526 | return x
527 |
528 |
529 |
530 | def convnext_3d_tiny(pretrained, in_22k=False, **kwargs):
531 | # print("Using Imagenet 22K pretrain", in_22k)
532 | model = ConvNeXt3D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
533 | checkpoint = torch.load(os.path.join(pretrained, 'convnext_tiny_1k_224_ema.pth'), map_location="cpu")
534 | model.inflate_weights(checkpoint["model"])
535 |
536 | return model
537 |
538 |
--------------------------------------------------------------------------------