├── checkpoints
└── .keep
├── third_party
├── gmflow
│ ├── models
│ │ ├── __init__.py
│ │ ├── position.py
│ │ ├── utils.py
│ │ ├── geometry.py
│ │ ├── trident_conv.py
│ │ ├── matching.py
│ │ ├── backbone.py
│ │ ├── gmflow.py
│ │ └── transformer.py
│ ├── config.yaml
│ └── utils.py
├── dover
│ ├── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── head.py
│ │ ├── evaluator.py
│ │ └── conv_backbone.py
│ ├── config.yaml
│ └── dataset.py
└── viclip
│ ├── __init__.py
│ ├── image_transform.py
│ ├── simple_tokenizer.py
│ ├── viclip.py
│ ├── viclip_text.py
│ └── viclip_vision.py
├── asset
└── taxonomy-repo.png
├── requirements.txt
├── doc
├── install.sh
├── leaderboard.md
└── README.md
├── LICENSE
├── v2vbench
├── dover_score.py
├── clip_consistency.py
├── dino_consistency.py
├── dover_utils.py
├── aesthetic_score.py
├── clip_text_alignment.py
├── pick_score.py
├── viclip_text_alignment.py
├── dino_image_alignment.py
├── utils.py
├── base_evaluator.py
├── motion_alignment.py
└── __init__.py
└── .gitignore
/checkpoints/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/third_party/gmflow/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .gmflow import GMFlow
--------------------------------------------------------------------------------
/asset/taxonomy-repo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenhao728/awesome-diffusion-v2v/HEAD/asset/taxonomy-repo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | decord
2 | torch~=1.13
3 | torchvision
4 | transformers>=4.32.1
5 | scipy
6 | numpy
7 | tqdm
8 | timm
9 | einops
10 | pandas
11 | simple-aesthetics-predictor
--------------------------------------------------------------------------------
/third_party/dover/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 20:47:36
5 | @Desc :
6 | @Ref : https://github.com/VQAssessment/DOVER
7 | '''
8 |
--------------------------------------------------------------------------------
/third_party/gmflow/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | num_scales: 1
3 | upsample_factor: 8
4 | feature_channels: 128
5 | attention_type: swin
6 | num_transformer_layers: 6
7 | ffn_dim_expansion: 4
8 | num_head: 1
9 |
10 | data:
11 | dims:
12 | - 3
13 | - 512
14 | - 512
--------------------------------------------------------------------------------
/third_party/viclip/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 20:39:45
5 | @Desc :
6 | @Ref :
7 | https://github.com/OpenGVLab/InternVideo/tree/main
8 | https://github.com/Vchitect/VBench/tree/master
9 | '''
10 | from .image_transform import clip_image_transform
11 | from .simple_tokenizer import SimpleTokenizer
12 | from .viclip import ViCLIP
13 |
--------------------------------------------------------------------------------
/third_party/viclip/image_transform.py:
--------------------------------------------------------------------------------
1 |
2 | from torchvision.transforms import (
3 | CenterCrop,
4 | Compose,
5 | InterpolationMode,
6 | Normalize,
7 | Resize,
8 | ToTensor,
9 | )
10 |
11 |
12 | def clip_image_transform(n_px=224):
13 | return Compose([
14 | Resize(n_px, interpolation=InterpolationMode.BICUBIC),
15 | CenterCrop(n_px),
16 | ToTensor(),
17 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
18 | ])
--------------------------------------------------------------------------------
/third_party/dover/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv_backbone import convnext_3d_small, convnext_3d_tiny
2 | from .evaluator import DOVER, BaseEvaluator, BaseImageEvaluator
3 | from .head import IQAHead, VARHead, VQAHead
4 | from .swin_backbone import SwinTransformer2D as IQABackbone
5 | from .swin_backbone import SwinTransformer3D as VQABackbone
6 | from .swin_backbone import swin_3d_small, swin_3d_tiny
7 |
8 | __all__ = [
9 | "VQABackbone",
10 | "IQABackbone",
11 | "VQAHead",
12 | "IQAHead",
13 | "VARHead",
14 | "BaseEvaluator",
15 | "BaseImageEvaluator",
16 | "DOVER",
17 | ]
18 |
--------------------------------------------------------------------------------
/third_party/dover/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | technical:
3 | fragments_h: 7
4 | fragments_w: 7
5 | fsize_h: 32
6 | fsize_w: 32
7 | aligned: 32
8 | clip_len: 32
9 | frame_interval: 2
10 | num_clips: 3
11 | aesthetic:
12 | size_h: 224
13 | size_w: 224
14 | clip_len: 32
15 | frame_interval: 2
16 | t_frag: 32
17 | num_clips: 1
18 |
19 | model:
20 | backbone:
21 | technical:
22 | type: swin_tiny_grpb
23 | checkpoint: true
24 | pretrained:
25 | aesthetic:
26 | type: conv_tiny
27 | backbone_preserve_keys: technical,aesthetic
28 | divide_head: true
29 | vqa_head:
30 | in_channels: 768
31 | hidden_channels: 64
--------------------------------------------------------------------------------
/doc/install.sh:
--------------------------------------------------------------------------------
1 | # DOVER
2 | # install git lfs to pull the checkpoints from huggingface
3 | git lfs install
4 | git clone https://huggingface.co/teowu/DOVER checkpoints/
5 |
6 | # ViCLIP
7 | # tokenizers
8 | wget https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz \
9 | -P checkpoints/ViCLIP
10 | # model weights
11 | wget https://huggingface.co/OpenGVLab/VBench_Used_Models/blob/main/ViClip-InternVid-10M-FLT.pth \
12 | -P checkpoints/ViCLIP
13 |
14 | # GMFlow
15 | # download the pretrained model from google drive
16 | gdown 1d5C5cgHIxWGsFR1vYs5XrQbbUiZl9TX2 -O checkpoints/
17 | # unzip the model and move it to the correct directory
18 | unzip -n checkpoints/pretrained.zip -d checkpoints/
19 | mv checkpoints/pretrained checkpoints/gmflow
20 | rm checkpoints/pretrained.zip
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Wenhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/gmflow/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 |
10 | def __init__(self, dims, mode='sintel', padding_factor=8):
11 | self.ht, self.wd = dims[-2:]
12 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
13 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
14 | if mode == 'sintel':
15 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
16 | else:
17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
18 |
19 | def pad(self, *inputs):
20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
21 |
22 | def unpad(self, x):
23 | ht, wd = x.shape[-2:]
24 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
25 | return x[..., c[0]:c[1], c[2]:c[3]]
26 |
27 |
28 | def write_flow(flow: torch.Tensor, filename: Path) -> None:
29 | """Write optical flow to file
30 |
31 | Args:
32 | flow (torch.Tensor): Shape (T, 2, H, W)
33 | filename (Path): File to write optical flow
34 | """
35 | filename.parent.mkdir(parents=True, exist_ok=True)
36 | flow = flow.cpu()
37 | torch.save(flow, filename)
--------------------------------------------------------------------------------
/third_party/gmflow/models/position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 |
9 | class PositionEmbeddingSine(nn.Module):
10 | """
11 | This is a more standard version of the position embedding, very similar to the one
12 | used by the Attention is all you need paper, generalized to work on images.
13 | """
14 |
15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
16 | super().__init__()
17 | self.num_pos_feats = num_pos_feats
18 | self.temperature = temperature
19 | self.normalize = normalize
20 | if scale is not None and normalize is False:
21 | raise ValueError("normalize should be True if scale is passed")
22 | if scale is None:
23 | scale = 2 * math.pi
24 | self.scale = scale
25 |
26 | def forward(self, x):
27 | # x = tensor_list.tensors # [B, C, H, W]
28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
29 | b, c, h, w = x.size()
30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
31 | y_embed = mask.cumsum(1, dtype=torch.float32)
32 | x_embed = mask.cumsum(2, dtype=torch.float32)
33 | if self.normalize:
34 | eps = 1e-6
35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
37 |
38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40 |
41 | pos_x = x_embed[:, :, :, None] / dim_t
42 | pos_y = y_embed[:, :, :, None] / dim_t
43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
46 | return pos
--------------------------------------------------------------------------------
/v2vbench/dover_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 16:31:28
5 | @Desc :
6 | @Ref : https://github.com/VQAssessment/DOVER/tree/master
7 | '''
8 | import logging
9 | from pathlib import Path
10 |
11 | import torch
12 | from omegaconf import OmegaConf
13 |
14 | from third_party.dover.models import DOVER
15 |
16 | from .base_evaluator import BaseEvaluator
17 | from .dover_utils import DoverPreprocessor, fuse_results
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class DoverScore(BaseEvaluator):
22 | pretrained_config_file = (Path(__file__) / '../../third_party/dover/config.yaml').resolve().absolute()
23 | pretrained_checkpoint = (Path(__file__) / '../../checkpoints/DOVER/DOVER.pth').resolve().absolute()
24 |
25 | def __init__(
26 | self,
27 | index_file: Path,
28 | edit_video_dir: Path,
29 | reference_video_dir: Path,
30 | edit_prompt: str,
31 | device: torch.device,
32 | pretrained_config_file: str = None,
33 | pretrained_checkpoint: str = None,
34 | ):
35 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
36 |
37 | pretrained_config_file = pretrained_config_file or self.pretrained_config_file
38 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint
39 |
40 | logger.debug(f"Loding model {pretrained_checkpoint}")
41 | config = OmegaConf.to_container(OmegaConf.load(pretrained_config_file))
42 | self.preprocessor = DoverPreprocessor(config["data"])
43 | self.model = DOVER(**config['model'])
44 | self.model.to(self.device)
45 | self.model.load_state_dict(torch.load(pretrained_checkpoint, map_location=self.device))
46 | self.model.eval()
47 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
48 |
49 | def range(self):
50 | return 0, 1
51 |
52 | def preprocess(self, sample):
53 | video = sample['edit_video']
54 |
55 | views = self.preprocessor(video)
56 | for k, v in views.items():
57 | views[k] = v.to(self.device)
58 | return views
59 |
60 | @torch.no_grad()
61 | def evaluate(self, views) -> float:
62 | results = [r.mean().item() for r in self.model(views)]
63 | # score
64 | scores = fuse_results(results)
65 | return scores
--------------------------------------------------------------------------------
/v2vbench/clip_consistency.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/07 16:06:35
5 | @Desc :
6 | @Ref :
7 | https://github.com/openai/CLIP
8 | https://github.com/mlfoundations/open_clip
9 | https://huggingface.co/docs/transformers/model_doc/clip#clip
10 | '''
11 | import logging
12 | from pathlib import Path
13 |
14 | import torch
15 | from transformers import CLIPImageProcessor, CLIPVisionModel
16 |
17 | from .base_evaluator import BaseEvaluator
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class ClipConsistency(BaseEvaluator):
22 | pretrained_model_name = 'openai/clip-vit-large-patch14'
23 |
24 | def __init__(
25 | self,
26 | index_file: Path,
27 | edit_video_dir: Path,
28 | reference_video_dir: Path,
29 | edit_prompt: str,
30 | device: torch.device,
31 | pretrained_model_name: str = None,
32 | ):
33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
35 | logger.debug(f"Loding model {pretrained_model_name}")
36 | self.preprocessor = CLIPImageProcessor.from_pretrained(pretrained_model_name)
37 | self.model = CLIPVisionModel.from_pretrained(pretrained_model_name)
38 | self.model.to(self.device)
39 | self.model.eval()
40 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
41 |
42 | def range(self):
43 | return 0, 1
44 |
45 | def preprocess(self, sample):
46 | video = sample['edit_video']
47 | frames = []
48 | for i, frame in enumerate(video):
49 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values)
50 | return frames
51 |
52 | @torch.no_grad()
53 | def evaluate(self, frames) -> float:
54 | similarity = []
55 | former_feature = None
56 | for i, frame in enumerate(frames):
57 | frame = frame.to(self.device)
58 | feature: torch.Tensor = self.model(pixel_values=frame).pooler_output
59 | feature = feature / torch.norm(feature, dim=-1, keepdim=True)
60 |
61 | if i > 0:
62 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item())
63 | similarity.append(sim)
64 | former_feature = feature
65 | return sum(similarity) / len(similarity)
--------------------------------------------------------------------------------
/v2vbench/dino_consistency.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/07 16:06:35
5 | @Desc :
6 | @Ref :
7 | https://github.com/facebookresearch/dinov2
8 | https://huggingface.co/docs/transformers/model_doc/dinov2#dinov2
9 | '''
10 | import logging
11 | from pathlib import Path
12 |
13 | import torch
14 | from transformers import BitImageProcessor, Dinov2Model
15 |
16 | from .base_evaluator import BaseEvaluator
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | class DinoConsistency(BaseEvaluator):
21 | pretrained_model_name = 'facebook/dinov2-base'
22 |
23 | def __init__(
24 | self,
25 | index_file: Path,
26 | edit_video_dir: Path,
27 | reference_video_dir: Path,
28 | edit_prompt: str,
29 | device: torch.device,
30 | pretrained_model_name: str = None,
31 | ):
32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
33 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
34 | logger.debug(f"Loding model {pretrained_model_name}")
35 | self.preprocessor = BitImageProcessor.from_pretrained(pretrained_model_name)
36 | self.model = Dinov2Model.from_pretrained(pretrained_model_name)
37 | self.model.to(self.device)
38 | self.model.eval()
39 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
40 |
41 | def range(self):
42 | return 0, 1
43 |
44 | def preprocess(self, sample):
45 | video = sample['edit_video']
46 | frames = []
47 | for i, frame in enumerate(video):
48 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values)
49 | return frames
50 |
51 | @torch.no_grad()
52 | def evaluate(self, frames) -> float:
53 | similarity = []
54 | former_feature = None
55 | for i, frame in enumerate(frames):
56 | frame = frame.to(self.device)
57 | # pooled output (first image token)
58 | feature = self.model(pixel_values=frame).pooler_output
59 | feature: torch.Tensor = feature / torch.norm(feature, dim=-1, keepdim=True)
60 |
61 | if i > 0:
62 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item())
63 | similarity.append(sim)
64 | former_feature = feature
65 | return sum(similarity) / len(similarity)
--------------------------------------------------------------------------------
/v2vbench/dover_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 16:38:16
5 | @Desc :
6 | @Ref : https://github.com/VQAssessment/DOVER/tree/master
7 | '''
8 | import logging
9 | from typing import Dict, List
10 |
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 |
15 | from third_party.dover.dataset import (
16 | UnifiedFrameSampler,
17 | spatial_temporal_view_decomposition,
18 | )
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | class DoverPreprocessor:
23 | mean = torch.FloatTensor([123.675, 116.28, 103.53])
24 | std = torch.FloatTensor([58.395, 57.12, 57.375])
25 |
26 | def __init__(
27 | self,
28 | sample_types: Dict[str, Dict[str, int]],
29 | ):
30 | self.sample_types = sample_types
31 | self.temporal_samplers = {}
32 | for stype, sopt in sample_types.items():
33 | if "t_frag" not in sopt:
34 | # resized temporal sampling for TQE in DOVER
35 | self.temporal_samplers[stype] = UnifiedFrameSampler(
36 | sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"]
37 | )
38 | else:
39 | # temporal sampling for AQE in DOVER
40 | self.temporal_samplers[stype] = UnifiedFrameSampler(
41 | sopt["clip_len"] // sopt["t_frag"],
42 | sopt["t_frag"],
43 | sopt["frame_interval"],
44 | sopt["num_clips"],
45 | )
46 |
47 | def __call__(self, frames: List[Image.Image]):
48 | views, _ = spatial_temporal_view_decomposition(
49 | frames, self.sample_types, self.temporal_samplers
50 | )
51 |
52 | for k, v in views.items():
53 | v: torch.Tensor
54 | num_clips = self.sample_types[k].get("num_clips", 1)
55 | views[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2).reshape(
56 | v.shape[0], num_clips, -1, *v.shape[2:]).transpose(0, 1)
57 | return views
58 |
59 |
60 | def fuse_results(results: list):
61 | logger.debug(f'Before fuse: {results}')
62 | means = [0.1107, 0.08285]
63 | stds = [0.07355, 0.03774]
64 | weights = [0.6104, 0.3896]
65 | x = (results[0] - means[0]) / stds[0] * weights[0] + (results[1] + means[1]) / stds[1] * weights[1]
66 | return 1 / (1 + np.exp(-x)) # sigmoid
--------------------------------------------------------------------------------
/v2vbench/aesthetic_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 15:54:42
5 | @Desc :
6 | @Ref : https://github.com/christophschuhmann/improved-aesthetic-predictor
7 | '''
8 | import logging
9 | from pathlib import Path
10 |
11 | import torch
12 | from aesthetics_predictor import AestheticsPredictorV1, AestheticsPredictorV2Linear
13 | from transformers import CLIPProcessor
14 |
15 | from .base_evaluator import BaseEvaluator
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | class AestheticScore(BaseEvaluator):
20 | # pretrained_model_name = 'shunk031/aesthetics-predictor-v1-vit-large-patch14' # v1
21 | pretrained_model_name = 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' # v2
22 |
23 | def __init__(
24 | self,
25 | index_file: Path,
26 | edit_video_dir: Path,
27 | reference_video_dir: Path,
28 | edit_prompt: str,
29 | device: torch.device,
30 | pretrained_model_name: str = None,
31 | ):
32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
33 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
34 | logger.debug(f"Loading model {pretrained_model_name}")
35 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_model_name)
36 | # self.model = AestheticsPredictorV1.from_pretrained(pretrained_model_name)
37 | self.model = AestheticsPredictorV2Linear.from_pretrained(pretrained_model_name)
38 | self.model.to(self.device)
39 | self.model.eval()
40 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
41 |
42 | def range(self):
43 | return 0, 10
44 |
45 | def preprocess(self, sample):
46 | video = sample['edit_video']
47 | frames = []
48 | for i, frame in enumerate(video):
49 | frames.append(self.preprocessor(images=frame, return_tensors='pt').pixel_values)
50 |
51 | return frames
52 |
53 | @torch.no_grad()
54 | def evaluate(self, frames) -> float:
55 | score = []
56 | for i, frame in enumerate(frames):
57 | if i == 0:
58 | logger.debug(f"Input shape: {frame.shape}")
59 | frame = frame.to(self.device)
60 | prediction = self.model(pixel_values=frame).logits
61 | if i == 0:
62 | logger.debug(f"Output shape: {prediction.shape}")
63 | score.append(prediction.squeeze().cpu().item())
64 | return sum(score) / len(score)
--------------------------------------------------------------------------------
/v2vbench/clip_text_alignment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 18:58:20
5 | @Desc :
6 | @Ref :
7 | https://github.com/openai/CLIP
8 | https://github.com/mlfoundations/open_clip
9 | https://huggingface.co/docs/transformers/model_doc/clip#clip
10 | '''
11 | import logging
12 | from pathlib import Path
13 |
14 | import torch
15 | from transformers import CLIPModel, CLIPProcessor
16 |
17 | from .base_evaluator import BaseEvaluator
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class ClipTextAlignment(BaseEvaluator):
22 | pretrained_model_name = 'openai/clip-vit-large-patch14'
23 |
24 | def __init__(
25 | self,
26 | index_file: Path,
27 | edit_video_dir: Path,
28 | reference_video_dir: Path,
29 | edit_prompt: str,
30 | device: torch.device,
31 | pretrained_model_name: str = None,
32 | ):
33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
35 | logger.debug(f"Loding model {pretrained_model_name}")
36 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_model_name)
37 | self.model = CLIPModel.from_pretrained(pretrained_model_name)
38 | self.model.to(self.device)
39 | self.model.eval()
40 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
41 |
42 | def range(self):
43 | return 0, 100
44 |
45 | def preprocess(self, sample):
46 | text = sample['edit_prompt']
47 | video = sample['edit_video']
48 |
49 | text_inputs = self.preprocessor(
50 | text=text, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)
51 |
52 | image_inputs = []
53 | for frame in video:
54 | image_inputs.append(self.preprocessor(
55 | images=frame, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device))
56 |
57 | return text_inputs, image_inputs
58 |
59 | @torch.no_grad()
60 | def evaluate(self, args) -> float:
61 | text_inputs, image_inputs = args
62 | text_embs = self.model.get_text_features(**text_inputs)
63 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
64 |
65 | scores = []
66 | for image_input in image_inputs:
67 | image_embs = self.model.get_image_features(**image_input)
68 | image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
69 | score = (self.model.logit_scale.exp() * (text_embs @ image_embs.T)).cpu().squeeze().item()
70 | scores.append(score)
71 |
72 | return sum(scores) / len(scores)
--------------------------------------------------------------------------------
/v2vbench/pick_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 16:07:53
5 | @Desc :
6 | @Ref : https://github.com/yuvalkirstain/PickScore
7 | '''
8 | import logging
9 | from pathlib import Path
10 |
11 | import torch
12 | from transformers import CLIPModel, CLIPProcessor
13 |
14 | from .base_evaluator import BaseEvaluator
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | class PickScore(BaseEvaluator):
19 | pretrained_processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
20 | pretrained_model_name = 'yuvalkirstain/PickScore_v1'
21 |
22 | def __init__(
23 | self,
24 | index_file: Path,
25 | edit_video_dir: Path,
26 | reference_video_dir: Path,
27 | edit_prompt: str,
28 | device: torch.device,
29 | pretrained_processor_name: str = None,
30 | pretrained_model_name: str = None,
31 | ):
32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
33 | pretrained_processor_name = pretrained_processor_name or self.pretrained_processor_name
34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
35 |
36 | logger.debug(f"Loding model {pretrained_model_name}")
37 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_processor_name)
38 | self.model = CLIPModel.from_pretrained(pretrained_model_name)
39 | self.model.to(self.device)
40 | self.model.eval()
41 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
42 |
43 | def range(self):
44 | return 0, 100
45 |
46 | def preprocess(self, sample):
47 | text = sample['edit_prompt']
48 | video = sample['edit_video']
49 |
50 | text_inputs = self.preprocessor(
51 | text=text, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)
52 |
53 | image_inputs = []
54 | for frame in video:
55 | image_inputs.append(self.preprocessor(
56 | images=frame, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device))
57 |
58 | return text_inputs, image_inputs
59 |
60 | @torch.no_grad()
61 | def evaluate(self, args) -> float:
62 | text_inputs, image_inputs = args
63 | text_embs = self.model.get_text_features(**text_inputs)
64 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
65 |
66 | scores = []
67 | for image_input in image_inputs:
68 | image_embs = self.model.get_image_features(**image_input)
69 | image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
70 | score = (self.model.logit_scale.exp() * (text_embs @ image_embs.T)).cpu().squeeze().item()
71 | scores.append(score)
72 |
73 | return sum(scores) / len(scores)
--------------------------------------------------------------------------------
/v2vbench/viclip_text_alignment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 20:58:36
5 | @Desc :
6 | @Ref :
7 | '''
8 | import logging
9 | from pathlib import Path
10 |
11 | import torch
12 |
13 | from third_party.viclip import SimpleTokenizer, ViCLIP, clip_image_transform
14 |
15 | from .base_evaluator import BaseEvaluator
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | class ViclipTextAlignment(BaseEvaluator):
20 | pretrained_tokenizer = (
21 | Path(__file__) / '../../checkpoints/ViCLIP/bpe_simple_vocab_16e6.txt.gz').resolve().absolute()
22 | pretrained_checkpoint = (
23 | Path(__file__) / '../../checkpoints/ViCLIP/ViClip-InternVid-10M-FLT.pth').resolve().absolute()
24 |
25 | def __init__(
26 | self,
27 | index_file: Path,
28 | edit_video_dir: Path,
29 | reference_video_dir: Path,
30 | edit_prompt: str,
31 | device: torch.device,
32 | pretrained_tokenizer: str = None,
33 | pretrained_checkpoint: str = None,
34 | stride: int = 3, # accept 8 frames as input
35 | ):
36 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device, stride=stride)
37 | pretrained_tokenizer = pretrained_tokenizer or self.pretrained_tokenizer
38 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint
39 |
40 | logger.debug(f"Loding model {pretrained_checkpoint}")
41 | tokenizer = SimpleTokenizer(bpe_path=pretrained_tokenizer)
42 | self.model = ViCLIP(tokenizer=tokenizer, pretrain=pretrained_checkpoint)
43 | self.model.to(self.device)
44 | self.model.eval()
45 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
46 |
47 | self.image_transform = clip_image_transform(224)
48 |
49 | def range(self):
50 | return 0, 1
51 |
52 | def preprocess(self, sample):
53 | text = sample['edit_prompt']
54 | video = sample['edit_video']
55 |
56 | text_inputs = text
57 |
58 | frames = []
59 | for frame in video:
60 | frames.append(self.image_transform(frame))
61 | video_inputs = torch.stack(frames).to(self.device)[None] # (1, T, C, H, W)
62 | logger.debug(f"video_inputs shape: {video_inputs.shape}")
63 |
64 | return text_inputs, video_inputs
65 |
66 | @torch.no_grad()
67 | def evaluate(self, args) -> float:
68 | text_inputs, video_inputs = args
69 |
70 | text_embs: torch.Tensor = self.model.encode_text(text_inputs).float()
71 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
72 | logger.debug(f"text_embs shape: {text_embs.shape}")
73 |
74 | video_embs: torch.Tensor = self.model.encode_vision(video_inputs, test=True).float()
75 | video_embs = video_embs / torch.norm(video_embs, dim=-1, keepdim=True)
76 | logger.debug(f"video_embs shape: {video_embs.shape}")
77 |
78 | score = (text_embs @ video_embs.T).cpu().squeeze().item()
79 |
80 | return score
--------------------------------------------------------------------------------
/v2vbench/dino_image_alignment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/07 16:06:35
5 | @Desc :
6 | @Ref :
7 | https://github.com/facebookresearch/dinov2
8 | https://huggingface.co/docs/transformers/model_doc/dinov2#dinov2
9 | '''
10 | import logging
11 | from pathlib import Path
12 |
13 | import numpy as np
14 | import torch
15 | from transformers import BitImageProcessor, Dinov2Model
16 |
17 | from .base_evaluator import BaseEvaluator
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class DinoImageAlignment(BaseEvaluator):
22 | pretrained_model_name = 'facebook/dinov2-base'
23 |
24 | def __init__(
25 | self,
26 | index_file: Path,
27 | edit_video_dir: Path,
28 | reference_video_dir: Path,
29 | edit_prompt: str,
30 | device: torch.device,
31 | pretrained_model_name: str = None,
32 | ):
33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device)
34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name
35 | logger.debug(f"Loding model {pretrained_model_name}")
36 | self.preprocessor = BitImageProcessor.from_pretrained(pretrained_model_name)
37 | self.model = Dinov2Model.from_pretrained(pretrained_model_name)
38 | self.model.to(self.device)
39 | self.model.eval()
40 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
41 |
42 | def range(self):
43 | return 0, 1
44 |
45 | def preprocess(self, sample):
46 | video = sample['edit_video']
47 | frames = []
48 | for i, frame in enumerate(video):
49 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values)
50 | reference_image = self.preprocessor(sample['reference_image'], return_tensors='pt').pixel_values
51 | return frames, reference_image
52 |
53 | @torch.no_grad()
54 | def evaluate(self, args) -> float:
55 | frames, reference_image = args
56 | similarity = []
57 | conformity = []
58 |
59 | reference_image = reference_image.to(self.device)
60 | reference_feature = self.model(pixel_values=reference_image).pooler_output
61 | reference_feature: torch.Tensor = reference_feature / torch.norm(reference_feature, dim=-1, keepdim=True)
62 |
63 | former_feature = None
64 | for i, frame in enumerate(frames):
65 | frame = frame.to(self.device)
66 | # pooled output (first image token)
67 | feature = self.model(pixel_values=frame).pooler_output
68 | feature: torch.Tensor = feature / torch.norm(feature, dim=-1, keepdim=True)
69 |
70 | if i > 0:
71 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item())
72 | similarity.append(sim)
73 | former_feature = feature
74 |
75 | conformity.append(max(0, (feature @ reference_feature.T).cpu().squeeze().item()))
76 |
77 | max_weight = 0.4
78 | mean_weight = 0.3
79 | min_weight = 0.3
80 | return float(
81 | max_weight * np.max(conformity) +
82 | mean_weight * np.mean(similarity) +
83 | min_weight * np.min(similarity)
84 | )
--------------------------------------------------------------------------------
/third_party/gmflow/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .position import PositionEmbeddingSine
3 |
4 |
5 | def split_feature(feature,
6 | num_splits=2,
7 | channel_last=False,
8 | ):
9 | if channel_last: # [B, H, W, C]
10 | b, h, w, c = feature.size()
11 | assert h % num_splits == 0 and w % num_splits == 0
12 |
13 | b_new = b * num_splits * num_splits
14 | h_new = h // num_splits
15 | w_new = w // num_splits
16 |
17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
19 | else: # [B, C, H, W]
20 | b, c, h, w = feature.size()
21 | assert h % num_splits == 0 and w % num_splits == 0
22 |
23 | b_new = b * num_splits * num_splits
24 | h_new = h // num_splits
25 | w_new = w // num_splits
26 |
27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
29 |
30 | return feature
31 |
32 |
33 | def merge_splits(splits,
34 | num_splits=2,
35 | channel_last=False,
36 | ):
37 | if channel_last: # [B*K*K, H/K, W/K, C]
38 | b, h, w, c = splits.size()
39 | new_b = b // num_splits // num_splits
40 |
41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c)
42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
44 | else: # [B*K*K, C, H/K, W/K]
45 | b, c, h, w = splits.size()
46 | new_b = b // num_splits // num_splits
47 |
48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w)
49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
51 |
52 | return merge
53 |
54 |
55 | def normalize_img(img0, img1):
56 | # loaded images are in [0, 255]
57 | # normalize by ImageNet mean and std
58 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
59 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
60 | img0 = (img0 / 255. - mean) / std
61 | img1 = (img1 / 255. - mean) / std
62 |
63 | return img0, img1
64 |
65 |
66 | def feature_add_position(feature0, feature1, attn_splits, feature_channels):
67 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
68 |
69 | if attn_splits > 1: # add position in splited window
70 | feature0_splits = split_feature(feature0, num_splits=attn_splits)
71 | feature1_splits = split_feature(feature1, num_splits=attn_splits)
72 |
73 | position = pos_enc(feature0_splits)
74 |
75 | feature0_splits = feature0_splits + position
76 | feature1_splits = feature1_splits + position
77 |
78 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
79 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
80 | else:
81 | position = pos_enc(feature0)
82 |
83 | feature0 = feature0 + position
84 | feature1 = feature1 + position
85 |
86 | return feature0, feature1
--------------------------------------------------------------------------------
/v2vbench/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/07 14:49:35
5 | @Desc :
6 | @Ref :
7 | '''
8 | import logging
9 | from pathlib import Path
10 | from typing import List
11 |
12 | import decord
13 | from PIL import Image, ImageSequence
14 |
15 | decord.bridge.set_bridge('native')
16 | logger = logging.getLogger(__name__)
17 | _supported_image_suffix = ['.jpg', '.jpeg', '.png']
18 | _supported_video_suffix = ['.mp4', '.gif']
19 |
20 |
21 | def _load_video_from_image_dir(video_dir: Path) -> List[Image.Image]:
22 | logger.debug(f'Loading video from image directory: {video_dir}')
23 | frames = []
24 | for file in sorted(video_dir.iterdir()):
25 | if file.suffix not in _supported_image_suffix:
26 | logger.debug(f'Skipping file: {file}')
27 | continue
28 | frame = Image.open(file).convert('RGB')
29 | frames.append(frame)
30 | if not frames:
31 | raise FileNotFoundError(f'No image found in {video_dir}, supported image suffix: {_supported_image_suffix}')
32 |
33 | return frames
34 |
35 |
36 | def _load_video_from_video_file(video_file: Path) -> List[Image.Image]:
37 | logger.debug(f'Loading video from video file: {video_file}')
38 | if video_file.suffix == '.mp4':
39 | video_reader = decord.VideoReader(str(video_file), num_threads=1)
40 | frames = []
41 | for i in range(len(video_reader)):
42 | frames.append(Image.fromarray(video_reader[i].asnumpy()))
43 | return frames
44 |
45 | elif video_file.suffix == '.gif':
46 | frames = []
47 | for f in ImageSequence.Iterator(Image.open(video_file)):
48 | frame = f.convert('RGB')
49 | frames.append(frame)
50 | return frames
51 |
52 | else:
53 | raise NotImplementedError(
54 | f'Unsupported video file: {video_file}, supported suffix: {_supported_video_suffix}')
55 |
56 |
57 | def load_video(video_file_or_dir: Path, start_frame: int = 0, stride: int = 1) -> List[Image.Image]:
58 | """
59 | Args:
60 | video_file_or_dir (Path): path to video file or directory containing images
61 | start_frame (int): start frame index
62 | stride (int): stride for frame sampling
63 | Returns:
64 | List[Image.Image]: list of frames, RGB
65 | """
66 | if not video_file_or_dir.exists():
67 | logger.debug(f'Video file or directory does not exist: {video_file_or_dir}, trying to find alternative')
68 | for suffix in _supported_video_suffix:
69 | if (video_file_or_dir.with_suffix(suffix)).exists():
70 | video_file_or_dir = video_file_or_dir.with_suffix(suffix)
71 | logger.debug(f'Found video file: {video_file_or_dir}')
72 | break
73 | else:
74 | raise FileNotFoundError(f'Reference video: {video_file_or_dir} does not exist')
75 |
76 | if video_file_or_dir.is_dir():
77 | buffer = _load_video_from_image_dir(video_file_or_dir)
78 | elif video_file_or_dir.is_file():
79 | buffer = _load_video_from_video_file(video_file_or_dir)
80 | else:
81 | # should not reach here
82 | raise NotImplementedError(f'{video_file_or_dir} is not a valid file or directory')
83 |
84 | video = buffer[start_frame::stride]
85 | logger.debug(f'Raw video frames: {len(buffer)}, sampled video frames: {len(video)}')
86 | logger.debug(f'Frame size: {video[0].size}')
87 | return video
--------------------------------------------------------------------------------
/third_party/gmflow/models/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def coords_grid(b, h, w, homogeneous=False, device=None):
6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
7 |
8 | stacks = [x, y]
9 |
10 | if homogeneous:
11 | ones = torch.ones_like(x) # [H, W]
12 | stacks.append(ones)
13 |
14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
15 |
16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
17 |
18 | if device is not None:
19 | grid = grid.to(device)
20 |
21 | return grid
22 |
23 |
24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
25 | assert device is not None
26 |
27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
28 | torch.linspace(h_min, h_max, len_h, device=device)],
29 | )
30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
31 |
32 | return grid
33 |
34 |
35 | def normalize_coords(coords, h, w):
36 | # coords: [B, H, W, 2]
37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
38 | return (coords - c) / c # [-1, 1]
39 |
40 |
41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
42 | # img: [B, C, H, W]
43 | # sample_coords: [B, 2, H, W] in image scale
44 | if sample_coords.size(1) != 2: # [B, H, W, 2]
45 | sample_coords = sample_coords.permute(0, 3, 1, 2)
46 |
47 | b, _, h, w = sample_coords.shape
48 |
49 | # Normalize to [-1, 1]
50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
52 |
53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
54 |
55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
56 |
57 | if return_mask:
58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
59 |
60 | return img, mask
61 |
62 | return img
63 |
64 |
65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
66 | b, c, h, w = feature.size()
67 | assert flow.size(1) == 2
68 |
69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
70 |
71 | return bilinear_sample(feature, grid, padding_mode=padding_mode,
72 | return_mask=mask)
73 |
74 |
75 | def forward_backward_consistency_check(fwd_flow, bwd_flow,
76 | alpha=0.01,
77 | beta=0.5
78 | ):
79 | # fwd_flow, bwd_flow: [B, 2, H, W]
80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
84 |
85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
87 |
88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
90 |
91 | threshold = alpha * flow_mag + beta
92 |
93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
94 | bwd_occ = (diff_bwd > threshold).float()
95 |
96 | return fwd_occ, bwd_occ
--------------------------------------------------------------------------------
/third_party/gmflow/models/trident_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.utils import _pair
8 |
9 |
10 | class MultiScaleTridentConv(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | strides=1,
18 | paddings=0,
19 | dilations=1,
20 | dilation=1,
21 | groups=1,
22 | num_branch=1,
23 | test_branch_idx=-1,
24 | bias=False,
25 | norm=None,
26 | activation=None,
27 | ):
28 | super(MultiScaleTridentConv, self).__init__()
29 | self.in_channels = in_channels
30 | self.out_channels = out_channels
31 | self.kernel_size = _pair(kernel_size)
32 | self.num_branch = num_branch
33 | self.stride = _pair(stride)
34 | self.groups = groups
35 | self.with_bias = bias
36 | self.dilation = dilation
37 | if isinstance(paddings, int):
38 | paddings = [paddings] * self.num_branch
39 | if isinstance(dilations, int):
40 | dilations = [dilations] * self.num_branch
41 | if isinstance(strides, int):
42 | strides = [strides] * self.num_branch
43 | self.paddings = [_pair(padding) for padding in paddings]
44 | self.dilations = [_pair(dilation) for dilation in dilations]
45 | self.strides = [_pair(stride) for stride in strides]
46 | self.test_branch_idx = test_branch_idx
47 | self.norm = norm
48 | self.activation = activation
49 |
50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51 |
52 | self.weight = nn.Parameter(
53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54 | )
55 | if bias:
56 | self.bias = nn.Parameter(torch.Tensor(out_channels))
57 | else:
58 | self.bias = None
59 |
60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61 | if self.bias is not None:
62 | nn.init.constant_(self.bias, 0)
63 |
64 | def forward(self, inputs):
65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66 | assert len(inputs) == num_branch
67 |
68 | if self.training or self.test_branch_idx == -1:
69 | outputs = [
70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
71 | for input, stride, padding in zip(inputs, self.strides, self.paddings)
72 | ]
73 | else:
74 | outputs = [
75 | F.conv2d(
76 | inputs[0],
77 | self.weight,
78 | self.bias,
79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
81 | self.dilation,
82 | self.groups,
83 | )
84 | ]
85 |
86 | if self.norm is not None:
87 | outputs = [self.norm(x) for x in outputs]
88 | if self.activation is not None:
89 | outputs = [self.activation(x) for x in outputs]
90 | return outputs
--------------------------------------------------------------------------------
/third_party/dover/models/head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class VQAHead(nn.Module):
5 | """MLP Regression Head for VQA.
6 | Args:
7 | in_channels: input channels for MLP
8 | hidden_channels: hidden channels for MLP
9 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
10 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes)
11 | """
12 |
13 | def __init__(
14 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, **kwargs
15 | ):
16 | super().__init__()
17 | self.dropout_ratio = dropout_ratio
18 | self.in_channels = in_channels
19 | self.hidden_channels = hidden_channels
20 | self.pre_pool = pre_pool
21 | if self.dropout_ratio != 0:
22 | self.dropout = nn.Dropout(p=self.dropout_ratio)
23 | else:
24 | self.dropout = None
25 | self.fc_hid = nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1))
26 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1))
27 | self.gelu = nn.GELU()
28 |
29 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
30 |
31 | def forward(self, x, rois=None):
32 | if self.pre_pool:
33 | x = self.avg_pool(x)
34 | x = self.dropout(x)
35 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x))))
36 | return qlt_score
37 |
38 |
39 |
40 |
41 |
42 | class VARHead(nn.Module):
43 | """MLP Regression Head for Video Action Recognition.
44 | Args:
45 | in_channels: input channels for MLP
46 | hidden_channels: hidden channels for MLP
47 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
48 | """
49 |
50 | def __init__(self, in_channels=768, out_channels=400, dropout_ratio=0.5, **kwargs):
51 | super().__init__()
52 | self.dropout_ratio = dropout_ratio
53 | self.in_channels = in_channels
54 | self.out_channels = out_channels
55 | if self.dropout_ratio != 0:
56 | self.dropout = nn.Dropout(p=self.dropout_ratio)
57 | else:
58 | self.dropout = None
59 | self.fc = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1))
60 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
61 |
62 | def forward(self, x, rois=None):
63 | x = self.dropout(x)
64 | x = self.avg_pool(x)
65 | out = self.fc(x)
66 | return out
67 |
68 |
69 | class IQAHead(nn.Module):
70 | """MLP Regression Head for IQA.
71 | Args:
72 | in_channels: input channels for MLP
73 | hidden_channels: hidden channels for MLP
74 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5)
75 | """
76 |
77 | def __init__(
78 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, **kwargs
79 | ):
80 | super().__init__()
81 | self.dropout_ratio = dropout_ratio
82 | self.in_channels = in_channels
83 | self.hidden_channels = hidden_channels
84 | if self.dropout_ratio != 0:
85 | self.dropout = nn.Dropout(p=self.dropout_ratio)
86 | else:
87 | self.dropout = None
88 | self.fc_hid = nn.Linear(self.in_channels, self.hidden_channels)
89 | self.fc_last = nn.Linear(self.hidden_channels, 1)
90 | self.gelu = nn.GELU()
91 |
92 | def forward(self, x):
93 | x = self.dropout(x)
94 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x))))
95 | return qlt_score
96 |
--------------------------------------------------------------------------------
/third_party/gmflow/models/matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from .geometry import coords_grid, generate_window_grid, normalize_coords
5 |
6 |
7 | def global_correlation_softmax(feature0, feature1,
8 | pred_bidir_flow=False,
9 | ):
10 | # global correlation
11 | b, c, h, w = feature0.shape
12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W]
14 |
15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
16 |
17 | # flow from softmax
18 | init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
20 |
21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
22 |
23 | if pred_bidir_flow:
24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
27 | b = b * 2
28 |
29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
30 |
31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
32 |
33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
34 | flow = correspondence - init_grid
35 |
36 | return flow, prob
37 |
38 |
39 | def local_correlation_softmax(feature0, feature1, local_radius,
40 | padding_mode='zeros',
41 | ):
42 | b, c, h, w = feature0.size()
43 | coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
45 |
46 | local_h = 2 * local_radius + 1
47 | local_w = 2 * local_radius + 1
48 |
49 | window_grid = generate_window_grid(-local_radius, local_radius,
50 | -local_radius, local_radius,
51 | local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
54 |
55 | sample_coords_softmax = sample_coords
56 |
57 | # exclude coords that are out of image space
58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
60 |
61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
62 |
63 | # normalize coordinates to [-1, 1]
64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
65 | window_feature = F.grid_sample(feature1, sample_coords_norm,
66 | padding_mode=padding_mode, align_corners=True
67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
69 |
70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
71 |
72 | # mask invalid locations
73 | corr[~valid] = -1e9
74 |
75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
76 |
77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
79 |
80 | flow = correspondence - coords_init
81 | match_prob = prob
82 |
83 | return flow, match_prob
--------------------------------------------------------------------------------
/doc/leaderboard.md:
--------------------------------------------------------------------------------
1 | # 📈 Leaderboard
2 | The best results within each category are italicized, and the globally best metrics are underlined*.
3 |
4 |
5 |
6 | | | Method | Frames Quality⬆ | Semantic Consistency⬆ | Object Consistency⬆ | Video Quality⬆ | Frames Text Alignment⬆ | Frames Pick Score⬆ | Video Text Alignment⬆ | Motion Alignment⬆ |
7 |
8 |
9 | Network and Training Paradigm |
10 | Tune-A-Video | 5.001 | 0.934 | 0.917 | 0.527 | 27.513 | 20.701 | 0.254 | -5.599 |
11 |
12 |
13 | | SimDA | 4.988 | 0.940 | 0.929 | 0.569 | 26.773 | 20.512 | 0.248 | -4.756 |
14 |
15 |
16 | | VidToMe | 4.988 | 0.949 | 0.945 | 0.656 | 26.813 | 20.546 | 0.240 | -3.203 |
17 |
18 |
19 | | VideoComposer | 4.429 | 0.914 | 0.905 | 0.370 | 28.001 | 20.272 | 0.262 | -8.095 |
20 |
21 |
22 | | MotionDirector | 4.984 | 0.940 | 0.951 | 0.617 | 27.845 | 20.923 | 0.262 | -3.088 |
23 |
24 |
25 | Attention Feature Injection |
26 | Video-P2P | 4.907 | 0.943 | 0.926 | 0.471 | 23.550 | 19.751 | 0.193 | -5.974 |
27 |
28 |
29 | | Vid2Vid-Zero | 5.103 | 0.919 | 0.912 | 0.638 | 28.789 | 20.950 | 0.270 | -4.175 |
30 |
31 |
32 | | Fate-Zero | 5.036 | 0.951 | 0.952 * | 0.704 | 25.065 | 20.707 | 0.225 | -1.439* |
33 |
34 |
35 | | TokenFlow | 5.068 | 0.947 | 0.943 | 0.715 | 27.522 | 20.757 | 0.254 | -1.572 |
36 |
37 |
38 | | FLATTEN | 4.965 | 0.943 | 0.949 | 0.645 | 27.156 | 20.745 | 0.251 | -1.446 |
39 |
40 |
41 | | FRESCO | 5.127 | 0.908 | 0.896 | 0.689 | 25.639 | 20.239 | 0.223 | -5.241 |
42 |
43 |
44 | Diffusion Latent Manipulation |
45 | Text2Video-Zero | 5.097 | 0.899 | 0.894 | 0.613 | 29.124* | 20.568 | 0.265 | -17.226 |
46 |
47 |
48 | | Pix2Video | 5.075 | 0.946 | 0.944 | 0.638 | 28.731 | 21.054* | 0.271 * | -2.889 |
49 |
50 |
51 | | ControlVideo | 5.404 * | 0.959 * | 0.948 | 0.674 | 28.551 | 20.961 | 0.261 | -9.396 |
52 |
53 |
54 | | Rerender | 5.002 | 0.872 | 0.863 | 0.724 * | 27.379 | 20.460 | 0.261 | -4.959 |
55 |
56 |
57 | | RAVE | 5.077 | 0.926 | 0.936 | 0.664 | 28.190 | 20.865 | 0.255 | -2.398 |
58 |
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # General
165 | .DS_Store
166 | .AppleDouble
167 | .LSOverride
168 |
169 | # Icon must end with two \r
170 | Icon
171 |
172 | # Thumbnails
173 | ._*
174 |
175 | # Files that might appear in the root of a volume
176 | .DocumentRevisions-V100
177 | .fseventsd
178 | .Spotlight-V100
179 | .TemporaryItems
180 | .Trashes
181 | .VolumeIcon.icns
182 | .com.apple.timemachine.donotpresent
183 |
184 | # Directories potentially created on remote AFP share
185 | .AppleDB
186 | .AppleDesktop
187 | Network Trash Folder
188 | Temporary Items
189 | .apdisk
190 |
--------------------------------------------------------------------------------
/v2vbench/base_evaluator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/26 13:48:08
5 | @Desc :
6 | @Ref :
7 | '''
8 | import logging
9 | from abc import ABC, abstractmethod
10 | from collections import defaultdict
11 | from pathlib import Path
12 | from typing import Tuple, Union
13 |
14 | import torch
15 | from omegaconf import OmegaConf
16 | from tqdm.auto import tqdm
17 |
18 | from .utils import load_video
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class BaseEvaluator(ABC):
24 | def __init__(
25 | self,
26 | index_file: Path,
27 | edit_video_dir: Path,
28 | reference_video_dir: Path,
29 | edit_prompt: str,
30 | device: torch.device,
31 | stride: int = 1,
32 | ):
33 | self.index_file = index_file
34 | self.edit_video_dir = edit_video_dir
35 | self.reference_video_dir = reference_video_dir
36 | self.edit_prompt = edit_prompt
37 | self.device = device
38 | # stride for loading video frames, default is 1 - load all frames
39 | # if stride is 2, load every other frame, if stride is 3, load every third frame, etc.
40 | self.stride = stride
41 |
42 | @property
43 | @abstractmethod
44 | def range(self) -> Tuple[Union[float, int]]:
45 | """
46 | Returns:
47 | Tuple[Union[float, int]]: the worse and best value of the metric
48 | """
49 | raise NotImplementedError
50 |
51 | @property
52 | def direction(self) -> str:
53 | return 'maximize' if self.range[0] < self.range[1] else 'minimize'
54 |
55 | def _sample_data_iter(self):
56 | # for single video, the video_dir is the file path to the video
57 | if self.reference_video_dir is not None:
58 | reference_video_path = self.reference_video_dir
59 | reference_video = load_video(reference_video_path, stride=self.stride)
60 | else:
61 | reference_video_path = None
62 | reference_video = None
63 | edit_video = load_video(self.edit_video_dir, stride=self.stride)
64 | yield {
65 | # for single video, the reference video may not be available
66 | # use the edit_video name as video_id
67 | 'video_id': self.edit_video_dir.with_suffix('').name,
68 | # 'prompt': sample['prompt'],
69 | 'reference_video': reference_video,
70 | 'reference_video_path': reference_video_path,
71 | 'edit_id': 0,
72 | 'edit_prompt': self.edit_prompt,
73 | 'edit_video': edit_video,
74 | }
75 |
76 | def _batch_data_iter(self):
77 | config: OmegaConf = OmegaConf.load(self.index_file)
78 | for sample in config['data']:
79 | # reference video is optional
80 | reference_video_path = None
81 | reference_video = None
82 |
83 | if self.reference_video_dir is not None:
84 | reference_video_path = self.reference_video_dir / sample['video_id']
85 | reference_video = load_video(reference_video_path, stride=self.stride)
86 |
87 | # iterate over edit
88 | for edit_id, edit in enumerate(sample['edit']):
89 | edit_video_path = self.edit_video_dir / sample['video_id'] / str(edit_id)
90 | edit_video = load_video(edit_video_path, stride=self.stride)
91 | yield {
92 | 'video_id': sample['video_id'],
93 | # 'prompt': sample['prompt'],
94 | 'reference_video': reference_video,
95 | 'reference_video_path': reference_video_path if reference_video is not None else None,
96 | 'edit_id': edit_id,
97 | 'edit_prompt': edit['prompt'],
98 | 'edit_video': edit_video,
99 | }
100 |
101 | def _data_iter(self):
102 | if self.index_file is not None:
103 | return self._batch_data_iter()
104 | else:
105 | return self._sample_data_iter()
106 |
107 | def preprocess(self, sample):
108 | return sample
109 |
110 | @abstractmethod
111 | def evaluate(self, sample) -> float:
112 | raise NotImplementedError
113 |
114 | def __call__(self):
115 | results = defaultdict(list)
116 | for sample in tqdm(self._data_iter()):
117 | score = self.evaluate(self.preprocess(sample))
118 |
119 | # results['method'].append(self.method)
120 | results['video_id'].append(sample['video_id'])
121 | results['edit_id'].append(sample['edit_id'])
122 | results['score'].append(score)
123 | return results
--------------------------------------------------------------------------------
/third_party/gmflow/models/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .trident_conv import MultiScaleTridentConv
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
8 | ):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
12 | dilation=dilation, padding=dilation, stride=stride, bias=False)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | self.norm1 = norm_layer(planes)
18 | self.norm2 = norm_layer(planes)
19 | if not stride == 1 or in_planes != planes:
20 | self.norm3 = norm_layer(planes)
21 |
22 | if stride == 1 and in_planes == planes:
23 | self.downsample = None
24 | else:
25 | self.downsample = nn.Sequential(
26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
27 |
28 | def forward(self, x):
29 | y = x
30 | y = self.relu(self.norm1(self.conv1(y)))
31 | y = self.relu(self.norm2(self.conv2(y)))
32 |
33 | if self.downsample is not None:
34 | x = self.downsample(x)
35 |
36 | return self.relu(x + y)
37 |
38 |
39 | class CNNEncoder(nn.Module):
40 | def __init__(self, output_dim=128,
41 | norm_layer=nn.InstanceNorm2d,
42 | num_output_scales=1,
43 | **kwargs,
44 | ):
45 | super(CNNEncoder, self).__init__()
46 | self.num_branch = num_output_scales
47 |
48 | feature_dims = [64, 96, 128]
49 |
50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
51 | self.norm1 = norm_layer(feature_dims[0])
52 | self.relu1 = nn.ReLU(inplace=True)
53 |
54 | self.in_planes = feature_dims[0]
55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
57 |
58 | # highest resolution 1/4 or 1/8
59 | stride = 2 if num_output_scales == 1 else 1
60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride,
61 | norm_layer=norm_layer,
62 | ) # 1/4 or 1/8
63 |
64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
65 |
66 | if self.num_branch > 1:
67 | if self.num_branch == 4:
68 | strides = (1, 2, 4, 8)
69 | elif self.num_branch == 3:
70 | strides = (1, 2, 4)
71 | elif self.num_branch == 2:
72 | strides = (1, 2)
73 | else:
74 | raise ValueError
75 |
76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
77 | kernel_size=3,
78 | strides=strides,
79 | paddings=1,
80 | num_branch=self.num_branch,
81 | )
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
87 | if m.weight is not None:
88 | nn.init.constant_(m.weight, 1)
89 | if m.bias is not None:
90 | nn.init.constant_(m.bias, 0)
91 |
92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
95 |
96 | layers = (layer1, layer2)
97 |
98 | self.in_planes = dim
99 | return nn.Sequential(*layers)
100 |
101 | def forward(self, x):
102 | x = self.conv1(x)
103 | x = self.norm1(x)
104 | x = self.relu1(x)
105 |
106 | x = self.layer1(x) # 1/2
107 | x = self.layer2(x) # 1/4
108 | x = self.layer3(x) # 1/8 or 1/4
109 |
110 | x = self.conv2(x)
111 |
112 | if self.num_branch > 1:
113 | out = self.trident_conv([x] * self.num_branch) # high to low res
114 | else:
115 | out = [x]
116 |
117 | return out
--------------------------------------------------------------------------------
/v2vbench/motion_alignment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/27 12:30:53
5 | @Desc :
6 | @Ref :
7 | '''
8 | import logging
9 | from pathlib import Path
10 | from typing import List
11 |
12 | import numpy as np
13 | import torch
14 | from omegaconf import OmegaConf
15 |
16 | from third_party.gmflow.models import GMFlow
17 | from third_party.gmflow.utils import InputPadder, write_flow
18 |
19 | from .base_evaluator import BaseEvaluator
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | class MotionAlignment(BaseEvaluator):
24 | pretrained_config_file = (
25 | Path(__file__) / '../../third_party/gmflow/config.yaml').resolve().absolute()
26 | pretrained_checkpoint = (
27 | Path(__file__) / '../../checkpoints/gmflow/gmflow_sintel-0c07dcb3.pth').resolve().absolute()
28 |
29 | def __init__(
30 | self,
31 | index_file: Path,
32 | edit_video_dir: Path,
33 | reference_video_dir: Path,
34 | edit_prompt: str,
35 | device: torch.device,
36 | pretrained_config_file: str = None,
37 | pretrained_checkpoint: str = None,
38 | cache_flow: bool = False,
39 | ):
40 | super().__init__(index_file, edit_video_dir, reference_video_dir, edit_prompt, device)
41 | pretrained_config_file = pretrained_config_file or self.pretrained_config_file
42 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint
43 |
44 | logger.debug(f"Loding model {pretrained_checkpoint}")
45 | config = OmegaConf.to_container(OmegaConf.load(pretrained_config_file))
46 | self.padder = InputPadder(**config['data'])
47 | self.model = GMFlow(**config['model'])
48 | self.model.to(self.device)
49 | self.model.load_state_dict(torch.load(pretrained_checkpoint, map_location=self.device)['model'])
50 | self.model.eval()
51 | logger.debug(f"Model {self.model.__class__.__name__} loaded")
52 |
53 | self.cache_flow = cache_flow
54 |
55 | def range(self):
56 | return -1 * np.inf, 0
57 |
58 | def get_flow_file(self, reference_video_path: Path):
59 | if reference_video_path is None:
60 | return Path('')
61 |
62 | video_id = reference_video_path.stem
63 | flow_dir = reference_video_path.parent.parent / 'flow'
64 | return flow_dir / f'{video_id}.pt'
65 |
66 | def preprocess(self, sample):
67 | # load edit frames
68 | edit_frames = []
69 | for i, frame in enumerate(sample['edit_video']):
70 | frame = torch.from_numpy(np.array(frame).astype(np.uint8)).permute(2, 0, 1).float().to(self.device)
71 | frame = self.padder.pad(frame)[0][None] # (B, C, H, W)
72 | edit_frames.append(frame)
73 |
74 | reference_frames = []
75 | reference_flow = None
76 | reference_flow_file = self.get_flow_file(sample['reference_video_path'])
77 |
78 | # try to load cached flow, only for batch evaluation
79 | try:
80 | logger.debug(f"Loading cached flow from {reference_flow_file}")
81 | reference_flow = torch.load(reference_flow_file, map_location=self.device)
82 | # cached flow not found, compute flow
83 | except FileNotFoundError or AttributeError:
84 | reference_flow_file = None
85 | logger.debug(f"Flow file not found, loading reference frames.")
86 | for i, frame in enumerate(sample['reference_video']):
87 | frame = torch.from_numpy(np.array(frame).astype(np.uint8)).permute(2, 0, 1).float().to(self.device)
88 | frame = self.padder.pad(frame)[0][None] # (B, C, H, W)
89 | reference_frames.append(frame)
90 |
91 | return edit_frames, reference_frames, reference_flow_file, reference_flow
92 |
93 | @torch.no_grad()
94 | def extract_flow(self, frames: List[torch.Tensor]):
95 | flows = []
96 |
97 | for i, frame in enumerate(frames):
98 | if i > 0:
99 | flow_pr = self.model(
100 | prev_frame, frame,
101 | attn_splits_list=[2],
102 | corr_radius_list=[-1],
103 | prop_radius_list=[-1],
104 | )['flow_preds'][-1][0] # (2, H, W)
105 | flows.append(flow_pr)
106 | if i == 1:
107 | logger.debug(f"Flow shape: {flow_pr.shape}")
108 | prev_frame = frame
109 |
110 | return torch.stack(flows, dim=0) # (T-1, 2, H, W)
111 |
112 | @torch.no_grad()
113 | def evaluate(self, args) -> float:
114 | edit_frames, reference_frames, reference_flow_file, reference_flow = args
115 | edit_flow = self.extract_flow(edit_frames)
116 |
117 | # calculate flow for reference video if not cached
118 | if reference_flow_file is None:
119 | reference_flow = self.extract_flow(reference_frames)
120 | if self.cache_flow:
121 | write_flow(reference_flow, reference_flow_file)
122 | logger.debug(f"Flow saved to {reference_flow_file}")
123 |
124 | # calculate alignment score: EPE
125 | score = torch.sum((edit_flow - reference_flow) ** 2, dim=1).sqrt().mean().cpu().item()
126 | return - score
127 |
--------------------------------------------------------------------------------
/third_party/viclip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | import subprocess
5 | from functools import lru_cache
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | # def default_bpe():
11 | # tokenizer_file = os.path.join('.checkpoints', "ViCLIP/bpe_simple_vocab_16e6.txt.gz")
12 | # if not os.path.exists(tokenizer_file):
13 | # print(f'Downloading ViCLIP tokenizer to {tokenizer_file}')
14 | # wget_command = ['wget', 'https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz', '-P', os.path.dirname(tokenizer_file)]
15 | # subprocess.run(wget_command)
16 | # return tokenizer_file
17 |
18 |
19 | @lru_cache()
20 | def bytes_to_unicode():
21 | """
22 | Returns list of utf-8 byte and a corresponding list of unicode strings.
23 | The reversible bpe codes work on unicode strings.
24 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
25 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
26 | This is a signficant percentage of your normal, say, 32K bpe vocab.
27 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
28 | And avoids mapping to whitespace/control characters the bpe code barfs on.
29 | """
30 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
31 | cs = bs[:]
32 | n = 0
33 | for b in range(2**8):
34 | if b not in bs:
35 | bs.append(b)
36 | cs.append(2**8+n)
37 | n += 1
38 | cs = [chr(n) for n in cs]
39 | return dict(zip(bs, cs))
40 |
41 |
42 | def get_pairs(word):
43 | """Return set of symbol pairs in a word.
44 | Word is represented as tuple of symbols (symbols being variable-length strings).
45 | """
46 | pairs = set()
47 | prev_char = word[0]
48 | for char in word[1:]:
49 | pairs.add((prev_char, char))
50 | prev_char = char
51 | return pairs
52 |
53 |
54 | def basic_clean(text):
55 | text = ftfy.fix_text(text)
56 | text = html.unescape(html.unescape(text))
57 | return text.strip()
58 |
59 |
60 | def whitespace_clean(text):
61 | text = re.sub(r'\s+', ' ', text)
62 | text = text.strip()
63 | return text
64 |
65 |
66 | class SimpleTokenizer(object):
67 | def __init__(self, bpe_path: str):
68 | self.byte_encoder = bytes_to_unicode()
69 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
70 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
71 | merges = merges[1:49152-256-2+1]
72 | merges = [tuple(merge.split()) for merge in merges]
73 | vocab = list(bytes_to_unicode().values())
74 | vocab = vocab + [v+'' for v in vocab]
75 | for merge in merges:
76 | vocab.append(''.join(merge))
77 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
78 | self.encoder = dict(zip(vocab, range(len(vocab))))
79 | self.decoder = {v: k for k, v in self.encoder.items()}
80 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
83 |
84 | def bpe(self, token):
85 | if token in self.cache:
86 | return self.cache[token]
87 | word = tuple(token[:-1]) + ( token[-1] + '',)
88 | pairs = get_pairs(word)
89 |
90 | if not pairs:
91 | return token+''
92 |
93 | while True:
94 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
95 | if bigram not in self.bpe_ranks:
96 | break
97 | first, second = bigram
98 | new_word = []
99 | i = 0
100 | while i < len(word):
101 | try:
102 | j = word.index(first, i)
103 | new_word.extend(word[i:j])
104 | i = j
105 | except:
106 | new_word.extend(word[i:])
107 | break
108 |
109 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
110 | new_word.append(first+second)
111 | i += 2
112 | else:
113 | new_word.append(word[i])
114 | i += 1
115 | new_word = tuple(new_word)
116 | word = new_word
117 | if len(word) == 1:
118 | break
119 | else:
120 | pairs = get_pairs(word)
121 | word = ' '.join(word)
122 | self.cache[token] = word
123 | return word
124 |
125 | def encode(self, text):
126 | bpe_tokens = []
127 | text = whitespace_clean(basic_clean(text)).lower()
128 | for token in re.findall(self.pat, text):
129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
131 | return bpe_tokens
132 |
133 | def decode(self, tokens):
134 | text = ''.join([self.decoder[token] for token in tokens])
135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
136 | return text
--------------------------------------------------------------------------------
/doc/README.md:
--------------------------------------------------------------------------------
1 | # 📈 V2VBench
2 |
3 | V2VBench is a comprehensive benchmark designed to evaluate video editing methods. It consists of:
4 | - 50 standardized videos across 5 categories, and
5 | - 3 editing prompts per video, encompassing 4 editing tasks: [Huggingface Datasets](https://huggingface.co/datasets/Wenhao-Sun/V2VBench)
6 | - 8 evaluation metrics to assess the quality of edited videos: [Evaluation Metrics](doc/README.md)
7 |
8 | For detailed information, please refer to the accompanying paper.
9 |
10 |
11 | ## Installation
12 | Clone the repository:
13 | ```bash
14 | git clone https://github.com/wenhao728/awesome-diffusion-v2v.git
15 | ```
16 |
17 | (Optional) We recommend using a virtual environment to manage the dependencies. You can refer [virtualenv](https://virtualenv.pypa.io/en/latest/user_guide.html) or [conda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#activating-an-environment) to create a virtual environment.
18 |
19 | Install the required packages:
20 | ```bash
21 | pip install -r requirements.txt
22 | ```
23 |
24 | Download the pre-trained models:
25 | ```bash
26 | sh doc/install.sh
27 | ```
28 | The `./checkpoints` directory should contain the following files:
29 | ```diff
30 | checkpoints
31 | + ├── DOVER
32 | + │ └── DOVER.pth
33 | + ├── gmflow
34 | + │ └── gmflow_sintel-0c07dcb3.pth
35 | + └── ViCLIP
36 | + ├── bpe_simple_vocab_16e6.txt.gz
37 | + └── ViClip-InternVid-10M-FLT.pth
38 | ```
39 |
40 | ## Run Evaluation
41 | ### Video format
42 | Videos can be provided either as "video files" (e.g. `video_name.mp4`, `video_name.gif`) or "directories containing video frame images" with filenames indicating the frame number (e.g. `video_name/01.jpg`, `video_name/02.jpg`, ...).
43 |
44 | ### Single Video Evaluation
45 | Run the following snippet to evaluate one edited video:
46 | ```python
47 | import logging
48 |
49 | from v2vbench import EvaluatorWrapper
50 |
51 | logging.basicConfig(
52 | level=logging.INFO, # Change it to logging.DEBUG if you want to troubleshoot
53 | format="%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)",
54 | handlers=[logging.StreamHandler(sys.stdout),]
55 | )
56 |
57 | evaluation = EvaluatorWrapper(metrics='all') # 'all' for all metrics
58 | # print(EvaluatorWrapper.all_metrics) # list all available metrics
59 |
60 | results = evaluation.evaluate(
61 | edit_video='/path/to/edited_video',
62 | reference_video='/path/to/source_video',
63 | edit_prompt='',
64 | output_dir='/path/to/save/results',
65 | )
66 | ```
67 |
68 | ### Batch Evaluation
69 | Check the [Prepare Data](#optional-prepare-data) section to prepare the data for batch evaluation.
70 | Then, simply run the following snippet to evaluate a batch of edited videos:
71 | ```python
72 | import logging
73 |
74 | from v2vbench import EvaluatorWrapper
75 |
76 | logging.basicConfig(
77 | level=logging.INFO, # Change it to logging.DEBUG if you want to troubleshoot
78 | format="%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)",
79 | handlers=[logging.StreamHandler(sys.stdout),]
80 | )
81 |
82 | evaluation = EvaluatorWrapper(metrics='all') # 'all' for all metrics
83 | # print(EvaluatorWrapper.all_metrics) # list all available metrics
84 |
85 | results = evaluation.evaluate(
86 | edit_video='/path/to/edited_videos',
87 | reference_video='/path/to/source_videos',
88 | index_file='/path/to/config.yaml',
89 | output_dir='/path/to/save/results',
90 | # it is recommended to cache the flow of source videos for motion_alignment to avoid redundant computation
91 | evaluator_kwargs={'motion_alignment': {'cache_flow': True}},
92 | )
93 | ```
94 |
95 | ## (Optional) Prepare Data
96 | To evaluate a batch of editing results, the data should be organized in specified formats.
97 | You can download the V2VBench dataset from [Huggingface Datasets](https://huggingface.co/datasets/Wenhao-Sun/V2VBench) or use your customized data with the following format.
98 |
99 | ### Configuration
100 | The configuration file is a YAML file. And each video should be one entry of the `data` list. The following is an example of one entry:
101 | ```yaml
102 | video_id: hike # source video id
103 | edit:
104 | - prompt: a superman with a backpack hikes through the desert # edit prompt 0
105 | - prompt: a man with a backpack hikes through the dolomites, pixel art # edit prompt 1
106 | # ... more edit prompts
107 | ```
108 | ### Source Videos
109 | The source videos for reference should be named according to the `video_id` specified in the configuration file.
110 |
111 | ### Edited Videos
112 | For evaluation, the edited videos should be placed in the directory named according to the `video_id` specified in the configuration file. Each edited video should be named according to the edit prompt index.
113 |
114 |
115 | ## Shoutouts
116 | This repository is inspired by the following open-source projects:
117 | - [transformers](https://github.com/huggingface/transformers) by [Huggingface](https://github.com/huggingface)
118 | - [LAION Aesthetic Predictor](https://github.com/LAION-AI/aesthetic-predictor) (and its [out-of-the-box implementation](https://github.com/shunk031/simple-aesthetics-predictor))
119 | - [CLIP](https://github.com/openai/CLIP)
120 | - [DINO-v2](https://github.com/facebookresearch/dinov2)
121 | - [DOVER](https://github.com/VQAssessment/DOVER)
122 | - [GMFlow](https://github.com/haofeixu/gmflow)
123 | - [ViCLIP](https://github.com/OpenGVLab/InternVideo)
124 | - [PickScore](https://github.com/yuvalkirstain/PickScore)
125 | - [VBench](https://github.com/Vchitect/VBench)
126 |
127 | We extend our gratitude to the authors for their exceptional work and open-source contributions.
--------------------------------------------------------------------------------
/third_party/gmflow/models/gmflow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .backbone import CNNEncoder
6 | from .transformer import FeatureTransformer, FeatureFlowAttention
7 | from .matching import global_correlation_softmax, local_correlation_softmax
8 | from .geometry import flow_warp
9 | from .utils import normalize_img, feature_add_position
10 |
11 |
12 | class GMFlow(nn.Module):
13 | def __init__(self,
14 | num_scales=1,
15 | upsample_factor=8,
16 | feature_channels=128,
17 | attention_type='swin',
18 | num_transformer_layers=6,
19 | ffn_dim_expansion=4,
20 | num_head=1,
21 | **kwargs,
22 | ):
23 | super(GMFlow, self).__init__()
24 |
25 | self.num_scales = num_scales
26 | self.feature_channels = feature_channels
27 | self.upsample_factor = upsample_factor
28 | self.attention_type = attention_type
29 | self.num_transformer_layers = num_transformer_layers
30 |
31 | # CNN backbone
32 | self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
33 |
34 | # Transformer
35 | self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
36 | d_model=feature_channels,
37 | nhead=num_head,
38 | attention_type=attention_type,
39 | ffn_dim_expansion=ffn_dim_expansion,
40 | )
41 |
42 | # flow propagation with self-attn
43 | self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
44 |
45 | # convex upsampling: concat feature0 and flow as input
46 | self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
49 |
50 | def extract_feature(self, img0, img1):
51 | concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
52 | features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
53 |
54 | # reverse: resolution from low to high
55 | features = features[::-1]
56 |
57 | feature0, feature1 = [], []
58 |
59 | for i in range(len(features)):
60 | feature = features[i]
61 | chunks = torch.chunk(feature, 2, 0) # tuple
62 | feature0.append(chunks[0])
63 | feature1.append(chunks[1])
64 |
65 | return feature0, feature1
66 |
67 | def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
68 | ):
69 | if bilinear:
70 | up_flow = F.interpolate(flow, scale_factor=upsample_factor,
71 | mode='bilinear', align_corners=True) * upsample_factor
72 |
73 | else:
74 | # convex upsampling
75 | concat = torch.cat((flow, feature), dim=1)
76 |
77 | mask = self.upsampler(concat)
78 | b, flow_channel, h, w = flow.shape
79 | mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
80 | mask = torch.softmax(mask, dim=2)
81 |
82 | up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
83 | up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
84 |
85 | up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
87 | up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
88 | self.upsample_factor * w) # [B, 2, K*H, K*W]
89 |
90 | return up_flow
91 |
92 | def forward(self, img0, img1,
93 | attn_splits_list=None,
94 | corr_radius_list=None,
95 | prop_radius_list=None,
96 | pred_bidir_flow=False,
97 | **kwargs,
98 | ):
99 |
100 | results_dict = {}
101 | flow_preds = []
102 |
103 | img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
104 |
105 | # resolution low to high
106 | feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
107 |
108 | flow = None
109 |
110 | assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
111 |
112 | for scale_idx in range(self.num_scales):
113 | feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
114 |
115 | if pred_bidir_flow and scale_idx > 0:
116 | # predicting bidirectional flow with refinement
117 | feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
118 |
119 | upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
120 |
121 | if scale_idx > 0:
122 | flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
123 |
124 | if flow is not None:
125 | flow = flow.detach()
126 | feature1 = flow_warp(feature1, flow) # [B, C, H, W]
127 |
128 | attn_splits = attn_splits_list[scale_idx]
129 | corr_radius = corr_radius_list[scale_idx]
130 | prop_radius = prop_radius_list[scale_idx]
131 |
132 | # add position to features
133 | feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
134 |
135 | # Transformer
136 | feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
137 |
138 | # correlation and softmax
139 | if corr_radius == -1: # global matching
140 | flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
141 | else: # local matching
142 | flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
143 |
144 | # flow or residual flow
145 | flow = flow + flow_pred if flow is not None else flow_pred
146 |
147 | # upsample to the original resolution for supervison
148 | if self.training: # only need to upsample intermediate flow predictions at training time
149 | flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor)
150 | flow_preds.append(flow_bilinear)
151 |
152 | # flow propagation with self-attn
153 | if pred_bidir_flow and scale_idx == 0:
154 | feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
155 | flow = self.feature_flow_attn(feature0, flow.detach(),
156 | local_window_attn=prop_radius > 0,
157 | local_window_radius=prop_radius)
158 |
159 | # bilinear upsampling at training time except the last one
160 | if self.training and scale_idx < self.num_scales - 1:
161 | flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor)
162 | flow_preds.append(flow_up)
163 |
164 | if scale_idx == self.num_scales - 1:
165 | flow_up = self.upsample_flow(flow, feature0)
166 | flow_preds.append(flow_up)
167 |
168 | results_dict.update({'flow_preds': flow_preds})
169 |
170 | return results_dict
--------------------------------------------------------------------------------
/v2vbench/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @Created : 2024/05/07 13:48:51
5 | @Desc :
6 | @Ref :
7 | '''
8 | import gc
9 | import importlib
10 | import logging
11 | import sys
12 | import warnings
13 | from pathlib import Path
14 | from typing import Dict, List, Optional, Union
15 |
16 | import pandas as pd
17 | import torch
18 | from tqdm.auto import tqdm
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | cur_dir = Path('..').resolve().absolute()
23 | if str(cur_dir) not in sys.path:
24 | sys.path.append(str(cur_dir))
25 |
26 |
27 | class EvaluatorWrapper:
28 | _quality_metrics = ['aesthetic_score', 'clip_consistency', 'dino_consistency', 'dover_score']
29 | _alignment_metrics = ['clip_text_alignment', 'pick_score', 'viclip_text_alignment', 'motion_alignment']
30 | all_metrics = _quality_metrics + _alignment_metrics
31 |
32 | def __init__(
33 | self,
34 | device: Optional[torch.device] = None,
35 | metrics: Union[str, List[str]] = 'all',
36 | ):
37 | self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38 | self.metrics = metrics
39 |
40 | def _check_args(self, edit_video, reference_video, index_file, edit_prompt):
41 | if isinstance(self.metrics, str):
42 | if self.metrics == 'all':
43 | self.metrics = self.all_metrics
44 | else:
45 | self.metrics = self.metrics.split(',')
46 | not_supported_metrics = set(self.metrics) - set(self.all_metrics)
47 | self.metrics = list(set(self.metrics) - not_supported_metrics)
48 | if not_supported_metrics:
49 | warnings.warn(f'Unsupported metrics: {not_supported_metrics}')
50 | if self.metrics:
51 | logger.info(f'Using metrics: {self.metrics}')
52 | else:
53 | raise ValueError('No supported metrics provided')
54 |
55 | if not edit_video.exists():
56 | raise FileNotFoundError(f'Edit video: {edit_video} does not exist')
57 | if reference_video and not reference_video.exists():
58 | raise FileNotFoundError(f'Reference video: {reference_video} does not exist')
59 | if index_file:
60 | if not index_file.exists():
61 | raise FileNotFoundError(f'Index file: {index_file} does not exist')
62 | else:
63 | if not edit_prompt:
64 | raise ValueError(
65 | 'Edit prompt (for single video evaluation) OR index file (for batch evaluation) is required')
66 |
67 | def merge_results(self, results: Dict[str, Dict[str, List]]) -> pd.DataFrame:
68 | merged_results_df = None
69 | for metric, result in results.items():
70 | result_df = pd.DataFrame(result)
71 | result_df.rename(columns={'score': metric}, inplace=True)
72 | if merged_results_df is None:
73 | merged_results_df = result_df
74 | else:
75 | merged_results_df = pd.merge(merged_results_df, result_df, on=['video_id', 'edit_id'])
76 | return merged_results_df
77 |
78 | def evaluate(
79 | self,
80 | edit_video: Path,
81 | reference_video: Optional[Path] = None,
82 | index_file: Optional[Path] = None,
83 | edit_prompt: Optional[str] = None,
84 | output_dir: Optional[Path] = None,
85 | save_results: bool = True,
86 | save_summary: bool = True,
87 | evaluator_kwargs: Optional[Dict[str, Dict]] = None,
88 | ) -> Dict[str, Dict[str, List]]:
89 | """Evaluate the editing video(s) using the specified metrics
90 |
91 | Args:
92 | edit_video (Path): Path to the editing video (single video) or directory containing editing videos (batch).
93 | reference_video (Optional[Path], optional): Path to the reference video (single video) or directory
94 | containing reference videos (batch). Defaults to None.
95 | index_file (Optional[Path], optional): Index file containing metadata for batch evaluation. Defaults to
96 | None.
97 | edit_prompt (Optional[str], optional): Edit prompt for single video evaluation. Defaults to None.
98 | output_dir (Optional[Path], optional): Directory to save the results and summary files. Defaults to None.
99 | save_results (bool, optional): Save the results in a CSV file. Only applicable when output_dir is provided.
100 | Defaults to True.
101 | save_summary (bool, optional): Save the summary in a CSV file. Only applicable when output_dir is provided.
102 | Defaults to True.
103 | evaluator_kwargs (Optional[Dict[str, Dict]], optional): Additional keyword arguments for the evaluators.
104 | Defaults to None.
105 |
106 | Returns:
107 | Dict[str, Dict[str, List]]: Results for each metric for each sample, formatted as below:
108 | {
109 | 'metric1_name': {
110 | 'video_id': [video_id1, video_id2, ...],
111 | 'edit_id: [edit_id1, edit_id2, ...],
112 | 'score': [score1, score2, ...]
113 | },
114 | ...
115 | }
116 | """
117 | reference_video = Path(reference_video).resolve() if reference_video else None
118 | edit_video = Path(edit_video).resolve()
119 | index_file = Path(index_file).resolve() if index_file else None
120 | edit_prompt = edit_prompt
121 | self._check_args(edit_video, reference_video, index_file, edit_prompt)
122 |
123 | # evaluate results for each sample
124 | results = {}
125 | for metric in tqdm(self.metrics):
126 | try:
127 | module = importlib.import_module(f'.{metric}', package='v2vbench')
128 | evaluator_cls_name = ''.join([w[0].upper() + w[1:] for w in metric.split('_')])
129 | evaluator_cls = getattr(module, evaluator_cls_name)
130 | kwargs = evaluator_kwargs[metric] if evaluator_kwargs and metric in evaluator_kwargs else {}
131 |
132 | evaluator = evaluator_cls(
133 | index_file=index_file,
134 | edit_video_dir=edit_video,
135 | reference_video_dir=reference_video,
136 | edit_prompt=edit_prompt,
137 | device=self.device,
138 | **kwargs,
139 | )
140 | except Exception as e:
141 | raise NotImplementedError(f'Failed to initalize {metric} metrics: {e}')
142 |
143 | result = evaluator()
144 | results[metric] = result
145 |
146 | # release memory after evaluating each metric
147 | del evaluator
148 | gc.collect()
149 | if self.device.type == 'cuda':
150 | torch.cuda.empty_cache()
151 |
152 | # merge the results into a single dataframe
153 | merged_results_df = self.merge_results(results)
154 | # aggregate into a single-row summary
155 | summary_df = merged_results_df[self.metrics].mean(axis=0).to_frame().T
156 | logger.info(reference_video.name)
157 | logger.info(summary_df)
158 |
159 | if output_dir is not None and (save_results or save_summary):
160 | output_dir = Path(output_dir).resolve()
161 | output_dir.mkdir(parents=True, exist_ok=True)
162 |
163 | if save_results:
164 | merged_results_df.to_csv(output_dir / 'results.csv', index=False)
165 | logger.info(f'Results saved at: {output_dir / "results.csv"}')
166 | if save_summary:
167 | summary_df.to_csv(output_dir / 'summary.csv', index=False)
168 | logger.info(f'Summary saved at: {output_dir / "summary.csv"}')
169 |
170 | return results
--------------------------------------------------------------------------------
/third_party/viclip/viclip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | import torch
5 | from einops import rearrange
6 | from torch import nn
7 | import math
8 |
9 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
10 | from .viclip_vision import clip_joint_l14
11 | from .viclip_text import clip_text_l14
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class ViCLIP(nn.Module):
17 | """docstring for ViCLIP"""
18 |
19 | def __init__(self, tokenizer=None, pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"), freeze_text=True):
20 | super(ViCLIP, self).__init__()
21 | if tokenizer:
22 | self.tokenizer = tokenizer
23 | else:
24 | self.tokenizer = _Tokenizer()
25 | self.max_txt_l = 32
26 |
27 | self.vision_encoder_name = 'vit_l14'
28 |
29 | self.vision_encoder_pretrained = False
30 | self.inputs_image_res = 224
31 | self.vision_encoder_kernel_size = 1
32 | self.vision_encoder_center = True
33 | self.video_input_num_frames = 8
34 | self.vision_encoder_drop_path_rate = 0.1
35 | self.vision_encoder_checkpoint_num = 24
36 | self.is_pretrain = pretrain
37 | self.vision_width = 1024
38 | self.text_width = 768
39 | self.embed_dim = 768
40 | self.masking_prob = 0.9
41 |
42 | self.text_encoder_name = 'vit_l14'
43 | self.text_encoder_pretrained = False #'bert-base-uncased'
44 | self.text_encoder_d_model = 768
45 |
46 | self.text_encoder_vocab_size = 49408
47 |
48 |
49 | # create modules.
50 | self.vision_encoder = self.build_vision_encoder()
51 | self.text_encoder = self.build_text_encoder()
52 |
53 | self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
54 | self.temp_min = 1 / 100.0
55 |
56 | if pretrain:
57 | logger.info(f"Load pretrained weights from {pretrain}")
58 | state_dict = torch.load(pretrain, map_location='cpu')['model']
59 | self.load_state_dict(state_dict)
60 |
61 | # Freeze weights
62 | if freeze_text:
63 | self.freeze_text()
64 |
65 |
66 |
67 | def freeze_text(self):
68 | """freeze text encoder"""
69 | for p in self.text_encoder.parameters():
70 | p.requires_grad = False
71 |
72 | def no_weight_decay(self):
73 | ret = {"temp"}
74 | ret.update(
75 | {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
76 | )
77 | ret.update(
78 | {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
79 | )
80 |
81 | return ret
82 |
83 | def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False):
84 | """forward and calculate loss.
85 |
86 | Args:
87 | image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
88 | text (dict): TODO
89 | idx (torch.Tensor): TODO
90 |
91 | Returns: TODO
92 |
93 | """
94 | self.clip_contrastive_temperature()
95 |
96 | vision_embeds = self.encode_vision(image)
97 | text_embeds = self.encode_text(raw_text)
98 | if return_sims:
99 | sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \
100 | torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1)
101 | return sims
102 |
103 | # calculate loss
104 |
105 | ## VTC loss
106 | loss_vtc = self.clip_loss.vtc_loss(
107 | vision_embeds, text_embeds, idx, self.temp, all_gather=True
108 | )
109 |
110 | return dict(
111 | loss_vtc=loss_vtc,
112 | )
113 |
114 | def encode_vision(self, image, test=False):
115 | """encode image / videos as features.
116 |
117 | Args:
118 | image (torch.Tensor): The input images.
119 | test (bool): Whether testing.
120 |
121 | Returns: tuple.
122 | - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C].
123 | - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C].
124 |
125 | """
126 | if image.ndim == 5:
127 | image = image.permute(0, 2, 1, 3, 4).contiguous()
128 | else:
129 | image = image.unsqueeze(2)
130 |
131 | if not test and self.masking_prob > 0.0:
132 | return self.vision_encoder(
133 | image, masking_prob=self.masking_prob
134 | )
135 |
136 | return self.vision_encoder(image)
137 |
138 | def encode_text(self, text):
139 | """encode text.
140 | Args:
141 | text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
142 | - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
143 | - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
144 | - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
145 | Returns: tuple.
146 | - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
147 | - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
148 |
149 | """
150 | device = next(self.text_encoder.parameters()).device
151 | text = self.text_encoder.tokenize(
152 | text, context_length=self.max_txt_l
153 | ).to(device)
154 | text_embeds = self.text_encoder(text)
155 | return text_embeds
156 |
157 | @torch.no_grad()
158 | def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
159 | """Seems only used during pre-training"""
160 | self.temp.clamp_(min=self.temp_min)
161 |
162 | def build_vision_encoder(self):
163 | """build vision encoder
164 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
165 |
166 | """
167 | encoder_name = self.vision_encoder_name
168 | if encoder_name != "vit_l14":
169 | raise ValueError(f"Not implemented: {encoder_name}")
170 | vision_encoder = clip_joint_l14(
171 | pretrained=self.vision_encoder_pretrained,
172 | input_resolution=self.inputs_image_res,
173 | kernel_size=self.vision_encoder_kernel_size,
174 | center=self.vision_encoder_center,
175 | num_frames=self.video_input_num_frames,
176 | drop_path=self.vision_encoder_drop_path_rate,
177 | checkpoint_num=self.vision_encoder_checkpoint_num,
178 | )
179 | return vision_encoder
180 |
181 | def build_text_encoder(self):
182 | """build text_encoder and possiblly video-to-text multimodal fusion encoder.
183 | Returns: nn.Module. The text encoder
184 |
185 | """
186 | encoder_name = self.text_encoder_name
187 | if encoder_name != "vit_l14":
188 | raise ValueError(f"Not implemented: {encoder_name}")
189 | text_encoder = clip_text_l14(
190 | pretrained=self.text_encoder_pretrained,
191 | embed_dim=self.text_encoder_d_model,
192 | context_length=self.max_txt_l,
193 | vocab_size=self.text_encoder_vocab_size,
194 | checkpoint_num=0,
195 | tokenizer=self.tokenizer,
196 | )
197 |
198 | return text_encoder
199 |
200 | def get_text_encoder(self):
201 | """get text encoder, used for text and cross-modal encoding"""
202 | encoder = self.text_encoder
203 | return encoder.bert if hasattr(encoder, "bert") else encoder
204 |
205 | def get_text_features(self, input_text, tokenizer, text_feature_dict={}):
206 | if input_text in text_feature_dict:
207 | return text_feature_dict[input_text]
208 | text_template= f"{input_text}"
209 | with torch.no_grad():
210 | # text_token = tokenizer.encode(text_template).cuda()
211 | text_features = self.encode_text(text_template).float()
212 | text_features /= text_features.norm(dim=-1, keepdim=True)
213 | text_feature_dict[input_text] = text_features
214 | return text_features
215 |
216 | def get_vid_features(self, input_frames):
217 | with torch.no_grad():
218 | clip_feat = self.encode_vision(input_frames,test=True).float()
219 | clip_feat /= clip_feat.norm(dim=-1, keepdim=True)
220 | return clip_feat
221 |
222 | def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
223 | label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
224 | top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
225 | return top_probs, top_labels
--------------------------------------------------------------------------------
/third_party/dover/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from functools import lru_cache
3 |
4 | import numpy as np
5 | import torch
6 | import torchvision
7 |
8 | random.seed(42)
9 |
10 |
11 | def get_spatial_fragments(
12 | video,
13 | fragments_h=7,
14 | fragments_w=7,
15 | fsize_h=32,
16 | fsize_w=32,
17 | aligned=32,
18 | nfrags=1,
19 | random=False,
20 | random_upsample=False,
21 | fallback_type="upsample",
22 | upsample=-1,
23 | **kwargs,
24 | ):
25 | if upsample > 0:
26 | old_h, old_w = video.shape[-2], video.shape[-1]
27 | if old_h >= old_w:
28 | w = upsample
29 | h = int(upsample * old_h / old_w)
30 | else:
31 | h = upsample
32 | w = int(upsample * old_w / old_h)
33 |
34 | video = get_resized_video(video, h, w)
35 | size_h = fragments_h * fsize_h
36 | size_w = fragments_w * fsize_w
37 | ## video: [C,T,H,W]
38 | ## situation for images
39 | if video.shape[1] == 1:
40 | aligned = 1
41 |
42 | dur_t, res_h, res_w = video.shape[-3:]
43 | ratio = min(res_h / size_h, res_w / size_w)
44 | if fallback_type == "upsample" and ratio < 1:
45 |
46 | ovideo = video
47 | video = torch.nn.functional.interpolate(
48 | video / 255.0, scale_factor=1 / ratio, mode="bilinear"
49 | )
50 | video = (video * 255.0).type_as(ovideo)
51 |
52 | if random_upsample:
53 |
54 | randratio = random.random() * 0.5 + 1
55 | video = torch.nn.functional.interpolate(
56 | video / 255.0, scale_factor=randratio, mode="bilinear"
57 | )
58 | video = (video * 255.0).type_as(ovideo)
59 |
60 | assert dur_t % aligned == 0, "Please provide match vclip and align index"
61 | size = size_h, size_w
62 |
63 | ## make sure that sampling will not run out of the picture
64 | hgrids = torch.LongTensor(
65 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)]
66 | )
67 | wgrids = torch.LongTensor(
68 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)]
69 | )
70 | hlength, wlength = res_h // fragments_h, res_w // fragments_w
71 |
72 | if random:
73 | print("This part is deprecated. Please remind that.")
74 | if res_h > fsize_h:
75 | rnd_h = torch.randint(
76 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
77 | )
78 | else:
79 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
80 | if res_w > fsize_w:
81 | rnd_w = torch.randint(
82 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
83 | )
84 | else:
85 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
86 | else:
87 | if hlength > fsize_h:
88 | rnd_h = torch.randint(
89 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
90 | )
91 | else:
92 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
93 | if wlength > fsize_w:
94 | rnd_w = torch.randint(
95 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
96 | )
97 | else:
98 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
99 |
100 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device)
101 | # target_videos = []
102 |
103 | for i, hs in enumerate(hgrids):
104 | for j, ws in enumerate(wgrids):
105 | for t in range(dur_t // aligned):
106 | t_s, t_e = t * aligned, (t + 1) * aligned
107 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h
108 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w
109 | if random:
110 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
111 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
112 | else:
113 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
114 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
115 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[
116 | :, t_s:t_e, h_so:h_eo, w_so:w_eo
117 | ]
118 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo])
119 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6)
120 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments
121 | return target_video
122 |
123 |
124 | @lru_cache
125 | def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False):
126 | if random_crop:
127 | return torchvision.transforms.RandomResizedCrop(
128 | (size_h, size_w), scale=(0.40, 1.0)
129 | )
130 | if target_ratio > 1:
131 | size_h = int(target_ratio * size_w)
132 | assert size_h > size_w
133 | elif target_ratio < 1:
134 | size_w = int(size_h / target_ratio)
135 | assert size_w > size_h
136 | return torchvision.transforms.Resize((size_h, size_w))
137 |
138 |
139 | def get_resized_video(
140 | video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs,
141 | ):
142 | video = video.permute(1, 0, 2, 3)
143 | resize_opt = get_resize_function(
144 | size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop
145 | )
146 | video = resize_opt(video).permute(1, 0, 2, 3)
147 | return video
148 |
149 |
150 | def get_single_view(
151 | video, sample_type="aesthetic", **kwargs,
152 | ):
153 | if sample_type.startswith("aesthetic"):
154 | video = get_resized_video(video, **kwargs)
155 | elif sample_type.startswith("technical"):
156 | video = get_spatial_fragments(video, **kwargs)
157 | elif sample_type == "original":
158 | return video
159 |
160 | return video
161 |
162 |
163 | def spatial_temporal_view_decomposition(
164 | frames: list, sample_types, samplers, is_train=False, augment=False,
165 | ):
166 | video = {}
167 |
168 | ### Avoid duplicated video decoding!!! Important!!!!
169 | all_frame_inds = []
170 | frame_inds = {}
171 | for stype in samplers:
172 | frame_inds[stype] = samplers[stype](len(frames), is_train)
173 | all_frame_inds.append(frame_inds[stype])
174 |
175 | ### Each frame is only decoded one time!!!
176 | all_frame_inds = np.concatenate(all_frame_inds, 0)
177 | frame_dict = {idx: frames[idx] for idx in np.unique(all_frame_inds)}
178 |
179 | for stype in samplers:
180 | imgs = [frame_dict[idx] for idx in frame_inds[stype]]
181 | imgs = [torch.from_numpy(np.array(img).astype(np.uint8)).float() for img in imgs] # List of [H,W,C]
182 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
183 |
184 | sampled_video = {}
185 | for stype, sopt in sample_types.items():
186 | sampled_video[stype] = get_single_view(video[stype], stype, **sopt)
187 | return sampled_video, frame_inds
188 |
189 |
190 |
191 | class UnifiedFrameSampler:
192 | def __init__(
193 | self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0,
194 | ):
195 |
196 | self.fragments_t = fragments_t
197 | self.fsize_t = fsize_t
198 | self.size_t = fragments_t * fsize_t
199 | self.frame_interval = frame_interval
200 | self.num_clips = num_clips
201 | self.drop_rate = drop_rate
202 |
203 | def get_frame_indices(self, num_frames, train=False):
204 |
205 | tgrids = np.array(
206 | [num_frames // self.fragments_t * i for i in range(self.fragments_t)],
207 | dtype=np.int32,
208 | )
209 | tlength = num_frames // self.fragments_t
210 |
211 | if tlength > self.fsize_t * self.frame_interval:
212 | rnd_t = np.random.randint(
213 | 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids)
214 | )
215 | else:
216 | rnd_t = np.zeros(len(tgrids), dtype=np.int32)
217 |
218 | ranges_t = (
219 | np.arange(self.fsize_t)[None, :] * self.frame_interval
220 | + rnd_t[:, None]
221 | + tgrids[:, None]
222 | )
223 |
224 | drop = random.sample(
225 | list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate)
226 | )
227 | dropped_ranges_t = []
228 | for i, rt in enumerate(ranges_t):
229 | if i not in drop:
230 | dropped_ranges_t.append(rt)
231 | return np.concatenate(dropped_ranges_t)
232 |
233 | def __call__(self, total_frames, train=False, start_index=0):
234 | frame_inds = []
235 |
236 | for i in range(self.num_clips):
237 | frame_inds += [self.get_frame_indices(total_frames)]
238 |
239 | frame_inds = np.concatenate(frame_inds)
240 | frame_inds = np.mod(frame_inds + start_index, total_frames)
241 | return frame_inds.astype(np.int32)
--------------------------------------------------------------------------------
/third_party/dover/models/evaluator.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .conv_backbone import convnext_3d_small, convnext_3d_tiny, convnextv2_3d_pico, convnextv2_3d_femto
7 | from .head import IQAHead, VARHead, VQAHead
8 | from .swin_backbone import SwinTransformer2D as ImageBackbone
9 | from .swin_backbone import SwinTransformer3D as VideoBackbone
10 | from .swin_backbone import swin_3d_small, swin_3d_tiny
11 |
12 |
13 | class BaseEvaluator(nn.Module):
14 | def __init__(
15 | self, backbone=dict(), vqa_head=dict(),
16 | ):
17 | super().__init__()
18 | self.backbone = VideoBackbone(**backbone)
19 | self.vqa_head = VQAHead(**vqa_head)
20 |
21 | def forward(self, vclip, inference=True, **kwargs):
22 | if inference:
23 | self.eval()
24 | with torch.no_grad():
25 | feat = self.backbone(vclip)
26 | score = self.vqa_head(feat)
27 | self.train()
28 | return score
29 | else:
30 | feat = self.backbone(vclip)
31 | score = self.vqa_head(feat)
32 | return score
33 |
34 | def forward_with_attention(self, vclip):
35 | self.eval()
36 | with torch.no_grad():
37 | feat, avg_attns = self.backbone(vclip, require_attn=True)
38 | score = self.vqa_head(feat)
39 | return score, avg_attns
40 |
41 |
42 | class DOVER(nn.Module):
43 | def __init__(
44 | self,
45 | backbone_size="divided",
46 | backbone_preserve_keys="fragments,resize",
47 | multi=False,
48 | layer=-1,
49 | backbone=dict(
50 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)}
51 | ),
52 | divide_head=False,
53 | vqa_head=dict(in_channels=768),
54 | var=False,
55 | ):
56 | self.backbone_preserve_keys = backbone_preserve_keys.split(",")
57 | self.multi = multi
58 | self.layer = layer
59 | super().__init__()
60 | for key, hypers in backbone.items():
61 | # print(backbone_size)
62 | if key not in self.backbone_preserve_keys:
63 | continue
64 | if backbone_size == "divided":
65 | t_backbone_size = hypers["type"]
66 | else:
67 | t_backbone_size = backbone_size
68 | if t_backbone_size == "swin_tiny":
69 | b = swin_3d_tiny(**backbone[key])
70 | elif t_backbone_size == "swin_tiny_grpb":
71 | # to reproduce fast-vqa
72 | b = VideoBackbone()
73 | elif t_backbone_size == "swin_tiny_grpb_m":
74 | # to reproduce fast-vqa-m
75 | b = VideoBackbone(window_size=(4, 4, 4), frag_biases=[0, 0, 0, 0])
76 | elif t_backbone_size == "swin_small":
77 | b = swin_3d_small(**backbone[key])
78 | elif t_backbone_size == "conv_tiny":
79 | b = convnext_3d_tiny(pretrained=True)
80 | elif t_backbone_size == "conv_small":
81 | b = convnext_3d_small(pretrained=True)
82 | elif t_backbone_size == "conv_femto":
83 | b = convnextv2_3d_femto(pretrained=True)
84 | elif t_backbone_size == "conv_pico":
85 | b = convnextv2_3d_pico(pretrained=True)
86 | elif t_backbone_size == "xclip":
87 | raise NotImplementedError
88 | # b = build_x_clip_model(**backbone[key])
89 | else:
90 | raise NotImplementedError
91 | # print("Setting backbone:", key + "_backbone")
92 | setattr(self, key + "_backbone", b)
93 | if divide_head:
94 | for key in backbone:
95 | pre_pool = False #if key == "technical" else True
96 | if key not in self.backbone_preserve_keys:
97 | continue
98 | b = VQAHead(pre_pool=pre_pool, **vqa_head)
99 | # print("Setting head:", key + "_head")
100 | setattr(self, key + "_head", b)
101 | else:
102 | if var:
103 | self.vqa_head = VARHead(**vqa_head)
104 | # print(b)
105 | else:
106 | self.vqa_head = VQAHead(**vqa_head)
107 |
108 | def forward(
109 | self,
110 | vclips,
111 | inference=True,
112 | return_pooled_feats=False,
113 | return_raw_feats=False,
114 | reduce_scores=False,
115 | pooled=False,
116 | **kwargs
117 | ):
118 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return"
119 | if inference:
120 | self.eval()
121 | with torch.no_grad():
122 | scores = []
123 | feats = {}
124 | for key in vclips:
125 | feat = getattr(self, key.split("_")[0] + "_backbone")(
126 | vclips[key], multi=self.multi, layer=self.layer, **kwargs
127 | )
128 | if hasattr(self, key.split("_")[0] + "_head"):
129 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
130 | else:
131 | scores += [getattr(self, "vqa_head")(feat)]
132 | if return_pooled_feats:
133 | feats[key] = feat
134 | if return_raw_feats:
135 | feats[key] = feat
136 | if reduce_scores:
137 | if len(scores) > 1:
138 | scores = reduce(lambda x, y: x + y, scores)
139 | else:
140 | scores = scores[0]
141 | if pooled:
142 | scores = torch.mean(scores, (1, 2, 3, 4))
143 | self.train()
144 | if return_pooled_feats or return_raw_feats:
145 | return scores, feats
146 | return scores
147 | else:
148 | self.train()
149 | scores = []
150 | feats = {}
151 | for key in vclips:
152 | feat = getattr(self, key.split("_")[0] + "_backbone")(
153 | vclips[key], multi=self.multi, layer=self.layer, **kwargs
154 | )
155 | if hasattr(self, key.split("_")[0] + "_head"):
156 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
157 | else:
158 | scores += [getattr(self, "vqa_head")(feat)]
159 | if return_pooled_feats:
160 | feats[key] = feat.mean((-3, -2, -1))
161 | if reduce_scores:
162 | if len(scores) > 1:
163 | scores = reduce(lambda x, y: x + y, scores)
164 | else:
165 | scores = scores[0]
166 | if pooled:
167 | # print(scores.shape)
168 | scores = torch.mean(scores, (1, 2, 3, 4))
169 | # print(scores.shape)
170 |
171 | if return_pooled_feats:
172 | return scores, feats
173 | return scores
174 |
175 |
176 | class MinimumDOVER(nn.Module):
177 | def __init__(self):
178 | super().__init__()
179 | self.technical_backbone = VideoBackbone()
180 | self.aesthetic_backbone = convnext_3d_tiny(pretrained=True)
181 | self.technical_head = VQAHead(pre_pool=False, in_channels=768)
182 | self.aesthetic_head = VQAHead(pre_pool=False, in_channels=768)
183 |
184 |
185 | def forward(self,aesthetic_view, technical_view):
186 | self.eval()
187 | with torch.no_grad():
188 | aesthetic_score = self.aesthetic_head(self.aesthetic_backbone(aesthetic_view))
189 | technical_score = self.technical_head(self.technical_backbone(technical_view))
190 |
191 | aesthetic_score_pooled = torch.mean(aesthetic_score, (1,2,3,4))
192 | technical_score_pooled = torch.mean(technical_score, (1,2,3,4))
193 | return [aesthetic_score_pooled, technical_score_pooled]
194 |
195 |
196 |
197 | class BaseImageEvaluator(nn.Module):
198 | def __init__(
199 | self, backbone=dict(), iqa_head=dict(),
200 | ):
201 | super().__init__()
202 | self.backbone = ImageBackbone(**backbone)
203 | self.iqa_head = IQAHead(**iqa_head)
204 |
205 | def forward(self, image, inference=True, **kwargs):
206 | if inference:
207 | self.eval()
208 | with torch.no_grad():
209 | feat = self.backbone(image)
210 | score = self.iqa_head(feat)
211 | self.train()
212 | return score
213 | else:
214 | feat = self.backbone(image)
215 | score = self.iqa_head(feat)
216 | return score
217 |
218 | def forward_with_attention(self, image):
219 | self.eval()
220 | with torch.no_grad():
221 | feat, avg_attns = self.backbone(image, require_attn=True)
222 | score = self.iqa_head(feat)
223 | return score, avg_attns
224 |
--------------------------------------------------------------------------------
/third_party/viclip/viclip_text.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import os
4 | from collections import OrderedDict
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint as checkpoint
10 | from pkg_resources import packaging
11 | from torch import nn
12 |
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | MODEL_PATH = 'https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K'
19 | _MODELS = {
20 | "ViT-L/14": os.path.join(MODEL_PATH, "vit_l14_text.pth"),
21 | }
22 |
23 |
24 | class LayerNorm(nn.LayerNorm):
25 | """Subclass torch's LayerNorm to handle fp16."""
26 |
27 | def forward(self, x: torch.Tensor):
28 | orig_type = x.dtype
29 | ret = super().forward(x.type(torch.float32))
30 | return ret.type(orig_type)
31 |
32 |
33 | class QuickGELU(nn.Module):
34 | def forward(self, x: torch.Tensor):
35 | return x * torch.sigmoid(1.702 * x)
36 |
37 |
38 | class ResidualAttentionBlock(nn.Module):
39 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
40 | super().__init__()
41 |
42 | self.attn = nn.MultiheadAttention(d_model, n_head)
43 | self.ln_1 = LayerNorm(d_model)
44 | self.mlp = nn.Sequential(OrderedDict([
45 | ("c_fc", nn.Linear(d_model, d_model * 4)),
46 | ("gelu", QuickGELU()),
47 | ("c_proj", nn.Linear(d_model * 4, d_model))
48 | ]))
49 | self.ln_2 = LayerNorm(d_model)
50 | self.attn_mask = attn_mask
51 |
52 | def attention(self, x: torch.Tensor):
53 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
54 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
55 |
56 | def forward(self, x: torch.Tensor):
57 | x = x + self.attention(self.ln_1(x))
58 | x = x + self.mlp(self.ln_2(x))
59 | return x
60 |
61 |
62 | class Transformer(nn.Module):
63 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
64 | checkpoint_num: int = 0):
65 | super().__init__()
66 | self.width = width
67 | self.layers = layers
68 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
69 |
70 | self.checkpoint_num = checkpoint_num
71 |
72 | def forward(self, x: torch.Tensor):
73 | if self.checkpoint_num > 0:
74 | segments = min(self.checkpoint_num, len(self.resblocks))
75 | return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
76 | else:
77 | return self.resblocks(x)
78 |
79 |
80 | class CLIP_TEXT(nn.Module):
81 | def __init__(
82 | self,
83 | embed_dim: int,
84 | context_length: int,
85 | vocab_size: int,
86 | transformer_width: int,
87 | transformer_heads: int,
88 | transformer_layers: int,
89 | checkpoint_num: int,
90 | tokenizer=None,
91 | ):
92 | super().__init__()
93 |
94 | self.context_length = context_length
95 | self._tokenizer = tokenizer or _Tokenizer()
96 |
97 | self.transformer = Transformer(
98 | width=transformer_width,
99 | layers=transformer_layers,
100 | heads=transformer_heads,
101 | attn_mask=self.build_attention_mask(),
102 | checkpoint_num=checkpoint_num,
103 | )
104 |
105 | self.vocab_size = vocab_size
106 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
107 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
108 | self.ln_final = LayerNorm(transformer_width)
109 |
110 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
111 |
112 | def no_weight_decay(self):
113 | return {'token_embedding', 'positional_embedding'}
114 |
115 | @functools.lru_cache(maxsize=None)
116 | def build_attention_mask(self):
117 | # lazily create causal attention mask, with full attention between the vision tokens
118 | # pytorch uses additive attention mask; fill with -inf
119 | mask = torch.empty(self.context_length, self.context_length)
120 | mask.fill_(float("-inf"))
121 | mask.triu_(1) # zero out the lower diagonal
122 | return mask
123 |
124 | def tokenize(self, texts, context_length=77, truncate=True):
125 | """
126 | Returns the tokenized representation of given input string(s)
127 | Parameters
128 | ----------
129 | texts : Union[str, List[str]]
130 | An input string or a list of input strings to tokenize
131 | context_length : int
132 | The context length to use; all CLIP models use 77 as the context length
133 | truncate: bool
134 | Whether to truncate the text in case its encoding is longer than the context length
135 | Returns
136 | -------
137 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
138 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
139 | """
140 | if isinstance(texts, str):
141 | texts = [texts]
142 |
143 | sot_token = self._tokenizer.encoder["<|startoftext|>"]
144 | eot_token = self._tokenizer.encoder["<|endoftext|>"]
145 | all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
146 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
147 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
148 | else:
149 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
150 |
151 | for i, tokens in enumerate(all_tokens):
152 | if len(tokens) > context_length:
153 | if truncate:
154 | tokens = tokens[:context_length]
155 | tokens[-1] = eot_token
156 | else:
157 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
158 | result[i, :len(tokens)] = torch.tensor(tokens)
159 |
160 | return result
161 |
162 | def forward(self, text):
163 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
164 |
165 | x = x + self.positional_embedding
166 | x = x.permute(1, 0, 2) # NLD -> LND
167 | x = self.transformer(x)
168 | x = x.permute(1, 0, 2) # LND -> NLD
169 | x = self.ln_final(x)
170 |
171 | # x.shape = [batch_size, n_ctx, transformer.width]
172 | # take features from the eot embedding (eot_token is the highest number in each sequence)
173 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
174 |
175 | return x
176 |
177 |
178 | def clip_text_b16(
179 | embed_dim=512,
180 | context_length=77,
181 | vocab_size=49408,
182 | transformer_width=512,
183 | transformer_heads=8,
184 | transformer_layers=12,
185 | ):
186 | raise NotImplementedError
187 | model = CLIP_TEXT(
188 | embed_dim,
189 | context_length,
190 | vocab_size,
191 | transformer_width,
192 | transformer_heads,
193 | transformer_layers
194 | )
195 | pretrained = _MODELS["ViT-B/16"]
196 | logger.info(f"Load pretrained weights from {pretrained}")
197 | state_dict = torch.load(pretrained, map_location='cpu')
198 | model.load_state_dict(state_dict, strict=False)
199 | return model.eval()
200 |
201 |
202 | def clip_text_l14(
203 | embed_dim=768,
204 | context_length=77,
205 | vocab_size=49408,
206 | transformer_width=768,
207 | transformer_heads=12,
208 | transformer_layers=12,
209 | checkpoint_num=0,
210 | pretrained=True,
211 | tokenizer=None,
212 | ):
213 | model = CLIP_TEXT(
214 | embed_dim,
215 | context_length,
216 | vocab_size,
217 | transformer_width,
218 | transformer_heads,
219 | transformer_layers,
220 | checkpoint_num,
221 | tokenizer,
222 | )
223 | if pretrained:
224 | if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
225 | pretrained = _MODELS[pretrained]
226 | else:
227 | pretrained = _MODELS["ViT-L/14"]
228 | logger.info(f"Load pretrained weights from {pretrained}")
229 | state_dict = torch.load(pretrained, map_location='cpu')
230 | if context_length != state_dict["positional_embedding"].size(0):
231 | # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
232 | # print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
233 | if context_length < state_dict["positional_embedding"].size(0):
234 | state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
235 | else:
236 | state_dict["positional_embedding"] = F.pad(
237 | state_dict["positional_embedding"],
238 | (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
239 | value=0,
240 | )
241 |
242 | message = model.load_state_dict(state_dict, strict=False)
243 | # print(f"Load pretrained weights from {pretrained}: {message}")
244 | return model.eval()
245 |
246 |
247 | def clip_text_l14_336(
248 | embed_dim=768,
249 | context_length=77,
250 | vocab_size=49408,
251 | transformer_width=768,
252 | transformer_heads=12,
253 | transformer_layers=12,
254 | ):
255 | raise NotImplementedError
256 | model = CLIP_TEXT(
257 | embed_dim,
258 | context_length,
259 | vocab_size,
260 | transformer_width,
261 | transformer_heads,
262 | transformer_layers
263 | )
264 | pretrained = _MODELS["ViT-L/14_336"]
265 | logger.info(f"Load pretrained weights from {pretrained}")
266 | state_dict = torch.load(pretrained, map_location='cpu')
267 | model.load_state_dict(state_dict, strict=False)
268 | return model.eval()
269 |
270 |
271 | def build_clip(config):
272 | model_cls = config.text_encoder.clip_teacher
273 | model = eval(model_cls)()
274 | return model
--------------------------------------------------------------------------------
/third_party/viclip/viclip_vision.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import logging
3 | import os
4 | from collections import OrderedDict
5 |
6 | import torch
7 | import torch.utils.checkpoint as checkpoint
8 | from einops import rearrange
9 | from timm.models.layers import DropPath
10 | from timm.models.registry import register_model
11 | from torch import nn
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
17 | """
18 | Add/Remove extra temporal_embeddings as needed.
19 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
20 |
21 | temp_embed_old: (1, num_frames_old, 1, d)
22 | temp_embed_new: (1, num_frames_new, 1, d)
23 | add_zero: bool, if True, add zero, else, interpolate trained embeddings.
24 | """
25 | # TODO zero pad
26 | num_frms_new = temp_embed_new.shape[1]
27 | num_frms_old = temp_embed_old.shape[1]
28 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
29 | if num_frms_new > num_frms_old:
30 | if add_zero:
31 | temp_embed_new[
32 | :, :num_frms_old
33 | ] = temp_embed_old # untrained embeddings are zeros.
34 | else:
35 | raise NotImplementedError
36 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
37 | elif num_frms_new < num_frms_old:
38 | temp_embed_new = temp_embed_old[:, :num_frms_new]
39 | else: # =
40 | temp_embed_new = temp_embed_old
41 | return temp_embed_new
42 |
43 |
44 | MODEL_PATH = 'https://huggingface.co/OpenGVLab/VBench_Used_Models/blob/main/'
45 | _MODELS = {
46 | "ViT-L/14": os.path.join(MODEL_PATH, "ViClip-InternVid-10M-FLT.pth"),
47 | }
48 |
49 |
50 | class QuickGELU(nn.Module):
51 | def forward(self, x):
52 | return x * torch.sigmoid(1.702 * x)
53 |
54 |
55 | class ResidualAttentionBlock(nn.Module):
56 | def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.):
57 | super().__init__()
58 |
59 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
60 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
61 | self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
62 | self.ln_1 = nn.LayerNorm(d_model)
63 | self.mlp = nn.Sequential(OrderedDict([
64 | ("c_fc", nn.Linear(d_model, d_model * 4)),
65 | ("gelu", QuickGELU()),
66 | ("drop1", nn.Dropout(dropout)),
67 | ("c_proj", nn.Linear(d_model * 4, d_model)),
68 | ("drop2", nn.Dropout(dropout)),
69 | ]))
70 | self.ln_2 = nn.LayerNorm(d_model)
71 | self.attn_mask = attn_mask
72 |
73 | def attention(self, x):
74 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
75 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
76 |
77 | def forward(self, x):
78 | x = x + self.drop_path1(self.attention(self.ln_1(x)))
79 | x = x + self.drop_path2(self.mlp(self.ln_2(x)))
80 | return x
81 |
82 |
83 | class Transformer(nn.Module):
84 | def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.):
85 | super().__init__()
86 | dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]
87 | self.resblocks = nn.ModuleList()
88 | for idx in range(layers):
89 | self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout))
90 | self.checkpoint_num = checkpoint_num
91 |
92 | def forward(self, x):
93 | for idx, blk in enumerate(self.resblocks):
94 | if idx < self.checkpoint_num:
95 | x = checkpoint.checkpoint(blk, x)
96 | else:
97 | x = blk(x)
98 | return x
99 |
100 |
101 | class VisionTransformer(nn.Module):
102 | def __init__(
103 | self, input_resolution, patch_size, width, layers, heads, output_dim=None,
104 | kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0.,
105 | temp_embed=True,
106 | ):
107 | super().__init__()
108 | self.output_dim = output_dim
109 | self.conv1 = nn.Conv3d(
110 | 3, width,
111 | (kernel_size, patch_size, patch_size),
112 | (kernel_size, patch_size, patch_size),
113 | (0, 0, 0), bias=False
114 | )
115 |
116 | scale = width ** -0.5
117 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
118 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
119 | self.ln_pre = nn.LayerNorm(width)
120 | if temp_embed:
121 | self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
122 |
123 | self.transformer = Transformer(
124 | width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num,
125 | dropout=dropout)
126 |
127 | self.ln_post = nn.LayerNorm(width)
128 | if output_dim is not None:
129 | self.proj = nn.Parameter(torch.empty(width, output_dim))
130 | else:
131 | self.proj = None
132 |
133 | self.dropout = nn.Dropout(dropout)
134 |
135 | def get_num_layers(self):
136 | return len(self.transformer.resblocks)
137 |
138 | @torch.jit.ignore
139 | def no_weight_decay(self):
140 | return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'}
141 |
142 | def mask_tokens(self, inputs, masking_prob=0.0):
143 | B, L, _ = inputs.shape
144 |
145 | # This is different from text as we are masking a fix number of tokens
146 | Lm = int(masking_prob * L)
147 | masked_indices = torch.zeros(B, L)
148 | indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm]
149 | batch_indices = (
150 | torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices)
151 | )
152 | masked_indices[batch_indices, indices] = 1
153 |
154 | masked_indices = masked_indices.bool()
155 |
156 | return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1])
157 |
158 | def forward(self, x, masking_prob=0.0):
159 | x = self.conv1(x) # shape = [*, width, grid, grid]
160 | B, C, T, H, W = x.shape
161 | x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
162 |
163 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
164 | x = x + self.positional_embedding.to(x.dtype)
165 |
166 | # temporal pos
167 | cls_tokens = x[:B, :1, :]
168 | x = x[:, 1:]
169 | x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
170 | if hasattr(self, 'temporal_positional_embedding'):
171 | if x.size(1) == 1:
172 | # This is a workaround for unused parameter issue
173 | x = x + self.temporal_positional_embedding.mean(1)
174 | else:
175 | x = x + self.temporal_positional_embedding
176 | x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
177 |
178 | if masking_prob > 0.0:
179 | x = self.mask_tokens(x, masking_prob)
180 |
181 | x = torch.cat((cls_tokens, x), dim=1)
182 |
183 | x = self.ln_pre(x)
184 |
185 | x = x.permute(1, 0, 2) #BND -> NBD
186 | x = self.transformer(x)
187 |
188 | x = self.ln_post(x)
189 |
190 | if self.proj is not None:
191 | x = self.dropout(x[0]) @ self.proj
192 | else:
193 | x = x.permute(1, 0, 2) #NBD -> BND
194 |
195 | return x
196 |
197 |
198 | def inflate_weight(weight_2d, time_dim, center=True):
199 | logger.info(f'Init center: {center}')
200 | if center:
201 | weight_3d = torch.zeros(*weight_2d.shape)
202 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
203 | middle_idx = time_dim // 2
204 | weight_3d[:, :, middle_idx, :, :] = weight_2d
205 | else:
206 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
207 | weight_3d = weight_3d / time_dim
208 | return weight_3d
209 |
210 |
211 | def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True):
212 | state_dict_3d = model.state_dict()
213 | for k in state_dict.keys():
214 | if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape:
215 | if len(state_dict_3d[k].shape) <= 2:
216 | logger.info(f'Ignore: {k}')
217 | continue
218 | logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
219 | time_dim = state_dict_3d[k].shape[2]
220 | state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center)
221 |
222 | pos_embed_checkpoint = state_dict['positional_embedding']
223 | embedding_size = pos_embed_checkpoint.shape[-1]
224 | num_patches = (input_resolution // patch_size) ** 2
225 | orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
226 | new_size = int(num_patches ** 0.5)
227 | if orig_size != new_size:
228 | logger.info(f'Pos_emb from {orig_size} to {new_size}')
229 | extra_tokens = pos_embed_checkpoint[:1]
230 | pos_tokens = pos_embed_checkpoint[1:]
231 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
232 | pos_tokens = torch.nn.functional.interpolate(
233 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
234 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
235 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
236 | state_dict['positional_embedding'] = new_pos_embed
237 |
238 | message = model.load_state_dict(state_dict, strict=False)
239 | logger.info(f"Load pretrained weights: {message}")
240 |
241 |
242 | @register_model
243 | def clip_joint_b16(
244 | pretrained=True, input_resolution=224, kernel_size=1,
245 | center=True, num_frames=8, drop_path=0.
246 | ):
247 | model = VisionTransformer(
248 | input_resolution=input_resolution, patch_size=16,
249 | width=768, layers=12, heads=12, output_dim=512,
250 | kernel_size=kernel_size, num_frames=num_frames,
251 | drop_path=drop_path,
252 | )
253 | raise NotImplementedError
254 | if pretrained:
255 | logger.info('load pretrained weights')
256 | state_dict = torch.load(_MODELS["ViT-B/16"], map_location='cpu')
257 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center)
258 | return model.eval()
259 |
260 |
261 | @register_model
262 | def clip_joint_l14(
263 | pretrained=False, input_resolution=224, kernel_size=1,
264 | center=True, num_frames=8, drop_path=0., checkpoint_num=0,
265 | dropout=0.,
266 | ):
267 | model = VisionTransformer(
268 | input_resolution=input_resolution, patch_size=14,
269 | width=1024, layers=24, heads=16, output_dim=768,
270 | kernel_size=kernel_size, num_frames=num_frames,
271 | drop_path=drop_path, checkpoint_num=checkpoint_num,
272 | dropout=dropout,
273 | )
274 | if pretrained:
275 | if isinstance(pretrained, str):
276 | model_name = pretrained
277 | else:
278 | model_name = "ViT-L/14"
279 | logger.info('load pretrained weights')
280 | state_dict = torch.load(_MODELS[model_name], map_location='cpu')
281 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
282 | return model.eval()
283 |
284 |
285 | @register_model
286 | def clip_joint_l14_336(
287 | pretrained=True, input_resolution=336, kernel_size=1,
288 | center=True, num_frames=8, drop_path=0.
289 | ):
290 | raise NotImplementedError
291 | model = VisionTransformer(
292 | input_resolution=input_resolution, patch_size=14,
293 | width=1024, layers=24, heads=16, output_dim=768,
294 | kernel_size=kernel_size, num_frames=num_frames,
295 | drop_path=drop_path,
296 | )
297 | if pretrained:
298 | logger.info('load pretrained weights')
299 | state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
300 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
301 | return model.eval()
302 |
303 |
304 | def interpolate_pos_embed_vit(state_dict, new_model):
305 | key = "vision_encoder.temporal_positional_embedding"
306 | if key in state_dict:
307 | vision_temp_embed_new = new_model.state_dict()[key]
308 | vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d]
309 | vision_temp_embed_old = state_dict[key]
310 | vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2)
311 |
312 | state_dict[key] = load_temp_embed_with_mismatch(
313 | vision_temp_embed_old, vision_temp_embed_new, add_zero=False
314 | ).squeeze(2)
315 |
316 | key = "text_encoder.positional_embedding"
317 | if key in state_dict:
318 | text_temp_embed_new = new_model.state_dict()[key]
319 | text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d]
320 | text_temp_embed_old = state_dict[key]
321 | text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2)
322 |
323 | state_dict[key] = load_temp_embed_with_mismatch(
324 | text_temp_embed_old, text_temp_embed_new, add_zero=False
325 | ).squeeze(2).squeeze(0)
326 | return state_dict
--------------------------------------------------------------------------------
/third_party/gmflow/models/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .utils import split_feature, merge_splits
6 |
7 |
8 | def single_head_full_attention(q, k, v):
9 | # q, k, v: [B, L, C]
10 | assert q.dim() == k.dim() == v.dim() == 3
11 |
12 | scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
13 | attn = torch.softmax(scores, dim=2) # [B, L, L]
14 | out = torch.matmul(attn, v) # [B, L, C]
15 |
16 | return out
17 |
18 |
19 | def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
20 | shift_size_h, shift_size_w, device=torch.device('cuda')):
21 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
22 | # calculate attention mask for SW-MSA
23 | h, w = input_resolution
24 | img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
25 | h_slices = (slice(0, -window_size_h),
26 | slice(-window_size_h, -shift_size_h),
27 | slice(-shift_size_h, None))
28 | w_slices = (slice(0, -window_size_w),
29 | slice(-window_size_w, -shift_size_w),
30 | slice(-shift_size_w, None))
31 | cnt = 0
32 | for h in h_slices:
33 | for w in w_slices:
34 | img_mask[:, h, w, :] = cnt
35 | cnt += 1
36 |
37 | mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
38 |
39 | mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
40 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
41 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
42 |
43 | return attn_mask
44 |
45 |
46 | def single_head_split_window_attention(q, k, v,
47 | num_splits=1,
48 | with_shift=False,
49 | h=None,
50 | w=None,
51 | attn_mask=None,
52 | ):
53 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
54 | # q, k, v: [B, L, C]
55 | assert q.dim() == k.dim() == v.dim() == 3
56 |
57 | assert h is not None and w is not None
58 | assert q.size(1) == h * w
59 |
60 | b, _, c = q.size()
61 |
62 | b_new = b * num_splits * num_splits
63 |
64 | window_size_h = h // num_splits
65 | window_size_w = w // num_splits
66 |
67 | q = q.view(b, h, w, c) # [B, H, W, C]
68 | k = k.view(b, h, w, c)
69 | v = v.view(b, h, w, c)
70 |
71 | scale_factor = c ** 0.5
72 |
73 | if with_shift:
74 | assert attn_mask is not None # compute once
75 | shift_size_h = window_size_h // 2
76 | shift_size_w = window_size_w // 2
77 |
78 | q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
79 | k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
80 | v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
81 |
82 | q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
83 | k = split_feature(k, num_splits=num_splits, channel_last=True)
84 | v = split_feature(v, num_splits=num_splits, channel_last=True)
85 |
86 | scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
87 | ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
88 |
89 | if with_shift:
90 | scores += attn_mask.repeat(b, 1, 1)
91 |
92 | attn = torch.softmax(scores, dim=-1)
93 |
94 | out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
95 |
96 | out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
97 | num_splits=num_splits, channel_last=True) # [B, H, W, C]
98 |
99 | # shift back
100 | if with_shift:
101 | out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
102 |
103 | out = out.view(b, -1, c)
104 |
105 | return out
106 |
107 |
108 | class TransformerLayer(nn.Module):
109 | def __init__(self,
110 | d_model=256,
111 | nhead=1,
112 | attention_type='swin',
113 | no_ffn=False,
114 | ffn_dim_expansion=4,
115 | with_shift=False,
116 | **kwargs,
117 | ):
118 | super(TransformerLayer, self).__init__()
119 |
120 | self.dim = d_model
121 | self.nhead = nhead
122 | self.attention_type = attention_type
123 | self.no_ffn = no_ffn
124 |
125 | self.with_shift = with_shift
126 |
127 | # multi-head attention
128 | self.q_proj = nn.Linear(d_model, d_model, bias=False)
129 | self.k_proj = nn.Linear(d_model, d_model, bias=False)
130 | self.v_proj = nn.Linear(d_model, d_model, bias=False)
131 |
132 | self.merge = nn.Linear(d_model, d_model, bias=False)
133 |
134 | self.norm1 = nn.LayerNorm(d_model)
135 |
136 | # no ffn after self-attn, with ffn after cross-attn
137 | if not self.no_ffn:
138 | in_channels = d_model * 2
139 | self.mlp = nn.Sequential(
140 | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
141 | nn.GELU(),
142 | nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
143 | )
144 |
145 | self.norm2 = nn.LayerNorm(d_model)
146 |
147 | def forward(self, source, target,
148 | height=None,
149 | width=None,
150 | shifted_window_attn_mask=None,
151 | attn_num_splits=None,
152 | **kwargs,
153 | ):
154 | # source, target: [B, L, C]
155 | query, key, value = source, target, target
156 |
157 | # single-head attention
158 | query = self.q_proj(query) # [B, L, C]
159 | key = self.k_proj(key) # [B, L, C]
160 | value = self.v_proj(value) # [B, L, C]
161 |
162 | if self.attention_type == 'swin' and attn_num_splits > 1:
163 | if self.nhead > 1:
164 | # we observe that multihead attention slows down the speed and increases the memory consumption
165 | # without bringing obvious performance gains and thus the implementation is removed
166 | raise NotImplementedError
167 | else:
168 | message = single_head_split_window_attention(query, key, value,
169 | num_splits=attn_num_splits,
170 | with_shift=self.with_shift,
171 | h=height,
172 | w=width,
173 | attn_mask=shifted_window_attn_mask,
174 | )
175 | else:
176 | message = single_head_full_attention(query, key, value) # [B, L, C]
177 |
178 | message = self.merge(message) # [B, L, C]
179 | message = self.norm1(message)
180 |
181 | if not self.no_ffn:
182 | message = self.mlp(torch.cat([source, message], dim=-1))
183 | message = self.norm2(message)
184 |
185 | return source + message
186 |
187 |
188 | class TransformerBlock(nn.Module):
189 | """self attention + cross attention + FFN"""
190 |
191 | def __init__(self,
192 | d_model=256,
193 | nhead=1,
194 | attention_type='swin',
195 | ffn_dim_expansion=4,
196 | with_shift=False,
197 | **kwargs,
198 | ):
199 | super(TransformerBlock, self).__init__()
200 |
201 | self.self_attn = TransformerLayer(d_model=d_model,
202 | nhead=nhead,
203 | attention_type=attention_type,
204 | no_ffn=True,
205 | ffn_dim_expansion=ffn_dim_expansion,
206 | with_shift=with_shift,
207 | )
208 |
209 | self.cross_attn_ffn = TransformerLayer(d_model=d_model,
210 | nhead=nhead,
211 | attention_type=attention_type,
212 | ffn_dim_expansion=ffn_dim_expansion,
213 | with_shift=with_shift,
214 | )
215 |
216 | def forward(self, source, target,
217 | height=None,
218 | width=None,
219 | shifted_window_attn_mask=None,
220 | attn_num_splits=None,
221 | **kwargs,
222 | ):
223 | # source, target: [B, L, C]
224 |
225 | # self attention
226 | source = self.self_attn(source, source,
227 | height=height,
228 | width=width,
229 | shifted_window_attn_mask=shifted_window_attn_mask,
230 | attn_num_splits=attn_num_splits,
231 | )
232 |
233 | # cross attention and ffn
234 | source = self.cross_attn_ffn(source, target,
235 | height=height,
236 | width=width,
237 | shifted_window_attn_mask=shifted_window_attn_mask,
238 | attn_num_splits=attn_num_splits,
239 | )
240 |
241 | return source
242 |
243 |
244 | class FeatureTransformer(nn.Module):
245 | def __init__(self,
246 | num_layers=6,
247 | d_model=128,
248 | nhead=1,
249 | attention_type='swin',
250 | ffn_dim_expansion=4,
251 | **kwargs,
252 | ):
253 | super(FeatureTransformer, self).__init__()
254 |
255 | self.attention_type = attention_type
256 |
257 | self.d_model = d_model
258 | self.nhead = nhead
259 |
260 | self.layers = nn.ModuleList([
261 | TransformerBlock(d_model=d_model,
262 | nhead=nhead,
263 | attention_type=attention_type,
264 | ffn_dim_expansion=ffn_dim_expansion,
265 | with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
266 | )
267 | for i in range(num_layers)])
268 |
269 | for p in self.parameters():
270 | if p.dim() > 1:
271 | nn.init.xavier_uniform_(p)
272 |
273 | def forward(self, feature0, feature1,
274 | attn_num_splits=None,
275 | **kwargs,
276 | ):
277 |
278 | b, c, h, w = feature0.shape
279 | assert self.d_model == c
280 |
281 | feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
282 | feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
283 |
284 | if self.attention_type == 'swin' and attn_num_splits > 1:
285 | # global and refine use different number of splits
286 | window_size_h = h // attn_num_splits
287 | window_size_w = w // attn_num_splits
288 |
289 | # compute attn mask once
290 | shifted_window_attn_mask = generate_shift_window_attn_mask(
291 | input_resolution=(h, w),
292 | window_size_h=window_size_h,
293 | window_size_w=window_size_w,
294 | shift_size_h=window_size_h // 2,
295 | shift_size_w=window_size_w // 2,
296 | device=feature0.device,
297 | ) # [K*K, H/K*W/K, H/K*W/K]
298 | else:
299 | shifted_window_attn_mask = None
300 |
301 | # concat feature0 and feature1 in batch dimension to compute in parallel
302 | concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
303 | concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
304 |
305 | for layer in self.layers:
306 | concat0 = layer(concat0, concat1,
307 | height=h,
308 | width=w,
309 | shifted_window_attn_mask=shifted_window_attn_mask,
310 | attn_num_splits=attn_num_splits,
311 | )
312 |
313 | # update feature1
314 | concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
315 |
316 | feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
317 |
318 | # reshape back
319 | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
320 | feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
321 |
322 | return feature0, feature1
323 |
324 |
325 | class FeatureFlowAttention(nn.Module):
326 | """
327 | flow propagation with self-attention on feature
328 | query: feature0, key: feature0, value: flow
329 | """
330 |
331 | def __init__(self, in_channels,
332 | **kwargs,
333 | ):
334 | super(FeatureFlowAttention, self).__init__()
335 |
336 | self.q_proj = nn.Linear(in_channels, in_channels)
337 | self.k_proj = nn.Linear(in_channels, in_channels)
338 |
339 | for p in self.parameters():
340 | if p.dim() > 1:
341 | nn.init.xavier_uniform_(p)
342 |
343 | def forward(self, feature0, flow,
344 | local_window_attn=False,
345 | local_window_radius=1,
346 | **kwargs,
347 | ):
348 | # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
349 | if local_window_attn:
350 | return self.forward_local_window_attn(feature0, flow,
351 | local_window_radius=local_window_radius)
352 |
353 | b, c, h, w = feature0.size()
354 |
355 | query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
356 |
357 | # a note: the ``correct'' implementation should be:
358 | # ``query = self.q_proj(query), key = self.k_proj(query)''
359 | # this problem is observed while cleaning up the code
360 | # however, this doesn't affect the performance since the projection is a linear operation,
361 | # thus the two projection matrices for key can be merged
362 | # so I just leave it as is in order to not re-train all models :)
363 | query = self.q_proj(query) # [B, H*W, C]
364 | key = self.k_proj(query) # [B, H*W, C]
365 |
366 | value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
367 |
368 | scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
369 | prob = torch.softmax(scores, dim=-1)
370 |
371 | out = torch.matmul(prob, value) # [B, H*W, 2]
372 | out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
373 |
374 | return out
375 |
376 | def forward_local_window_attn(self, feature0, flow,
377 | local_window_radius=1,
378 | ):
379 | assert flow.size(1) == 2
380 | assert local_window_radius > 0
381 |
382 | b, c, h, w = feature0.size()
383 |
384 | feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
385 | ).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
386 |
387 | kernel_size = 2 * local_window_radius + 1
388 |
389 | feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
390 |
391 | feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
392 | padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
393 |
394 | feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
395 | 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
396 |
397 | flow_window = F.unfold(flow, kernel_size=kernel_size,
398 | padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
399 |
400 | flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
401 | 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
402 |
403 | scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
404 |
405 | prob = torch.softmax(scores, dim=-1)
406 |
407 | out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
408 |
409 | return out
--------------------------------------------------------------------------------
/third_party/dover/models/conv_backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import trunc_normal_, DropPath
5 |
6 |
7 | class GRN(nn.Module):
8 | """ GRN (Global Response Normalization) layer
9 | """
10 | def __init__(self, dim):
11 | super().__init__()
12 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
13 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
14 |
15 | def forward(self, x):
16 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
17 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
18 | return self.gamma * (x * Nx) + self.beta + x
19 |
20 | class Block(nn.Module):
21 | r""" ConvNeXt Block. There are two equivalent implementations:
22 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
23 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
24 | We use (2) as we find it slightly faster in PyTorch
25 |
26 | Args:
27 | dim (int): Number of input channels.
28 | drop_path (float): Stochastic depth rate. Default: 0.0
29 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
30 | """
31 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
32 | super().__init__()
33 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
34 | self.norm = LayerNorm(dim, eps=1e-6)
35 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
36 | self.act = nn.GELU()
37 | self.pwconv2 = nn.Linear(4 * dim, dim)
38 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
39 | requires_grad=True) if layer_scale_init_value > 0 else None
40 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
41 |
42 | def forward(self, x):
43 | input = x
44 | x = self.dwconv(x)
45 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
46 | x = self.norm(x)
47 | x = self.pwconv1(x)
48 | x = self.act(x)
49 | x = self.pwconv2(x)
50 | if self.gamma is not None:
51 | x = self.gamma * x
52 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
53 |
54 | x = input + self.drop_path(x)
55 | return x
56 |
57 | class ConvNeXt(nn.Module):
58 | r""" ConvNeXt
59 | A PyTorch impl of : `A ConvNet for the 2020s` -
60 | https://arxiv.org/pdf/2201.03545.pdf
61 | Args:
62 | in_chans (int): Number of input image channels. Default: 3
63 | num_classes (int): Number of classes for classification head. Default: 1000
64 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
65 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
66 | drop_path_rate (float): Stochastic depth rate. Default: 0.
67 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
68 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
69 | """
70 | def __init__(self, in_chans=3, num_classes=1000,
71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
72 | layer_scale_init_value=1e-6, head_init_scale=1.,
73 | ):
74 | super().__init__()
75 |
76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
77 | stem = nn.Sequential(
78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
80 | )
81 | self.downsample_layers.append(stem)
82 | for i in range(3):
83 | downsample_layer = nn.Sequential(
84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
86 | )
87 | self.downsample_layers.append(downsample_layer)
88 |
89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
90 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
91 | cur = 0
92 | for i in range(4):
93 | stage = nn.Sequential(
94 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
95 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
96 | )
97 | self.stages.append(stage)
98 | cur += depths[i]
99 |
100 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
101 | self.head = nn.Linear(dims[-1], num_classes)
102 |
103 | self.apply(self._init_weights)
104 | self.head.weight.data.mul_(head_init_scale)
105 | self.head.bias.data.mul_(head_init_scale)
106 |
107 | def _init_weights(self, m):
108 | if isinstance(m, (nn.Conv2d, nn.Linear)):
109 | trunc_normal_(m.weight, std=.02)
110 | nn.init.constant_(m.bias, 0)
111 |
112 | def forward_features(self, x):
113 | for i in range(4):
114 | x = self.downsample_layers[i](x)
115 | x = self.stages[i](x)
116 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
117 |
118 | def forward(self, x):
119 | x = self.forward_features(x)
120 | x = self.head(x)
121 | return x
122 |
123 | class LayerNorm(nn.Module):
124 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
125 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
126 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
127 | with shape (batch_size, channels, height, width).
128 | """
129 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
130 | super().__init__()
131 | self.weight = nn.Parameter(torch.ones(normalized_shape))
132 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
133 | self.eps = eps
134 | self.data_format = data_format
135 | if self.data_format not in ["channels_last", "channels_first"]:
136 | raise NotImplementedError
137 | self.normalized_shape = (normalized_shape, )
138 |
139 | def forward(self, x):
140 | if self.data_format == "channels_last":
141 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
142 | elif self.data_format == "channels_first":
143 | u = x.mean(1, keepdim=True)
144 | s = (x - u).pow(2).mean(1, keepdim=True)
145 | x = (x - u) / torch.sqrt(s + self.eps)
146 | if len(x.shape) == 4:
147 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
148 | elif len(x.shape) == 5:
149 | x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
150 | return x
151 |
152 |
153 | class Block3D(nn.Module):
154 | r""" ConvNeXt Block. There are two equivalent implementations:
155 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
156 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
157 | We use (2) as we find it slightly faster in PyTorch
158 |
159 | Args:
160 | dim (int): Number of input channels.
161 | drop_path (float): Stochastic depth rate. Default: 0.0
162 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
163 | """
164 | def __init__(self, dim, drop_path=0., inflate_len=3, layer_scale_init_value=1e-6):
165 | super().__init__()
166 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv
167 | self.norm = LayerNorm(dim, eps=1e-6)
168 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
169 | self.act = nn.GELU()
170 | self.pwconv2 = nn.Linear(4 * dim, dim)
171 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
172 | requires_grad=True) if layer_scale_init_value > 0 else None
173 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
174 |
175 | def forward(self, x):
176 | input = x
177 | x = self.dwconv(x)
178 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C)
179 | x = self.norm(x)
180 | x = self.pwconv1(x)
181 | x = self.act(x)
182 | x = self.pwconv2(x)
183 | if self.gamma is not None:
184 | x = self.gamma * x
185 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
186 |
187 | x = input + self.drop_path(x)
188 | return x
189 |
190 | class BlockV2(nn.Module):
191 | """ ConvNeXtV2 Block.
192 |
193 | Args:
194 | dim (int): Number of input channels.
195 | drop_path (float): Stochastic depth rate. Default: 0.0
196 | """
197 | def __init__(self, dim, drop_path=0.):
198 | super().__init__()
199 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
200 | self.norm = LayerNorm(dim, eps=1e-6)
201 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
202 | self.act = nn.GELU()
203 | self.grn = GRN(4 * dim)
204 | self.pwconv2 = nn.Linear(4 * dim, dim)
205 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
206 |
207 | def forward(self, x):
208 | input = x
209 | x = self.dwconv(x)
210 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
211 | x = self.norm(x)
212 | x = self.pwconv1(x)
213 | x = self.act(x)
214 | x = self.grn(x)
215 | x = self.pwconv2(x)
216 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
217 |
218 | x = input + self.drop_path(x)
219 | return x
220 |
221 | class BlockV23D(nn.Module):
222 | """ ConvNeXtV2 Block.
223 |
224 | Args:
225 | dim (int): Number of input channels.
226 | drop_path (float): Stochastic depth rate. Default: 0.0
227 | """
228 | def __init__(self, dim, drop_path=0., inflate_len=3,):
229 | super().__init__()
230 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv
231 | self.norm = LayerNorm(dim, eps=1e-6)
232 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
233 | self.act = nn.GELU()
234 | self.grn = GRN(4 * dim)
235 | self.pwconv2 = nn.Linear(4 * dim, dim)
236 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
237 |
238 | def forward(self, x):
239 | input = x
240 | x = self.dwconv(x)
241 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C)
242 | x = self.norm(x)
243 | x = self.pwconv1(x)
244 | x = self.act(x)
245 | x = self.grn(x)
246 | x = self.pwconv2(x)
247 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
248 |
249 | x = input + self.drop_path(x)
250 | return x
251 |
252 | class ConvNeXtV2(nn.Module):
253 | """ ConvNeXt V2
254 |
255 | Args:
256 | in_chans (int): Number of input image channels. Default: 3
257 | num_classes (int): Number of classes for classification head. Default: 1000
258 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
259 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
260 | drop_path_rate (float): Stochastic depth rate. Default: 0.
261 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
262 | """
263 | def __init__(self, in_chans=3, num_classes=1000,
264 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
265 | drop_path_rate=0., head_init_scale=1.
266 | ):
267 | super().__init__()
268 | self.depths = depths
269 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
270 | stem = nn.Sequential(
271 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
272 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
273 | )
274 | self.downsample_layers.append(stem)
275 | for i in range(3):
276 | downsample_layer = nn.Sequential(
277 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
278 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
279 | )
280 | self.downsample_layers.append(downsample_layer)
281 |
282 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
283 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
284 | cur = 0
285 | for i in range(4):
286 | stage = nn.Sequential(
287 | *[BlockV2(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
288 | )
289 | self.stages.append(stage)
290 | cur += depths[i]
291 |
292 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
293 | self.head = nn.Linear(dims[-1], num_classes)
294 |
295 | self.apply(self._init_weights)
296 | self.head.weight.data.mul_(head_init_scale)
297 | self.head.bias.data.mul_(head_init_scale)
298 |
299 | def _init_weights(self, m):
300 | if isinstance(m, (nn.Conv2d, nn.Linear)):
301 | trunc_normal_(m.weight, std=.02)
302 | nn.init.constant_(m.bias, 0)
303 |
304 | def forward_features(self, x):
305 | for i in range(4):
306 | x = self.downsample_layers[i](x)
307 | x = self.stages[i](x)
308 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
309 |
310 | def forward(self, x):
311 | x = self.forward_features(x)
312 | x = self.head(x)
313 | return x
314 |
315 | def convnextv2_atto(**kwargs):
316 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
317 | return model
318 |
319 | def convnextv2_femto(**kwargs):
320 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
321 | return model
322 |
323 | def convnext_pico(**kwargs):
324 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
325 | return model
326 |
327 | def convnextv2_nano(**kwargs):
328 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
329 | return model
330 |
331 | def convnextv2_tiny(**kwargs):
332 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
333 | return model
334 |
335 | def convnextv2_base(**kwargs):
336 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
337 | return model
338 |
339 | def convnextv2_large(**kwargs):
340 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
341 | return model
342 |
343 | def convnextv2_huge(**kwargs):
344 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
345 | return model
346 |
347 | class ConvNeXt3D(nn.Module):
348 | r""" ConvNeXt
349 | A PyTorch impl of : `A ConvNet for the 2020s` -
350 | https://arxiv.org/pdf/2201.03545.pdf
351 | Args:
352 | in_chans (int): Number of input image channels. Default: 3
353 | num_classes (int): Number of classes for classification head. Default: 1000
354 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
355 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
356 | drop_path_rate (float): Stochastic depth rate. Default: 0.
357 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
358 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
359 | """
360 | def __init__(self, in_chans=3, num_classes=1000,
361 | inflate_strategy='131',
362 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
363 | layer_scale_init_value=1e-6, head_init_scale=1.,
364 | ):
365 | super().__init__()
366 |
367 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
368 | stem = nn.Sequential(
369 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)),
370 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
371 | )
372 | self.downsample_layers.append(stem)
373 | for i in range(3):
374 | downsample_layer = nn.Sequential(
375 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
376 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)),
377 | )
378 | self.downsample_layers.append(downsample_layer)
379 |
380 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
381 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
382 | cur = 0
383 | for i in range(4):
384 | stage = nn.Sequential(
385 | *[Block3D(dim=dims[i], inflate_len=int(inflate_strategy[j%len(inflate_strategy)]),
386 | drop_path=dp_rates[cur + j],
387 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
388 | )
389 | self.stages.append(stage)
390 | cur += depths[i]
391 |
392 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
393 |
394 | self.apply(self._init_weights)
395 |
396 | def inflate_weights(self, s_state_dict):
397 | t_state_dict = self.state_dict()
398 | from collections import OrderedDict
399 | for key in t_state_dict.keys():
400 | if key not in s_state_dict:
401 | # print(key)
402 | continue
403 | if t_state_dict[key].shape != s_state_dict[key].shape:
404 | t = t_state_dict[key].shape[2]
405 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t
406 | self.load_state_dict(s_state_dict, strict=False)
407 |
408 | def _init_weights(self, m):
409 | if isinstance(m, (nn.Conv3d, nn.Linear)):
410 | trunc_normal_(m.weight, std=.02)
411 | nn.init.constant_(m.bias, 0)
412 |
413 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1):
414 | if multi:
415 | xs = []
416 | for i in range(4):
417 | x = self.downsample_layers[i](x)
418 | x = self.stages[i](x)
419 | if multi:
420 | xs.append(x)
421 | if return_spatial:
422 | if multi:
423 | shape = xs[-1].shape[2:]
424 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1)
425 | elif layer > -1:
426 | return xs[layer]
427 | else:
428 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
429 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C)
430 |
431 | def forward(self, x, multi=False, layer=-1):
432 | x = self.forward_features(x, True, multi=multi, layer=layer)
433 | return x
434 |
435 |
436 | class ConvNeXtV23D(nn.Module):
437 | """ ConvNeXt V2
438 |
439 | Args:
440 | in_chans (int): Number of input image channels. Default: 3
441 | num_classes (int): Number of classes for classification head. Default: 1000
442 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
443 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
444 | drop_path_rate (float): Stochastic depth rate. Default: 0.
445 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
446 | """
447 | def __init__(self, in_chans=3, num_classes=1000,
448 | inflate_strategy='131',
449 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
450 | drop_path_rate=0., head_init_scale=1.
451 | ):
452 | super().__init__()
453 | self.depths = depths
454 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
455 | stem = nn.Sequential(
456 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)),
457 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
458 | )
459 | self.downsample_layers.append(stem)
460 | for i in range(3):
461 | downsample_layer = nn.Sequential(
462 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
463 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)),
464 | )
465 | self.downsample_layers.append(downsample_layer)
466 |
467 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
468 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
469 | cur = 0
470 | for i in range(4):
471 | stage = nn.Sequential(
472 | *[BlockV23D(dim=dims[i], drop_path=dp_rates[cur + j],
473 | inflate_len=int(inflate_strategy[j%len(inflate_strategy)]),
474 | ) for j in range(depths[i])]
475 | )
476 | self.stages.append(stage)
477 | cur += depths[i]
478 |
479 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
480 | self.head = nn.Linear(dims[-1], num_classes)
481 |
482 | self.apply(self._init_weights)
483 | self.head.weight.data.mul_(head_init_scale)
484 | self.head.bias.data.mul_(head_init_scale)
485 |
486 | def inflate_weights(self, pretrained_path):
487 | t_state_dict = self.state_dict()
488 | s_state_dict = torch.load(pretrained_path)["model"]
489 | from collections import OrderedDict
490 | for key in t_state_dict.keys():
491 | if key not in s_state_dict:
492 | # print(key)
493 | continue
494 | if t_state_dict[key].shape != s_state_dict[key].shape:
495 | # print(t_state_dict[key].shape, s_state_dict[key].shape)
496 | t = t_state_dict[key].shape[2]
497 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t
498 | self.load_state_dict(s_state_dict, strict=False)
499 |
500 | def _init_weights(self, m):
501 | if isinstance(m, (nn.Conv3d, nn.Linear)):
502 | trunc_normal_(m.weight, std=.02)
503 | nn.init.constant_(m.bias, 0)
504 |
505 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1):
506 | if multi:
507 | xs = []
508 | for i in range(4):
509 | x = self.downsample_layers[i](x)
510 | x = self.stages[i](x)
511 | if multi:
512 | xs.append(x)
513 | if return_spatial:
514 | if multi:
515 | shape = xs[-1].shape[2:]
516 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1)
517 | elif layer > -1:
518 | return xs[layer]
519 | else:
520 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
521 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C)
522 |
523 | def forward(self, x, multi=False, layer=-1):
524 | x = self.forward_features(x, True, multi=multi, layer=layer)
525 | return x
526 |
527 |
528 | model_urls = {
529 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
530 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
531 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
532 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
533 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
534 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
535 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
536 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
537 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
538 | }
539 |
540 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
541 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
542 | if pretrained:
543 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
544 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
545 | model.load_state_dict(checkpoint["model"])
546 | return model
547 |
548 | def convnext_small(pretrained=False,in_22k=False, **kwargs):
549 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
550 | if pretrained:
551 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
552 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
553 | model.load_state_dict(checkpoint["model"])
554 | return model
555 |
556 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
557 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
558 | if pretrained:
559 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
560 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
561 | model.load_state_dict(checkpoint["model"])
562 | return model
563 |
564 |
565 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
566 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
567 | if pretrained:
568 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
569 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
570 | model.load_state_dict(checkpoint["model"])
571 | return model
572 |
573 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
574 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
575 | if pretrained:
576 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
577 | url = model_urls['convnext_xlarge_22k']
578 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
579 | model.load_state_dict(checkpoint["model"])
580 |
581 | return model
582 |
583 | def convnext_3d_tiny(pretrained=False, in_22k=False, **kwargs):
584 | # print("Using Imagenet 22K pretrain", in_22k)
585 | model = ConvNeXt3D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
586 | if pretrained:
587 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
588 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
589 | model.inflate_weights(checkpoint["model"])
590 | return model
591 |
592 | def convnext_3d_small(pretrained=False, in_22k=False, **kwargs):
593 | model = ConvNeXt3D(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
594 | if pretrained:
595 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
596 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
597 | model.inflate_weights(checkpoint["model"])
598 |
599 | return model
600 |
601 | def convnextv2_3d_atto(**kwargs):
602 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
603 |
604 | return model
605 |
606 | def convnextv2_3d_femto(pretrained="../pretrained/convnextv2_femto_1k_224_ema.pt", **kwargs):
607 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
608 | #model.inflate_weights(pretrained)
609 | return model
610 |
611 | def convnextv2_3d_pico(pretrained="../pretrained/convnextv2_pico_1k_224_ema.pt", **kwargs):
612 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
613 | #model.inflate_weights(pretrained)
614 | return model
615 |
616 | def convnextv2_3d_nano(pretrained="../pretrained/convnextv2_nano_1k_224_ema.pt", **kwargs):
617 | model = ConvNeXtV23D(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
618 | #model.inflate_weights(pretrained)
619 | return model
620 |
621 | def convnextv2_tiny(**kwargs):
622 | model = ConvNeXtV23D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
623 | return model
624 |
625 | def convnextv2_base(**kwargs):
626 | model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
627 | return model
628 |
629 | def convnextv2_large(**kwargs):
630 | model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
631 | return model
632 |
633 | def convnextv2_huge(**kwargs):
634 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
635 | return model
--------------------------------------------------------------------------------