├── checkpoints └── .keep ├── third_party ├── gmflow │ ├── models │ │ ├── __init__.py │ │ ├── position.py │ │ ├── utils.py │ │ ├── geometry.py │ │ ├── trident_conv.py │ │ ├── matching.py │ │ ├── backbone.py │ │ ├── gmflow.py │ │ └── transformer.py │ ├── config.yaml │ └── utils.py ├── dover │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── head.py │ │ ├── evaluator.py │ │ └── conv_backbone.py │ ├── config.yaml │ └── dataset.py └── viclip │ ├── __init__.py │ ├── image_transform.py │ ├── simple_tokenizer.py │ ├── viclip.py │ ├── viclip_text.py │ └── viclip_vision.py ├── asset └── taxonomy-repo.png ├── requirements.txt ├── doc ├── install.sh ├── leaderboard.md └── README.md ├── LICENSE ├── v2vbench ├── dover_score.py ├── clip_consistency.py ├── dino_consistency.py ├── dover_utils.py ├── aesthetic_score.py ├── clip_text_alignment.py ├── pick_score.py ├── viclip_text_alignment.py ├── dino_image_alignment.py ├── utils.py ├── base_evaluator.py ├── motion_alignment.py └── __init__.py └── .gitignore /checkpoints/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/gmflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gmflow import GMFlow -------------------------------------------------------------------------------- /asset/taxonomy-repo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenhao728/awesome-diffusion-v2v/HEAD/asset/taxonomy-repo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord 2 | torch~=1.13 3 | torchvision 4 | transformers>=4.32.1 5 | scipy 6 | numpy 7 | tqdm 8 | timm 9 | einops 10 | pandas 11 | simple-aesthetics-predictor -------------------------------------------------------------------------------- /third_party/dover/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 20:47:36 5 | @Desc : 6 | @Ref : https://github.com/VQAssessment/DOVER 7 | ''' 8 | -------------------------------------------------------------------------------- /third_party/gmflow/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | num_scales: 1 3 | upsample_factor: 8 4 | feature_channels: 128 5 | attention_type: swin 6 | num_transformer_layers: 6 7 | ffn_dim_expansion: 4 8 | num_head: 1 9 | 10 | data: 11 | dims: 12 | - 3 13 | - 512 14 | - 512 -------------------------------------------------------------------------------- /third_party/viclip/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 20:39:45 5 | @Desc : 6 | @Ref : 7 | https://github.com/OpenGVLab/InternVideo/tree/main 8 | https://github.com/Vchitect/VBench/tree/master 9 | ''' 10 | from .image_transform import clip_image_transform 11 | from .simple_tokenizer import SimpleTokenizer 12 | from .viclip import ViCLIP 13 | -------------------------------------------------------------------------------- /third_party/viclip/image_transform.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision.transforms import ( 3 | CenterCrop, 4 | Compose, 5 | InterpolationMode, 6 | Normalize, 7 | Resize, 8 | ToTensor, 9 | ) 10 | 11 | 12 | def clip_image_transform(n_px=224): 13 | return Compose([ 14 | Resize(n_px, interpolation=InterpolationMode.BICUBIC), 15 | CenterCrop(n_px), 16 | ToTensor(), 17 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 18 | ]) -------------------------------------------------------------------------------- /third_party/dover/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_backbone import convnext_3d_small, convnext_3d_tiny 2 | from .evaluator import DOVER, BaseEvaluator, BaseImageEvaluator 3 | from .head import IQAHead, VARHead, VQAHead 4 | from .swin_backbone import SwinTransformer2D as IQABackbone 5 | from .swin_backbone import SwinTransformer3D as VQABackbone 6 | from .swin_backbone import swin_3d_small, swin_3d_tiny 7 | 8 | __all__ = [ 9 | "VQABackbone", 10 | "IQABackbone", 11 | "VQAHead", 12 | "IQAHead", 13 | "VARHead", 14 | "BaseEvaluator", 15 | "BaseImageEvaluator", 16 | "DOVER", 17 | ] 18 | -------------------------------------------------------------------------------- /third_party/dover/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | technical: 3 | fragments_h: 7 4 | fragments_w: 7 5 | fsize_h: 32 6 | fsize_w: 32 7 | aligned: 32 8 | clip_len: 32 9 | frame_interval: 2 10 | num_clips: 3 11 | aesthetic: 12 | size_h: 224 13 | size_w: 224 14 | clip_len: 32 15 | frame_interval: 2 16 | t_frag: 32 17 | num_clips: 1 18 | 19 | model: 20 | backbone: 21 | technical: 22 | type: swin_tiny_grpb 23 | checkpoint: true 24 | pretrained: 25 | aesthetic: 26 | type: conv_tiny 27 | backbone_preserve_keys: technical,aesthetic 28 | divide_head: true 29 | vqa_head: 30 | in_channels: 768 31 | hidden_channels: 64 -------------------------------------------------------------------------------- /doc/install.sh: -------------------------------------------------------------------------------- 1 | # DOVER 2 | # install git lfs to pull the checkpoints from huggingface 3 | git lfs install 4 | git clone https://huggingface.co/teowu/DOVER checkpoints/ 5 | 6 | # ViCLIP 7 | # tokenizers 8 | wget https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz \ 9 | -P checkpoints/ViCLIP 10 | # model weights 11 | wget https://huggingface.co/OpenGVLab/VBench_Used_Models/blob/main/ViClip-InternVid-10M-FLT.pth \ 12 | -P checkpoints/ViCLIP 13 | 14 | # GMFlow 15 | # download the pretrained model from google drive 16 | gdown 1d5C5cgHIxWGsFR1vYs5XrQbbUiZl9TX2 -O checkpoints/ 17 | # unzip the model and move it to the correct directory 18 | unzip -n checkpoints/pretrained.zip -d checkpoints/ 19 | mv checkpoints/pretrained checkpoints/gmflow 20 | rm checkpoints/pretrained.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Wenhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/gmflow/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | 10 | def __init__(self, dims, mode='sintel', padding_factor=8): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor 13 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor 14 | if mode == 'sintel': 15 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 16 | else: 17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 18 | 19 | def pad(self, *inputs): 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self, x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 25 | return x[..., c[0]:c[1], c[2]:c[3]] 26 | 27 | 28 | def write_flow(flow: torch.Tensor, filename: Path) -> None: 29 | """Write optical flow to file 30 | 31 | Args: 32 | flow (torch.Tensor): Shape (T, 2, H, W) 33 | filename (Path): File to write optical flow 34 | """ 35 | filename.parent.mkdir(parents=True, exist_ok=True) 36 | flow = flow.cpu() 37 | torch.save(flow, filename) -------------------------------------------------------------------------------- /third_party/gmflow/models/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos -------------------------------------------------------------------------------- /v2vbench/dover_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 16:31:28 5 | @Desc : 6 | @Ref : https://github.com/VQAssessment/DOVER/tree/master 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | 11 | import torch 12 | from omegaconf import OmegaConf 13 | 14 | from third_party.dover.models import DOVER 15 | 16 | from .base_evaluator import BaseEvaluator 17 | from .dover_utils import DoverPreprocessor, fuse_results 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class DoverScore(BaseEvaluator): 22 | pretrained_config_file = (Path(__file__) / '../../third_party/dover/config.yaml').resolve().absolute() 23 | pretrained_checkpoint = (Path(__file__) / '../../checkpoints/DOVER/DOVER.pth').resolve().absolute() 24 | 25 | def __init__( 26 | self, 27 | index_file: Path, 28 | edit_video_dir: Path, 29 | reference_video_dir: Path, 30 | edit_prompt: str, 31 | device: torch.device, 32 | pretrained_config_file: str = None, 33 | pretrained_checkpoint: str = None, 34 | ): 35 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 36 | 37 | pretrained_config_file = pretrained_config_file or self.pretrained_config_file 38 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint 39 | 40 | logger.debug(f"Loding model {pretrained_checkpoint}") 41 | config = OmegaConf.to_container(OmegaConf.load(pretrained_config_file)) 42 | self.preprocessor = DoverPreprocessor(config["data"]) 43 | self.model = DOVER(**config['model']) 44 | self.model.to(self.device) 45 | self.model.load_state_dict(torch.load(pretrained_checkpoint, map_location=self.device)) 46 | self.model.eval() 47 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 48 | 49 | def range(self): 50 | return 0, 1 51 | 52 | def preprocess(self, sample): 53 | video = sample['edit_video'] 54 | 55 | views = self.preprocessor(video) 56 | for k, v in views.items(): 57 | views[k] = v.to(self.device) 58 | return views 59 | 60 | @torch.no_grad() 61 | def evaluate(self, views) -> float: 62 | results = [r.mean().item() for r in self.model(views)] 63 | # score 64 | scores = fuse_results(results) 65 | return scores -------------------------------------------------------------------------------- /v2vbench/clip_consistency.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/07 16:06:35 5 | @Desc : 6 | @Ref : 7 | https://github.com/openai/CLIP 8 | https://github.com/mlfoundations/open_clip 9 | https://huggingface.co/docs/transformers/model_doc/clip#clip 10 | ''' 11 | import logging 12 | from pathlib import Path 13 | 14 | import torch 15 | from transformers import CLIPImageProcessor, CLIPVisionModel 16 | 17 | from .base_evaluator import BaseEvaluator 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class ClipConsistency(BaseEvaluator): 22 | pretrained_model_name = 'openai/clip-vit-large-patch14' 23 | 24 | def __init__( 25 | self, 26 | index_file: Path, 27 | edit_video_dir: Path, 28 | reference_video_dir: Path, 29 | edit_prompt: str, 30 | device: torch.device, 31 | pretrained_model_name: str = None, 32 | ): 33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 35 | logger.debug(f"Loding model {pretrained_model_name}") 36 | self.preprocessor = CLIPImageProcessor.from_pretrained(pretrained_model_name) 37 | self.model = CLIPVisionModel.from_pretrained(pretrained_model_name) 38 | self.model.to(self.device) 39 | self.model.eval() 40 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 41 | 42 | def range(self): 43 | return 0, 1 44 | 45 | def preprocess(self, sample): 46 | video = sample['edit_video'] 47 | frames = [] 48 | for i, frame in enumerate(video): 49 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values) 50 | return frames 51 | 52 | @torch.no_grad() 53 | def evaluate(self, frames) -> float: 54 | similarity = [] 55 | former_feature = None 56 | for i, frame in enumerate(frames): 57 | frame = frame.to(self.device) 58 | feature: torch.Tensor = self.model(pixel_values=frame).pooler_output 59 | feature = feature / torch.norm(feature, dim=-1, keepdim=True) 60 | 61 | if i > 0: 62 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item()) 63 | similarity.append(sim) 64 | former_feature = feature 65 | return sum(similarity) / len(similarity) -------------------------------------------------------------------------------- /v2vbench/dino_consistency.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/07 16:06:35 5 | @Desc : 6 | @Ref : 7 | https://github.com/facebookresearch/dinov2 8 | https://huggingface.co/docs/transformers/model_doc/dinov2#dinov2 9 | ''' 10 | import logging 11 | from pathlib import Path 12 | 13 | import torch 14 | from transformers import BitImageProcessor, Dinov2Model 15 | 16 | from .base_evaluator import BaseEvaluator 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | class DinoConsistency(BaseEvaluator): 21 | pretrained_model_name = 'facebook/dinov2-base' 22 | 23 | def __init__( 24 | self, 25 | index_file: Path, 26 | edit_video_dir: Path, 27 | reference_video_dir: Path, 28 | edit_prompt: str, 29 | device: torch.device, 30 | pretrained_model_name: str = None, 31 | ): 32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 33 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 34 | logger.debug(f"Loding model {pretrained_model_name}") 35 | self.preprocessor = BitImageProcessor.from_pretrained(pretrained_model_name) 36 | self.model = Dinov2Model.from_pretrained(pretrained_model_name) 37 | self.model.to(self.device) 38 | self.model.eval() 39 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 40 | 41 | def range(self): 42 | return 0, 1 43 | 44 | def preprocess(self, sample): 45 | video = sample['edit_video'] 46 | frames = [] 47 | for i, frame in enumerate(video): 48 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values) 49 | return frames 50 | 51 | @torch.no_grad() 52 | def evaluate(self, frames) -> float: 53 | similarity = [] 54 | former_feature = None 55 | for i, frame in enumerate(frames): 56 | frame = frame.to(self.device) 57 | # pooled output (first image token) 58 | feature = self.model(pixel_values=frame).pooler_output 59 | feature: torch.Tensor = feature / torch.norm(feature, dim=-1, keepdim=True) 60 | 61 | if i > 0: 62 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item()) 63 | similarity.append(sim) 64 | former_feature = feature 65 | return sum(similarity) / len(similarity) -------------------------------------------------------------------------------- /v2vbench/dover_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 16:38:16 5 | @Desc : 6 | @Ref : https://github.com/VQAssessment/DOVER/tree/master 7 | ''' 8 | import logging 9 | from typing import Dict, List 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from third_party.dover.dataset import ( 16 | UnifiedFrameSampler, 17 | spatial_temporal_view_decomposition, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class DoverPreprocessor: 23 | mean = torch.FloatTensor([123.675, 116.28, 103.53]) 24 | std = torch.FloatTensor([58.395, 57.12, 57.375]) 25 | 26 | def __init__( 27 | self, 28 | sample_types: Dict[str, Dict[str, int]], 29 | ): 30 | self.sample_types = sample_types 31 | self.temporal_samplers = {} 32 | for stype, sopt in sample_types.items(): 33 | if "t_frag" not in sopt: 34 | # resized temporal sampling for TQE in DOVER 35 | self.temporal_samplers[stype] = UnifiedFrameSampler( 36 | sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"] 37 | ) 38 | else: 39 | # temporal sampling for AQE in DOVER 40 | self.temporal_samplers[stype] = UnifiedFrameSampler( 41 | sopt["clip_len"] // sopt["t_frag"], 42 | sopt["t_frag"], 43 | sopt["frame_interval"], 44 | sopt["num_clips"], 45 | ) 46 | 47 | def __call__(self, frames: List[Image.Image]): 48 | views, _ = spatial_temporal_view_decomposition( 49 | frames, self.sample_types, self.temporal_samplers 50 | ) 51 | 52 | for k, v in views.items(): 53 | v: torch.Tensor 54 | num_clips = self.sample_types[k].get("num_clips", 1) 55 | views[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2).reshape( 56 | v.shape[0], num_clips, -1, *v.shape[2:]).transpose(0, 1) 57 | return views 58 | 59 | 60 | def fuse_results(results: list): 61 | logger.debug(f'Before fuse: {results}') 62 | means = [0.1107, 0.08285] 63 | stds = [0.07355, 0.03774] 64 | weights = [0.6104, 0.3896] 65 | x = (results[0] - means[0]) / stds[0] * weights[0] + (results[1] + means[1]) / stds[1] * weights[1] 66 | return 1 / (1 + np.exp(-x)) # sigmoid -------------------------------------------------------------------------------- /v2vbench/aesthetic_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 15:54:42 5 | @Desc : 6 | @Ref : https://github.com/christophschuhmann/improved-aesthetic-predictor 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | 11 | import torch 12 | from aesthetics_predictor import AestheticsPredictorV1, AestheticsPredictorV2Linear 13 | from transformers import CLIPProcessor 14 | 15 | from .base_evaluator import BaseEvaluator 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class AestheticScore(BaseEvaluator): 20 | # pretrained_model_name = 'shunk031/aesthetics-predictor-v1-vit-large-patch14' # v1 21 | pretrained_model_name = 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' # v2 22 | 23 | def __init__( 24 | self, 25 | index_file: Path, 26 | edit_video_dir: Path, 27 | reference_video_dir: Path, 28 | edit_prompt: str, 29 | device: torch.device, 30 | pretrained_model_name: str = None, 31 | ): 32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 33 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 34 | logger.debug(f"Loading model {pretrained_model_name}") 35 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_model_name) 36 | # self.model = AestheticsPredictorV1.from_pretrained(pretrained_model_name) 37 | self.model = AestheticsPredictorV2Linear.from_pretrained(pretrained_model_name) 38 | self.model.to(self.device) 39 | self.model.eval() 40 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 41 | 42 | def range(self): 43 | return 0, 10 44 | 45 | def preprocess(self, sample): 46 | video = sample['edit_video'] 47 | frames = [] 48 | for i, frame in enumerate(video): 49 | frames.append(self.preprocessor(images=frame, return_tensors='pt').pixel_values) 50 | 51 | return frames 52 | 53 | @torch.no_grad() 54 | def evaluate(self, frames) -> float: 55 | score = [] 56 | for i, frame in enumerate(frames): 57 | if i == 0: 58 | logger.debug(f"Input shape: {frame.shape}") 59 | frame = frame.to(self.device) 60 | prediction = self.model(pixel_values=frame).logits 61 | if i == 0: 62 | logger.debug(f"Output shape: {prediction.shape}") 63 | score.append(prediction.squeeze().cpu().item()) 64 | return sum(score) / len(score) -------------------------------------------------------------------------------- /v2vbench/clip_text_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 18:58:20 5 | @Desc : 6 | @Ref : 7 | https://github.com/openai/CLIP 8 | https://github.com/mlfoundations/open_clip 9 | https://huggingface.co/docs/transformers/model_doc/clip#clip 10 | ''' 11 | import logging 12 | from pathlib import Path 13 | 14 | import torch 15 | from transformers import CLIPModel, CLIPProcessor 16 | 17 | from .base_evaluator import BaseEvaluator 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class ClipTextAlignment(BaseEvaluator): 22 | pretrained_model_name = 'openai/clip-vit-large-patch14' 23 | 24 | def __init__( 25 | self, 26 | index_file: Path, 27 | edit_video_dir: Path, 28 | reference_video_dir: Path, 29 | edit_prompt: str, 30 | device: torch.device, 31 | pretrained_model_name: str = None, 32 | ): 33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 35 | logger.debug(f"Loding model {pretrained_model_name}") 36 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_model_name) 37 | self.model = CLIPModel.from_pretrained(pretrained_model_name) 38 | self.model.to(self.device) 39 | self.model.eval() 40 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 41 | 42 | def range(self): 43 | return 0, 100 44 | 45 | def preprocess(self, sample): 46 | text = sample['edit_prompt'] 47 | video = sample['edit_video'] 48 | 49 | text_inputs = self.preprocessor( 50 | text=text, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device) 51 | 52 | image_inputs = [] 53 | for frame in video: 54 | image_inputs.append(self.preprocessor( 55 | images=frame, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)) 56 | 57 | return text_inputs, image_inputs 58 | 59 | @torch.no_grad() 60 | def evaluate(self, args) -> float: 61 | text_inputs, image_inputs = args 62 | text_embs = self.model.get_text_features(**text_inputs) 63 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) 64 | 65 | scores = [] 66 | for image_input in image_inputs: 67 | image_embs = self.model.get_image_features(**image_input) 68 | image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) 69 | score = (self.model.logit_scale.exp() * (text_embs @ image_embs.T)).cpu().squeeze().item() 70 | scores.append(score) 71 | 72 | return sum(scores) / len(scores) -------------------------------------------------------------------------------- /v2vbench/pick_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 16:07:53 5 | @Desc : 6 | @Ref : https://github.com/yuvalkirstain/PickScore 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | 11 | import torch 12 | from transformers import CLIPModel, CLIPProcessor 13 | 14 | from .base_evaluator import BaseEvaluator 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class PickScore(BaseEvaluator): 19 | pretrained_processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 20 | pretrained_model_name = 'yuvalkirstain/PickScore_v1' 21 | 22 | def __init__( 23 | self, 24 | index_file: Path, 25 | edit_video_dir: Path, 26 | reference_video_dir: Path, 27 | edit_prompt: str, 28 | device: torch.device, 29 | pretrained_processor_name: str = None, 30 | pretrained_model_name: str = None, 31 | ): 32 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 33 | pretrained_processor_name = pretrained_processor_name or self.pretrained_processor_name 34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 35 | 36 | logger.debug(f"Loding model {pretrained_model_name}") 37 | self.preprocessor = CLIPProcessor.from_pretrained(pretrained_processor_name) 38 | self.model = CLIPModel.from_pretrained(pretrained_model_name) 39 | self.model.to(self.device) 40 | self.model.eval() 41 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 42 | 43 | def range(self): 44 | return 0, 100 45 | 46 | def preprocess(self, sample): 47 | text = sample['edit_prompt'] 48 | video = sample['edit_video'] 49 | 50 | text_inputs = self.preprocessor( 51 | text=text, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device) 52 | 53 | image_inputs = [] 54 | for frame in video: 55 | image_inputs.append(self.preprocessor( 56 | images=frame, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)) 57 | 58 | return text_inputs, image_inputs 59 | 60 | @torch.no_grad() 61 | def evaluate(self, args) -> float: 62 | text_inputs, image_inputs = args 63 | text_embs = self.model.get_text_features(**text_inputs) 64 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) 65 | 66 | scores = [] 67 | for image_input in image_inputs: 68 | image_embs = self.model.get_image_features(**image_input) 69 | image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) 70 | score = (self.model.logit_scale.exp() * (text_embs @ image_embs.T)).cpu().squeeze().item() 71 | scores.append(score) 72 | 73 | return sum(scores) / len(scores) -------------------------------------------------------------------------------- /v2vbench/viclip_text_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 20:58:36 5 | @Desc : 6 | @Ref : 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | 11 | import torch 12 | 13 | from third_party.viclip import SimpleTokenizer, ViCLIP, clip_image_transform 14 | 15 | from .base_evaluator import BaseEvaluator 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class ViclipTextAlignment(BaseEvaluator): 20 | pretrained_tokenizer = ( 21 | Path(__file__) / '../../checkpoints/ViCLIP/bpe_simple_vocab_16e6.txt.gz').resolve().absolute() 22 | pretrained_checkpoint = ( 23 | Path(__file__) / '../../checkpoints/ViCLIP/ViClip-InternVid-10M-FLT.pth').resolve().absolute() 24 | 25 | def __init__( 26 | self, 27 | index_file: Path, 28 | edit_video_dir: Path, 29 | reference_video_dir: Path, 30 | edit_prompt: str, 31 | device: torch.device, 32 | pretrained_tokenizer: str = None, 33 | pretrained_checkpoint: str = None, 34 | stride: int = 3, # accept 8 frames as input 35 | ): 36 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device, stride=stride) 37 | pretrained_tokenizer = pretrained_tokenizer or self.pretrained_tokenizer 38 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint 39 | 40 | logger.debug(f"Loding model {pretrained_checkpoint}") 41 | tokenizer = SimpleTokenizer(bpe_path=pretrained_tokenizer) 42 | self.model = ViCLIP(tokenizer=tokenizer, pretrain=pretrained_checkpoint) 43 | self.model.to(self.device) 44 | self.model.eval() 45 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 46 | 47 | self.image_transform = clip_image_transform(224) 48 | 49 | def range(self): 50 | return 0, 1 51 | 52 | def preprocess(self, sample): 53 | text = sample['edit_prompt'] 54 | video = sample['edit_video'] 55 | 56 | text_inputs = text 57 | 58 | frames = [] 59 | for frame in video: 60 | frames.append(self.image_transform(frame)) 61 | video_inputs = torch.stack(frames).to(self.device)[None] # (1, T, C, H, W) 62 | logger.debug(f"video_inputs shape: {video_inputs.shape}") 63 | 64 | return text_inputs, video_inputs 65 | 66 | @torch.no_grad() 67 | def evaluate(self, args) -> float: 68 | text_inputs, video_inputs = args 69 | 70 | text_embs: torch.Tensor = self.model.encode_text(text_inputs).float() 71 | text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) 72 | logger.debug(f"text_embs shape: {text_embs.shape}") 73 | 74 | video_embs: torch.Tensor = self.model.encode_vision(video_inputs, test=True).float() 75 | video_embs = video_embs / torch.norm(video_embs, dim=-1, keepdim=True) 76 | logger.debug(f"video_embs shape: {video_embs.shape}") 77 | 78 | score = (text_embs @ video_embs.T).cpu().squeeze().item() 79 | 80 | return score -------------------------------------------------------------------------------- /v2vbench/dino_image_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/07 16:06:35 5 | @Desc : 6 | @Ref : 7 | https://github.com/facebookresearch/dinov2 8 | https://huggingface.co/docs/transformers/model_doc/dinov2#dinov2 9 | ''' 10 | import logging 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import torch 15 | from transformers import BitImageProcessor, Dinov2Model 16 | 17 | from .base_evaluator import BaseEvaluator 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class DinoImageAlignment(BaseEvaluator): 22 | pretrained_model_name = 'facebook/dinov2-base' 23 | 24 | def __init__( 25 | self, 26 | index_file: Path, 27 | edit_video_dir: Path, 28 | reference_video_dir: Path, 29 | edit_prompt: str, 30 | device: torch.device, 31 | pretrained_model_name: str = None, 32 | ): 33 | super().__init__(index_file, edit_video_dir, None, edit_prompt, device) 34 | pretrained_model_name = pretrained_model_name or self.pretrained_model_name 35 | logger.debug(f"Loding model {pretrained_model_name}") 36 | self.preprocessor = BitImageProcessor.from_pretrained(pretrained_model_name) 37 | self.model = Dinov2Model.from_pretrained(pretrained_model_name) 38 | self.model.to(self.device) 39 | self.model.eval() 40 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 41 | 42 | def range(self): 43 | return 0, 1 44 | 45 | def preprocess(self, sample): 46 | video = sample['edit_video'] 47 | frames = [] 48 | for i, frame in enumerate(video): 49 | frames.append(self.preprocessor(frame, return_tensors='pt').pixel_values) 50 | reference_image = self.preprocessor(sample['reference_image'], return_tensors='pt').pixel_values 51 | return frames, reference_image 52 | 53 | @torch.no_grad() 54 | def evaluate(self, args) -> float: 55 | frames, reference_image = args 56 | similarity = [] 57 | conformity = [] 58 | 59 | reference_image = reference_image.to(self.device) 60 | reference_feature = self.model(pixel_values=reference_image).pooler_output 61 | reference_feature: torch.Tensor = reference_feature / torch.norm(reference_feature, dim=-1, keepdim=True) 62 | 63 | former_feature = None 64 | for i, frame in enumerate(frames): 65 | frame = frame.to(self.device) 66 | # pooled output (first image token) 67 | feature = self.model(pixel_values=frame).pooler_output 68 | feature: torch.Tensor = feature / torch.norm(feature, dim=-1, keepdim=True) 69 | 70 | if i > 0: 71 | sim = max(0, (feature @ former_feature.T).cpu().squeeze().item()) 72 | similarity.append(sim) 73 | former_feature = feature 74 | 75 | conformity.append(max(0, (feature @ reference_feature.T).cpu().squeeze().item())) 76 | 77 | max_weight = 0.4 78 | mean_weight = 0.3 79 | min_weight = 0.3 80 | return float( 81 | max_weight * np.max(conformity) + 82 | mean_weight * np.mean(similarity) + 83 | min_weight * np.min(similarity) 84 | ) -------------------------------------------------------------------------------- /third_party/gmflow/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .position import PositionEmbeddingSine 3 | 4 | 5 | def split_feature(feature, 6 | num_splits=2, 7 | channel_last=False, 8 | ): 9 | if channel_last: # [B, H, W, C] 10 | b, h, w, c = feature.size() 11 | assert h % num_splits == 0 and w % num_splits == 0 12 | 13 | b_new = b * num_splits * num_splits 14 | h_new = h // num_splits 15 | w_new = w // num_splits 16 | 17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c 18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] 19 | else: # [B, C, H, W] 20 | b, c, h, w = feature.size() 21 | assert h % num_splits == 0 and w % num_splits == 0 22 | 23 | b_new = b * num_splits * num_splits 24 | h_new = h // num_splits 25 | w_new = w // num_splits 26 | 27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits 28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] 29 | 30 | return feature 31 | 32 | 33 | def merge_splits(splits, 34 | num_splits=2, 35 | channel_last=False, 36 | ): 37 | if channel_last: # [B*K*K, H/K, W/K, C] 38 | b, h, w, c = splits.size() 39 | new_b = b // num_splits // num_splits 40 | 41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( 43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] 44 | else: # [B*K*K, C, H/K, W/K] 45 | b, c, h, w = splits.size() 46 | new_b = b // num_splits // num_splits 47 | 48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( 50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] 51 | 52 | return merge 53 | 54 | 55 | def normalize_img(img0, img1): 56 | # loaded images are in [0, 255] 57 | # normalize by ImageNet mean and std 58 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) 59 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) 60 | img0 = (img0 / 255. - mean) / std 61 | img1 = (img1 / 255. - mean) / std 62 | 63 | return img0, img1 64 | 65 | 66 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 67 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 68 | 69 | if attn_splits > 1: # add position in splited window 70 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 71 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 72 | 73 | position = pos_enc(feature0_splits) 74 | 75 | feature0_splits = feature0_splits + position 76 | feature1_splits = feature1_splits + position 77 | 78 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 79 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 80 | else: 81 | position = pos_enc(feature0) 82 | 83 | feature0 = feature0 + position 84 | feature1 = feature1 + position 85 | 86 | return feature0, feature1 -------------------------------------------------------------------------------- /v2vbench/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/07 14:49:35 5 | @Desc : 6 | @Ref : 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | from typing import List 11 | 12 | import decord 13 | from PIL import Image, ImageSequence 14 | 15 | decord.bridge.set_bridge('native') 16 | logger = logging.getLogger(__name__) 17 | _supported_image_suffix = ['.jpg', '.jpeg', '.png'] 18 | _supported_video_suffix = ['.mp4', '.gif'] 19 | 20 | 21 | def _load_video_from_image_dir(video_dir: Path) -> List[Image.Image]: 22 | logger.debug(f'Loading video from image directory: {video_dir}') 23 | frames = [] 24 | for file in sorted(video_dir.iterdir()): 25 | if file.suffix not in _supported_image_suffix: 26 | logger.debug(f'Skipping file: {file}') 27 | continue 28 | frame = Image.open(file).convert('RGB') 29 | frames.append(frame) 30 | if not frames: 31 | raise FileNotFoundError(f'No image found in {video_dir}, supported image suffix: {_supported_image_suffix}') 32 | 33 | return frames 34 | 35 | 36 | def _load_video_from_video_file(video_file: Path) -> List[Image.Image]: 37 | logger.debug(f'Loading video from video file: {video_file}') 38 | if video_file.suffix == '.mp4': 39 | video_reader = decord.VideoReader(str(video_file), num_threads=1) 40 | frames = [] 41 | for i in range(len(video_reader)): 42 | frames.append(Image.fromarray(video_reader[i].asnumpy())) 43 | return frames 44 | 45 | elif video_file.suffix == '.gif': 46 | frames = [] 47 | for f in ImageSequence.Iterator(Image.open(video_file)): 48 | frame = f.convert('RGB') 49 | frames.append(frame) 50 | return frames 51 | 52 | else: 53 | raise NotImplementedError( 54 | f'Unsupported video file: {video_file}, supported suffix: {_supported_video_suffix}') 55 | 56 | 57 | def load_video(video_file_or_dir: Path, start_frame: int = 0, stride: int = 1) -> List[Image.Image]: 58 | """ 59 | Args: 60 | video_file_or_dir (Path): path to video file or directory containing images 61 | start_frame (int): start frame index 62 | stride (int): stride for frame sampling 63 | Returns: 64 | List[Image.Image]: list of frames, RGB 65 | """ 66 | if not video_file_or_dir.exists(): 67 | logger.debug(f'Video file or directory does not exist: {video_file_or_dir}, trying to find alternative') 68 | for suffix in _supported_video_suffix: 69 | if (video_file_or_dir.with_suffix(suffix)).exists(): 70 | video_file_or_dir = video_file_or_dir.with_suffix(suffix) 71 | logger.debug(f'Found video file: {video_file_or_dir}') 72 | break 73 | else: 74 | raise FileNotFoundError(f'Reference video: {video_file_or_dir} does not exist') 75 | 76 | if video_file_or_dir.is_dir(): 77 | buffer = _load_video_from_image_dir(video_file_or_dir) 78 | elif video_file_or_dir.is_file(): 79 | buffer = _load_video_from_video_file(video_file_or_dir) 80 | else: 81 | # should not reach here 82 | raise NotImplementedError(f'{video_file_or_dir} is not a valid file or directory') 83 | 84 | video = buffer[start_frame::stride] 85 | logger.debug(f'Raw video frames: {len(buffer)}, sampled video frames: {len(video)}') 86 | logger.debug(f'Frame size: {video[0].size}') 87 | return video -------------------------------------------------------------------------------- /third_party/gmflow/models/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 25 | assert device is not None 26 | 27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 28 | torch.linspace(h_min, h_max, len_h, device=device)], 29 | ) 30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 31 | 32 | return grid 33 | 34 | 35 | def normalize_coords(coords, h, w): 36 | # coords: [B, H, W, 2] 37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 38 | return (coords - c) / c # [-1, 1] 39 | 40 | 41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 42 | # img: [B, C, H, W] 43 | # sample_coords: [B, 2, H, W] in image scale 44 | if sample_coords.size(1) != 2: # [B, H, W, 2] 45 | sample_coords = sample_coords.permute(0, 3, 1, 2) 46 | 47 | b, _, h, w = sample_coords.shape 48 | 49 | # Normalize to [-1, 1] 50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 52 | 53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 54 | 55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 56 | 57 | if return_mask: 58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 59 | 60 | return img, mask 61 | 62 | return img 63 | 64 | 65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 66 | b, c, h, w = feature.size() 67 | assert flow.size(1) == 2 68 | 69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 70 | 71 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 72 | return_mask=mask) 73 | 74 | 75 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 76 | alpha=0.01, 77 | beta=0.5 78 | ): 79 | # fwd_flow, bwd_flow: [B, 2, H, W] 80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 84 | 85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 87 | 88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 90 | 91 | threshold = alpha * flow_mag + beta 92 | 93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 94 | bwd_occ = (diff_bwd > threshold).float() 95 | 96 | return fwd_occ, bwd_occ -------------------------------------------------------------------------------- /third_party/gmflow/models/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs -------------------------------------------------------------------------------- /third_party/dover/models/head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class VQAHead(nn.Module): 5 | """MLP Regression Head for VQA. 6 | Args: 7 | in_channels: input channels for MLP 8 | hidden_channels: hidden channels for MLP 9 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 10 | pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes) 11 | """ 12 | 13 | def __init__( 14 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, **kwargs 15 | ): 16 | super().__init__() 17 | self.dropout_ratio = dropout_ratio 18 | self.in_channels = in_channels 19 | self.hidden_channels = hidden_channels 20 | self.pre_pool = pre_pool 21 | if self.dropout_ratio != 0: 22 | self.dropout = nn.Dropout(p=self.dropout_ratio) 23 | else: 24 | self.dropout = None 25 | self.fc_hid = nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1)) 26 | self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1)) 27 | self.gelu = nn.GELU() 28 | 29 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 30 | 31 | def forward(self, x, rois=None): 32 | if self.pre_pool: 33 | x = self.avg_pool(x) 34 | x = self.dropout(x) 35 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) 36 | return qlt_score 37 | 38 | 39 | 40 | 41 | 42 | class VARHead(nn.Module): 43 | """MLP Regression Head for Video Action Recognition. 44 | Args: 45 | in_channels: input channels for MLP 46 | hidden_channels: hidden channels for MLP 47 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 48 | """ 49 | 50 | def __init__(self, in_channels=768, out_channels=400, dropout_ratio=0.5, **kwargs): 51 | super().__init__() 52 | self.dropout_ratio = dropout_ratio 53 | self.in_channels = in_channels 54 | self.out_channels = out_channels 55 | if self.dropout_ratio != 0: 56 | self.dropout = nn.Dropout(p=self.dropout_ratio) 57 | else: 58 | self.dropout = None 59 | self.fc = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1)) 60 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 61 | 62 | def forward(self, x, rois=None): 63 | x = self.dropout(x) 64 | x = self.avg_pool(x) 65 | out = self.fc(x) 66 | return out 67 | 68 | 69 | class IQAHead(nn.Module): 70 | """MLP Regression Head for IQA. 71 | Args: 72 | in_channels: input channels for MLP 73 | hidden_channels: hidden channels for MLP 74 | dropout_ratio: the dropout ratio for features before the MLP (default 0.5) 75 | """ 76 | 77 | def __init__( 78 | self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, **kwargs 79 | ): 80 | super().__init__() 81 | self.dropout_ratio = dropout_ratio 82 | self.in_channels = in_channels 83 | self.hidden_channels = hidden_channels 84 | if self.dropout_ratio != 0: 85 | self.dropout = nn.Dropout(p=self.dropout_ratio) 86 | else: 87 | self.dropout = None 88 | self.fc_hid = nn.Linear(self.in_channels, self.hidden_channels) 89 | self.fc_last = nn.Linear(self.hidden_channels, 1) 90 | self.gelu = nn.GELU() 91 | 92 | def forward(self, x): 93 | x = self.dropout(x) 94 | qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) 95 | return qlt_score 96 | -------------------------------------------------------------------------------- /third_party/gmflow/models/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .geometry import coords_grid, generate_window_grid, normalize_coords 5 | 6 | 7 | def global_correlation_softmax(feature0, feature1, 8 | pred_bidir_flow=False, 9 | ): 10 | # global correlation 11 | b, c, h, w = feature0.shape 12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] 13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W] 14 | 15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] 16 | 17 | # flow from softmax 18 | init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] 19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 20 | 21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] 22 | 23 | if pred_bidir_flow: 24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] 25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] 26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] 27 | b = b * 2 28 | 29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] 30 | 31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 32 | 33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow 34 | flow = correspondence - init_grid 35 | 36 | return flow, prob 37 | 38 | 39 | def local_correlation_softmax(feature0, feature1, local_radius, 40 | padding_mode='zeros', 41 | ): 42 | b, c, h, w = feature0.size() 43 | coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] 44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 45 | 46 | local_h = 2 * local_radius + 1 47 | local_w = 2 * local_radius + 1 48 | 49 | window_grid = generate_window_grid(-local_radius, local_radius, 50 | -local_radius, local_radius, 51 | local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] 52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] 53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] 54 | 55 | sample_coords_softmax = sample_coords 56 | 57 | # exclude coords that are out of image space 58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] 59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] 60 | 61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax 62 | 63 | # normalize coordinates to [-1, 1] 64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] 65 | window_feature = F.grid_sample(feature1, sample_coords_norm, 66 | padding_mode=padding_mode, align_corners=True 67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] 68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] 69 | 70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] 71 | 72 | # mask invalid locations 73 | corr[~valid] = -1e9 74 | 75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] 76 | 77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( 78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 79 | 80 | flow = correspondence - coords_init 81 | match_prob = prob 82 | 83 | return flow, match_prob -------------------------------------------------------------------------------- /doc/leaderboard.md: -------------------------------------------------------------------------------- 1 | # 📈 Leaderboard 2 | The best results within each category are italicized, and the globally best metrics are underlined*. 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 |
Method Frames
Quality⬆
Semantic
Consistency⬆
Object
Consistency⬆
Video
Quality⬆
Frames Text
Alignment⬆
Frames
Pick Score⬆
Video Text
Alignment⬆
Motion
Alignment⬆
Network and
Training Paradigm
Tune-A-Video 5.001 0.934 0.917 0.527 27.513 20.701 0.254 -5.599
SimDA 4.988 0.940 0.929 0.569 26.773 20.512 0.248 -4.756
VidToMe 4.988 0.949 0.945 0.656 26.813 20.546 0.240 -3.203
VideoComposer 4.429 0.914 0.905 0.370 28.001 20.272 0.262 -8.095
MotionDirector 4.984 0.940 0.951 0.617 27.845 20.923 0.262 -3.088
Attention
Feature Injection
Video-P2P 4.907 0.943 0.926 0.471 23.550 19.751 0.193 -5.974
Vid2Vid-Zero 5.103 0.919 0.912 0.638 28.789 20.950 0.270 -4.175
Fate-Zero 5.036 0.951 0.952 * 0.704 25.065 20.707 0.225 -1.439*
TokenFlow 5.068 0.947 0.943 0.715 27.522 20.757 0.254 -1.572
FLATTEN 4.965 0.943 0.949 0.645 27.156 20.745 0.251 -1.446
FRESCO 5.127 0.908 0.896 0.689 25.639 20.239 0.223 -5.241
Diffusion
Latent
Manipulation
Text2Video-Zero 5.097 0.899 0.894 0.613 29.124* 20.568 0.265 -17.226
Pix2Video 5.075 0.946 0.944 0.638 28.731 21.054* 0.271 * -2.889
ControlVideo 5.404 * 0.959 * 0.948 0.674 28.551 20.961 0.261 -9.396
Rerender 5.002 0.872 0.863 0.724 * 27.379 20.460 0.261 -4.959
RAVE 5.077 0.926 0.936 0.664 28.190 20.865 0.255 -2.398
-------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # General 165 | .DS_Store 166 | .AppleDouble 167 | .LSOverride 168 | 169 | # Icon must end with two \r 170 | Icon 171 | 172 | # Thumbnails 173 | ._* 174 | 175 | # Files that might appear in the root of a volume 176 | .DocumentRevisions-V100 177 | .fseventsd 178 | .Spotlight-V100 179 | .TemporaryItems 180 | .Trashes 181 | .VolumeIcon.icns 182 | .com.apple.timemachine.donotpresent 183 | 184 | # Directories potentially created on remote AFP share 185 | .AppleDB 186 | .AppleDesktop 187 | Network Trash Folder 188 | Temporary Items 189 | .apdisk 190 | -------------------------------------------------------------------------------- /v2vbench/base_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/26 13:48:08 5 | @Desc : 6 | @Ref : 7 | ''' 8 | import logging 9 | from abc import ABC, abstractmethod 10 | from collections import defaultdict 11 | from pathlib import Path 12 | from typing import Tuple, Union 13 | 14 | import torch 15 | from omegaconf import OmegaConf 16 | from tqdm.auto import tqdm 17 | 18 | from .utils import load_video 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class BaseEvaluator(ABC): 24 | def __init__( 25 | self, 26 | index_file: Path, 27 | edit_video_dir: Path, 28 | reference_video_dir: Path, 29 | edit_prompt: str, 30 | device: torch.device, 31 | stride: int = 1, 32 | ): 33 | self.index_file = index_file 34 | self.edit_video_dir = edit_video_dir 35 | self.reference_video_dir = reference_video_dir 36 | self.edit_prompt = edit_prompt 37 | self.device = device 38 | # stride for loading video frames, default is 1 - load all frames 39 | # if stride is 2, load every other frame, if stride is 3, load every third frame, etc. 40 | self.stride = stride 41 | 42 | @property 43 | @abstractmethod 44 | def range(self) -> Tuple[Union[float, int]]: 45 | """ 46 | Returns: 47 | Tuple[Union[float, int]]: the worse and best value of the metric 48 | """ 49 | raise NotImplementedError 50 | 51 | @property 52 | def direction(self) -> str: 53 | return 'maximize' if self.range[0] < self.range[1] else 'minimize' 54 | 55 | def _sample_data_iter(self): 56 | # for single video, the video_dir is the file path to the video 57 | if self.reference_video_dir is not None: 58 | reference_video_path = self.reference_video_dir 59 | reference_video = load_video(reference_video_path, stride=self.stride) 60 | else: 61 | reference_video_path = None 62 | reference_video = None 63 | edit_video = load_video(self.edit_video_dir, stride=self.stride) 64 | yield { 65 | # for single video, the reference video may not be available 66 | # use the edit_video name as video_id 67 | 'video_id': self.edit_video_dir.with_suffix('').name, 68 | # 'prompt': sample['prompt'], 69 | 'reference_video': reference_video, 70 | 'reference_video_path': reference_video_path, 71 | 'edit_id': 0, 72 | 'edit_prompt': self.edit_prompt, 73 | 'edit_video': edit_video, 74 | } 75 | 76 | def _batch_data_iter(self): 77 | config: OmegaConf = OmegaConf.load(self.index_file) 78 | for sample in config['data']: 79 | # reference video is optional 80 | reference_video_path = None 81 | reference_video = None 82 | 83 | if self.reference_video_dir is not None: 84 | reference_video_path = self.reference_video_dir / sample['video_id'] 85 | reference_video = load_video(reference_video_path, stride=self.stride) 86 | 87 | # iterate over edit 88 | for edit_id, edit in enumerate(sample['edit']): 89 | edit_video_path = self.edit_video_dir / sample['video_id'] / str(edit_id) 90 | edit_video = load_video(edit_video_path, stride=self.stride) 91 | yield { 92 | 'video_id': sample['video_id'], 93 | # 'prompt': sample['prompt'], 94 | 'reference_video': reference_video, 95 | 'reference_video_path': reference_video_path if reference_video is not None else None, 96 | 'edit_id': edit_id, 97 | 'edit_prompt': edit['prompt'], 98 | 'edit_video': edit_video, 99 | } 100 | 101 | def _data_iter(self): 102 | if self.index_file is not None: 103 | return self._batch_data_iter() 104 | else: 105 | return self._sample_data_iter() 106 | 107 | def preprocess(self, sample): 108 | return sample 109 | 110 | @abstractmethod 111 | def evaluate(self, sample) -> float: 112 | raise NotImplementedError 113 | 114 | def __call__(self): 115 | results = defaultdict(list) 116 | for sample in tqdm(self._data_iter()): 117 | score = self.evaluate(self.preprocess(sample)) 118 | 119 | # results['method'].append(self.method) 120 | results['video_id'].append(sample['video_id']) 121 | results['edit_id'].append(sample['edit_id']) 122 | results['score'].append(score) 123 | return results -------------------------------------------------------------------------------- /third_party/gmflow/models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out -------------------------------------------------------------------------------- /v2vbench/motion_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/27 12:30:53 5 | @Desc : 6 | @Ref : 7 | ''' 8 | import logging 9 | from pathlib import Path 10 | from typing import List 11 | 12 | import numpy as np 13 | import torch 14 | from omegaconf import OmegaConf 15 | 16 | from third_party.gmflow.models import GMFlow 17 | from third_party.gmflow.utils import InputPadder, write_flow 18 | 19 | from .base_evaluator import BaseEvaluator 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | class MotionAlignment(BaseEvaluator): 24 | pretrained_config_file = ( 25 | Path(__file__) / '../../third_party/gmflow/config.yaml').resolve().absolute() 26 | pretrained_checkpoint = ( 27 | Path(__file__) / '../../checkpoints/gmflow/gmflow_sintel-0c07dcb3.pth').resolve().absolute() 28 | 29 | def __init__( 30 | self, 31 | index_file: Path, 32 | edit_video_dir: Path, 33 | reference_video_dir: Path, 34 | edit_prompt: str, 35 | device: torch.device, 36 | pretrained_config_file: str = None, 37 | pretrained_checkpoint: str = None, 38 | cache_flow: bool = False, 39 | ): 40 | super().__init__(index_file, edit_video_dir, reference_video_dir, edit_prompt, device) 41 | pretrained_config_file = pretrained_config_file or self.pretrained_config_file 42 | pretrained_checkpoint = pretrained_checkpoint or self.pretrained_checkpoint 43 | 44 | logger.debug(f"Loding model {pretrained_checkpoint}") 45 | config = OmegaConf.to_container(OmegaConf.load(pretrained_config_file)) 46 | self.padder = InputPadder(**config['data']) 47 | self.model = GMFlow(**config['model']) 48 | self.model.to(self.device) 49 | self.model.load_state_dict(torch.load(pretrained_checkpoint, map_location=self.device)['model']) 50 | self.model.eval() 51 | logger.debug(f"Model {self.model.__class__.__name__} loaded") 52 | 53 | self.cache_flow = cache_flow 54 | 55 | def range(self): 56 | return -1 * np.inf, 0 57 | 58 | def get_flow_file(self, reference_video_path: Path): 59 | if reference_video_path is None: 60 | return Path('') 61 | 62 | video_id = reference_video_path.stem 63 | flow_dir = reference_video_path.parent.parent / 'flow' 64 | return flow_dir / f'{video_id}.pt' 65 | 66 | def preprocess(self, sample): 67 | # load edit frames 68 | edit_frames = [] 69 | for i, frame in enumerate(sample['edit_video']): 70 | frame = torch.from_numpy(np.array(frame).astype(np.uint8)).permute(2, 0, 1).float().to(self.device) 71 | frame = self.padder.pad(frame)[0][None] # (B, C, H, W) 72 | edit_frames.append(frame) 73 | 74 | reference_frames = [] 75 | reference_flow = None 76 | reference_flow_file = self.get_flow_file(sample['reference_video_path']) 77 | 78 | # try to load cached flow, only for batch evaluation 79 | try: 80 | logger.debug(f"Loading cached flow from {reference_flow_file}") 81 | reference_flow = torch.load(reference_flow_file, map_location=self.device) 82 | # cached flow not found, compute flow 83 | except FileNotFoundError or AttributeError: 84 | reference_flow_file = None 85 | logger.debug(f"Flow file not found, loading reference frames.") 86 | for i, frame in enumerate(sample['reference_video']): 87 | frame = torch.from_numpy(np.array(frame).astype(np.uint8)).permute(2, 0, 1).float().to(self.device) 88 | frame = self.padder.pad(frame)[0][None] # (B, C, H, W) 89 | reference_frames.append(frame) 90 | 91 | return edit_frames, reference_frames, reference_flow_file, reference_flow 92 | 93 | @torch.no_grad() 94 | def extract_flow(self, frames: List[torch.Tensor]): 95 | flows = [] 96 | 97 | for i, frame in enumerate(frames): 98 | if i > 0: 99 | flow_pr = self.model( 100 | prev_frame, frame, 101 | attn_splits_list=[2], 102 | corr_radius_list=[-1], 103 | prop_radius_list=[-1], 104 | )['flow_preds'][-1][0] # (2, H, W) 105 | flows.append(flow_pr) 106 | if i == 1: 107 | logger.debug(f"Flow shape: {flow_pr.shape}") 108 | prev_frame = frame 109 | 110 | return torch.stack(flows, dim=0) # (T-1, 2, H, W) 111 | 112 | @torch.no_grad() 113 | def evaluate(self, args) -> float: 114 | edit_frames, reference_frames, reference_flow_file, reference_flow = args 115 | edit_flow = self.extract_flow(edit_frames) 116 | 117 | # calculate flow for reference video if not cached 118 | if reference_flow_file is None: 119 | reference_flow = self.extract_flow(reference_frames) 120 | if self.cache_flow: 121 | write_flow(reference_flow, reference_flow_file) 122 | logger.debug(f"Flow saved to {reference_flow_file}") 123 | 124 | # calculate alignment score: EPE 125 | score = torch.sum((edit_flow - reference_flow) ** 2, dim=1).sqrt().mean().cpu().item() 126 | return - score 127 | -------------------------------------------------------------------------------- /third_party/viclip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | import subprocess 5 | from functools import lru_cache 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | # def default_bpe(): 11 | # tokenizer_file = os.path.join('.checkpoints', "ViCLIP/bpe_simple_vocab_16e6.txt.gz") 12 | # if not os.path.exists(tokenizer_file): 13 | # print(f'Downloading ViCLIP tokenizer to {tokenizer_file}') 14 | # wget_command = ['wget', 'https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz', '-P', os.path.dirname(tokenizer_file)] 15 | # subprocess.run(wget_command) 16 | # return tokenizer_file 17 | 18 | 19 | @lru_cache() 20 | def bytes_to_unicode(): 21 | """ 22 | Returns list of utf-8 byte and a corresponding list of unicode strings. 23 | The reversible bpe codes work on unicode strings. 24 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 25 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 26 | This is a signficant percentage of your normal, say, 32K bpe vocab. 27 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 28 | And avoids mapping to whitespace/control characters the bpe code barfs on. 29 | """ 30 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 31 | cs = bs[:] 32 | n = 0 33 | for b in range(2**8): 34 | if b not in bs: 35 | bs.append(b) 36 | cs.append(2**8+n) 37 | n += 1 38 | cs = [chr(n) for n in cs] 39 | return dict(zip(bs, cs)) 40 | 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | 54 | def basic_clean(text): 55 | text = ftfy.fix_text(text) 56 | text = html.unescape(html.unescape(text)) 57 | return text.strip() 58 | 59 | 60 | def whitespace_clean(text): 61 | text = re.sub(r'\s+', ' ', text) 62 | text = text.strip() 63 | return text 64 | 65 | 66 | class SimpleTokenizer(object): 67 | def __init__(self, bpe_path: str): 68 | self.byte_encoder = bytes_to_unicode() 69 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 70 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 71 | merges = merges[1:49152-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + ( token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token+'' 92 | 93 | while True: 94 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # 📈 V2VBench 2 | 3 | V2VBench is a comprehensive benchmark designed to evaluate video editing methods. It consists of: 4 | - 50 standardized videos across 5 categories, and 5 | - 3 editing prompts per video, encompassing 4 editing tasks: [Huggingface Datasets](https://huggingface.co/datasets/Wenhao-Sun/V2VBench) 6 | - 8 evaluation metrics to assess the quality of edited videos: [Evaluation Metrics](doc/README.md) 7 | 8 | For detailed information, please refer to the accompanying paper. 9 | 10 | 11 | ## Installation 12 | Clone the repository: 13 | ```bash 14 | git clone https://github.com/wenhao728/awesome-diffusion-v2v.git 15 | ``` 16 | 17 | (Optional) We recommend using a virtual environment to manage the dependencies. You can refer [virtualenv](https://virtualenv.pypa.io/en/latest/user_guide.html) or [conda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#activating-an-environment) to create a virtual environment. 18 | 19 | Install the required packages: 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | Download the pre-trained models: 25 | ```bash 26 | sh doc/install.sh 27 | ``` 28 | The `./checkpoints` directory should contain the following files: 29 | ```diff 30 | checkpoints 31 | + ├── DOVER 32 | + │   └── DOVER.pth 33 | + ├── gmflow 34 | + │   └── gmflow_sintel-0c07dcb3.pth 35 | + └── ViCLIP 36 | +    ├── bpe_simple_vocab_16e6.txt.gz 37 | +    └── ViClip-InternVid-10M-FLT.pth 38 | ``` 39 | 40 | ## Run Evaluation 41 | ### Video format 42 | Videos can be provided either as "video files" (e.g. `video_name.mp4`, `video_name.gif`) or "directories containing video frame images" with filenames indicating the frame number (e.g. `video_name/01.jpg`, `video_name/02.jpg`, ...). 43 | 44 | ### Single Video Evaluation 45 | Run the following snippet to evaluate one edited video: 46 | ```python 47 | import logging 48 | 49 | from v2vbench import EvaluatorWrapper 50 | 51 | logging.basicConfig( 52 | level=logging.INFO, # Change it to logging.DEBUG if you want to troubleshoot 53 | format="%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)", 54 | handlers=[logging.StreamHandler(sys.stdout),] 55 | ) 56 | 57 | evaluation = EvaluatorWrapper(metrics='all') # 'all' for all metrics 58 | # print(EvaluatorWrapper.all_metrics) # list all available metrics 59 | 60 | results = evaluation.evaluate( 61 | edit_video='/path/to/edited_video', 62 | reference_video='/path/to/source_video', 63 | edit_prompt='', 64 | output_dir='/path/to/save/results', 65 | ) 66 | ``` 67 | 68 | ### Batch Evaluation 69 | Check the [Prepare Data](#optional-prepare-data) section to prepare the data for batch evaluation. 70 | Then, simply run the following snippet to evaluate a batch of edited videos: 71 | ```python 72 | import logging 73 | 74 | from v2vbench import EvaluatorWrapper 75 | 76 | logging.basicConfig( 77 | level=logging.INFO, # Change it to logging.DEBUG if you want to troubleshoot 78 | format="%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)", 79 | handlers=[logging.StreamHandler(sys.stdout),] 80 | ) 81 | 82 | evaluation = EvaluatorWrapper(metrics='all') # 'all' for all metrics 83 | # print(EvaluatorWrapper.all_metrics) # list all available metrics 84 | 85 | results = evaluation.evaluate( 86 | edit_video='/path/to/edited_videos', 87 | reference_video='/path/to/source_videos', 88 | index_file='/path/to/config.yaml', 89 | output_dir='/path/to/save/results', 90 | # it is recommended to cache the flow of source videos for motion_alignment to avoid redundant computation 91 | evaluator_kwargs={'motion_alignment': {'cache_flow': True}}, 92 | ) 93 | ``` 94 | 95 | ## (Optional) Prepare Data 96 | To evaluate a batch of editing results, the data should be organized in specified formats. 97 | You can download the V2VBench dataset from [Huggingface Datasets](https://huggingface.co/datasets/Wenhao-Sun/V2VBench) or use your customized data with the following format. 98 | 99 | ### Configuration 100 | The configuration file is a YAML file. And each video should be one entry of the `data` list. The following is an example of one entry: 101 | ```yaml 102 | video_id: hike # source video id 103 | edit: 104 | - prompt: a superman with a backpack hikes through the desert # edit prompt 0 105 | - prompt: a man with a backpack hikes through the dolomites, pixel art # edit prompt 1 106 | # ... more edit prompts 107 | ``` 108 | ### Source Videos 109 | The source videos for reference should be named according to the `video_id` specified in the configuration file. 110 | 111 | ### Edited Videos 112 | For evaluation, the edited videos should be placed in the directory named according to the `video_id` specified in the configuration file. Each edited video should be named according to the edit prompt index. 113 | 114 | 115 | ## Shoutouts 116 | This repository is inspired by the following open-source projects: 117 | - [transformers](https://github.com/huggingface/transformers) by [Huggingface](https://github.com/huggingface) 118 | - [LAION Aesthetic Predictor](https://github.com/LAION-AI/aesthetic-predictor) (and its [out-of-the-box implementation](https://github.com/shunk031/simple-aesthetics-predictor)) 119 | - [CLIP](https://github.com/openai/CLIP) 120 | - [DINO-v2](https://github.com/facebookresearch/dinov2) 121 | - [DOVER](https://github.com/VQAssessment/DOVER) 122 | - [GMFlow](https://github.com/haofeixu/gmflow) 123 | - [ViCLIP](https://github.com/OpenGVLab/InternVideo) 124 | - [PickScore](https://github.com/yuvalkirstain/PickScore) 125 | - [VBench](https://github.com/Vchitect/VBench) 126 | 127 | We extend our gratitude to the authors for their exceptional work and open-source contributions. -------------------------------------------------------------------------------- /third_party/gmflow/models/gmflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbone import CNNEncoder 6 | from .transformer import FeatureTransformer, FeatureFlowAttention 7 | from .matching import global_correlation_softmax, local_correlation_softmax 8 | from .geometry import flow_warp 9 | from .utils import normalize_img, feature_add_position 10 | 11 | 12 | class GMFlow(nn.Module): 13 | def __init__(self, 14 | num_scales=1, 15 | upsample_factor=8, 16 | feature_channels=128, 17 | attention_type='swin', 18 | num_transformer_layers=6, 19 | ffn_dim_expansion=4, 20 | num_head=1, 21 | **kwargs, 22 | ): 23 | super(GMFlow, self).__init__() 24 | 25 | self.num_scales = num_scales 26 | self.feature_channels = feature_channels 27 | self.upsample_factor = upsample_factor 28 | self.attention_type = attention_type 29 | self.num_transformer_layers = num_transformer_layers 30 | 31 | # CNN backbone 32 | self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) 33 | 34 | # Transformer 35 | self.transformer = FeatureTransformer(num_layers=num_transformer_layers, 36 | d_model=feature_channels, 37 | nhead=num_head, 38 | attention_type=attention_type, 39 | ffn_dim_expansion=ffn_dim_expansion, 40 | ) 41 | 42 | # flow propagation with self-attn 43 | self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) 44 | 45 | # convex upsampling: concat feature0 and flow as input 46 | self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) 49 | 50 | def extract_feature(self, img0, img1): 51 | concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] 52 | features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low 53 | 54 | # reverse: resolution from low to high 55 | features = features[::-1] 56 | 57 | feature0, feature1 = [], [] 58 | 59 | for i in range(len(features)): 60 | feature = features[i] 61 | chunks = torch.chunk(feature, 2, 0) # tuple 62 | feature0.append(chunks[0]) 63 | feature1.append(chunks[1]) 64 | 65 | return feature0, feature1 66 | 67 | def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, 68 | ): 69 | if bilinear: 70 | up_flow = F.interpolate(flow, scale_factor=upsample_factor, 71 | mode='bilinear', align_corners=True) * upsample_factor 72 | 73 | else: 74 | # convex upsampling 75 | concat = torch.cat((flow, feature), dim=1) 76 | 77 | mask = self.upsampler(concat) 78 | b, flow_channel, h, w = flow.shape 79 | mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) 83 | up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] 84 | 85 | up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] 86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] 87 | up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h, 88 | self.upsample_factor * w) # [B, 2, K*H, K*W] 89 | 90 | return up_flow 91 | 92 | def forward(self, img0, img1, 93 | attn_splits_list=None, 94 | corr_radius_list=None, 95 | prop_radius_list=None, 96 | pred_bidir_flow=False, 97 | **kwargs, 98 | ): 99 | 100 | results_dict = {} 101 | flow_preds = [] 102 | 103 | img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] 104 | 105 | # resolution low to high 106 | feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features 107 | 108 | flow = None 109 | 110 | assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales 111 | 112 | for scale_idx in range(self.num_scales): 113 | feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] 114 | 115 | if pred_bidir_flow and scale_idx > 0: 116 | # predicting bidirectional flow with refinement 117 | feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) 118 | 119 | upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) 120 | 121 | if scale_idx > 0: 122 | flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 123 | 124 | if flow is not None: 125 | flow = flow.detach() 126 | feature1 = flow_warp(feature1, flow) # [B, C, H, W] 127 | 128 | attn_splits = attn_splits_list[scale_idx] 129 | corr_radius = corr_radius_list[scale_idx] 130 | prop_radius = prop_radius_list[scale_idx] 131 | 132 | # add position to features 133 | feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) 134 | 135 | # Transformer 136 | feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) 137 | 138 | # correlation and softmax 139 | if corr_radius == -1: # global matching 140 | flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] 141 | else: # local matching 142 | flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] 143 | 144 | # flow or residual flow 145 | flow = flow + flow_pred if flow is not None else flow_pred 146 | 147 | # upsample to the original resolution for supervison 148 | if self.training: # only need to upsample intermediate flow predictions at training time 149 | flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) 150 | flow_preds.append(flow_bilinear) 151 | 152 | # flow propagation with self-attn 153 | if pred_bidir_flow and scale_idx == 0: 154 | feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation 155 | flow = self.feature_flow_attn(feature0, flow.detach(), 156 | local_window_attn=prop_radius > 0, 157 | local_window_radius=prop_radius) 158 | 159 | # bilinear upsampling at training time except the last one 160 | if self.training and scale_idx < self.num_scales - 1: 161 | flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) 162 | flow_preds.append(flow_up) 163 | 164 | if scale_idx == self.num_scales - 1: 165 | flow_up = self.upsample_flow(flow, feature0) 166 | flow_preds.append(flow_up) 167 | 168 | results_dict.update({'flow_preds': flow_preds}) 169 | 170 | return results_dict -------------------------------------------------------------------------------- /v2vbench/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @Created : 2024/05/07 13:48:51 5 | @Desc : 6 | @Ref : 7 | ''' 8 | import gc 9 | import importlib 10 | import logging 11 | import sys 12 | import warnings 13 | from pathlib import Path 14 | from typing import Dict, List, Optional, Union 15 | 16 | import pandas as pd 17 | import torch 18 | from tqdm.auto import tqdm 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | cur_dir = Path('..').resolve().absolute() 23 | if str(cur_dir) not in sys.path: 24 | sys.path.append(str(cur_dir)) 25 | 26 | 27 | class EvaluatorWrapper: 28 | _quality_metrics = ['aesthetic_score', 'clip_consistency', 'dino_consistency', 'dover_score'] 29 | _alignment_metrics = ['clip_text_alignment', 'pick_score', 'viclip_text_alignment', 'motion_alignment'] 30 | all_metrics = _quality_metrics + _alignment_metrics 31 | 32 | def __init__( 33 | self, 34 | device: Optional[torch.device] = None, 35 | metrics: Union[str, List[str]] = 'all', 36 | ): 37 | self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | self.metrics = metrics 39 | 40 | def _check_args(self, edit_video, reference_video, index_file, edit_prompt): 41 | if isinstance(self.metrics, str): 42 | if self.metrics == 'all': 43 | self.metrics = self.all_metrics 44 | else: 45 | self.metrics = self.metrics.split(',') 46 | not_supported_metrics = set(self.metrics) - set(self.all_metrics) 47 | self.metrics = list(set(self.metrics) - not_supported_metrics) 48 | if not_supported_metrics: 49 | warnings.warn(f'Unsupported metrics: {not_supported_metrics}') 50 | if self.metrics: 51 | logger.info(f'Using metrics: {self.metrics}') 52 | else: 53 | raise ValueError('No supported metrics provided') 54 | 55 | if not edit_video.exists(): 56 | raise FileNotFoundError(f'Edit video: {edit_video} does not exist') 57 | if reference_video and not reference_video.exists(): 58 | raise FileNotFoundError(f'Reference video: {reference_video} does not exist') 59 | if index_file: 60 | if not index_file.exists(): 61 | raise FileNotFoundError(f'Index file: {index_file} does not exist') 62 | else: 63 | if not edit_prompt: 64 | raise ValueError( 65 | 'Edit prompt (for single video evaluation) OR index file (for batch evaluation) is required') 66 | 67 | def merge_results(self, results: Dict[str, Dict[str, List]]) -> pd.DataFrame: 68 | merged_results_df = None 69 | for metric, result in results.items(): 70 | result_df = pd.DataFrame(result) 71 | result_df.rename(columns={'score': metric}, inplace=True) 72 | if merged_results_df is None: 73 | merged_results_df = result_df 74 | else: 75 | merged_results_df = pd.merge(merged_results_df, result_df, on=['video_id', 'edit_id']) 76 | return merged_results_df 77 | 78 | def evaluate( 79 | self, 80 | edit_video: Path, 81 | reference_video: Optional[Path] = None, 82 | index_file: Optional[Path] = None, 83 | edit_prompt: Optional[str] = None, 84 | output_dir: Optional[Path] = None, 85 | save_results: bool = True, 86 | save_summary: bool = True, 87 | evaluator_kwargs: Optional[Dict[str, Dict]] = None, 88 | ) -> Dict[str, Dict[str, List]]: 89 | """Evaluate the editing video(s) using the specified metrics 90 | 91 | Args: 92 | edit_video (Path): Path to the editing video (single video) or directory containing editing videos (batch). 93 | reference_video (Optional[Path], optional): Path to the reference video (single video) or directory 94 | containing reference videos (batch). Defaults to None. 95 | index_file (Optional[Path], optional): Index file containing metadata for batch evaluation. Defaults to 96 | None. 97 | edit_prompt (Optional[str], optional): Edit prompt for single video evaluation. Defaults to None. 98 | output_dir (Optional[Path], optional): Directory to save the results and summary files. Defaults to None. 99 | save_results (bool, optional): Save the results in a CSV file. Only applicable when output_dir is provided. 100 | Defaults to True. 101 | save_summary (bool, optional): Save the summary in a CSV file. Only applicable when output_dir is provided. 102 | Defaults to True. 103 | evaluator_kwargs (Optional[Dict[str, Dict]], optional): Additional keyword arguments for the evaluators. 104 | Defaults to None. 105 | 106 | Returns: 107 | Dict[str, Dict[str, List]]: Results for each metric for each sample, formatted as below: 108 | { 109 | 'metric1_name': { 110 | 'video_id': [video_id1, video_id2, ...], 111 | 'edit_id: [edit_id1, edit_id2, ...], 112 | 'score': [score1, score2, ...] 113 | }, 114 | ... 115 | } 116 | """ 117 | reference_video = Path(reference_video).resolve() if reference_video else None 118 | edit_video = Path(edit_video).resolve() 119 | index_file = Path(index_file).resolve() if index_file else None 120 | edit_prompt = edit_prompt 121 | self._check_args(edit_video, reference_video, index_file, edit_prompt) 122 | 123 | # evaluate results for each sample 124 | results = {} 125 | for metric in tqdm(self.metrics): 126 | try: 127 | module = importlib.import_module(f'.{metric}', package='v2vbench') 128 | evaluator_cls_name = ''.join([w[0].upper() + w[1:] for w in metric.split('_')]) 129 | evaluator_cls = getattr(module, evaluator_cls_name) 130 | kwargs = evaluator_kwargs[metric] if evaluator_kwargs and metric in evaluator_kwargs else {} 131 | 132 | evaluator = evaluator_cls( 133 | index_file=index_file, 134 | edit_video_dir=edit_video, 135 | reference_video_dir=reference_video, 136 | edit_prompt=edit_prompt, 137 | device=self.device, 138 | **kwargs, 139 | ) 140 | except Exception as e: 141 | raise NotImplementedError(f'Failed to initalize {metric} metrics: {e}') 142 | 143 | result = evaluator() 144 | results[metric] = result 145 | 146 | # release memory after evaluating each metric 147 | del evaluator 148 | gc.collect() 149 | if self.device.type == 'cuda': 150 | torch.cuda.empty_cache() 151 | 152 | # merge the results into a single dataframe 153 | merged_results_df = self.merge_results(results) 154 | # aggregate into a single-row summary 155 | summary_df = merged_results_df[self.metrics].mean(axis=0).to_frame().T 156 | logger.info(reference_video.name) 157 | logger.info(summary_df) 158 | 159 | if output_dir is not None and (save_results or save_summary): 160 | output_dir = Path(output_dir).resolve() 161 | output_dir.mkdir(parents=True, exist_ok=True) 162 | 163 | if save_results: 164 | merged_results_df.to_csv(output_dir / 'results.csv', index=False) 165 | logger.info(f'Results saved at: {output_dir / "results.csv"}') 166 | if save_summary: 167 | summary_df.to_csv(output_dir / 'summary.csv', index=False) 168 | logger.info(f'Summary saved at: {output_dir / "summary.csv"}') 169 | 170 | return results -------------------------------------------------------------------------------- /third_party/viclip/viclip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | from einops import rearrange 6 | from torch import nn 7 | import math 8 | 9 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 10 | from .viclip_vision import clip_joint_l14 11 | from .viclip_text import clip_text_l14 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class ViCLIP(nn.Module): 17 | """docstring for ViCLIP""" 18 | 19 | def __init__(self, tokenizer=None, pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"), freeze_text=True): 20 | super(ViCLIP, self).__init__() 21 | if tokenizer: 22 | self.tokenizer = tokenizer 23 | else: 24 | self.tokenizer = _Tokenizer() 25 | self.max_txt_l = 32 26 | 27 | self.vision_encoder_name = 'vit_l14' 28 | 29 | self.vision_encoder_pretrained = False 30 | self.inputs_image_res = 224 31 | self.vision_encoder_kernel_size = 1 32 | self.vision_encoder_center = True 33 | self.video_input_num_frames = 8 34 | self.vision_encoder_drop_path_rate = 0.1 35 | self.vision_encoder_checkpoint_num = 24 36 | self.is_pretrain = pretrain 37 | self.vision_width = 1024 38 | self.text_width = 768 39 | self.embed_dim = 768 40 | self.masking_prob = 0.9 41 | 42 | self.text_encoder_name = 'vit_l14' 43 | self.text_encoder_pretrained = False #'bert-base-uncased' 44 | self.text_encoder_d_model = 768 45 | 46 | self.text_encoder_vocab_size = 49408 47 | 48 | 49 | # create modules. 50 | self.vision_encoder = self.build_vision_encoder() 51 | self.text_encoder = self.build_text_encoder() 52 | 53 | self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0) 54 | self.temp_min = 1 / 100.0 55 | 56 | if pretrain: 57 | logger.info(f"Load pretrained weights from {pretrain}") 58 | state_dict = torch.load(pretrain, map_location='cpu')['model'] 59 | self.load_state_dict(state_dict) 60 | 61 | # Freeze weights 62 | if freeze_text: 63 | self.freeze_text() 64 | 65 | 66 | 67 | def freeze_text(self): 68 | """freeze text encoder""" 69 | for p in self.text_encoder.parameters(): 70 | p.requires_grad = False 71 | 72 | def no_weight_decay(self): 73 | ret = {"temp"} 74 | ret.update( 75 | {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} 76 | ) 77 | ret.update( 78 | {"text_encoder." + k for k in self.text_encoder.no_weight_decay()} 79 | ) 80 | 81 | return ret 82 | 83 | def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False): 84 | """forward and calculate loss. 85 | 86 | Args: 87 | image (torch.Tensor): The input images. Shape: [B,T,C,H,W]. 88 | text (dict): TODO 89 | idx (torch.Tensor): TODO 90 | 91 | Returns: TODO 92 | 93 | """ 94 | self.clip_contrastive_temperature() 95 | 96 | vision_embeds = self.encode_vision(image) 97 | text_embeds = self.encode_text(raw_text) 98 | if return_sims: 99 | sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \ 100 | torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1) 101 | return sims 102 | 103 | # calculate loss 104 | 105 | ## VTC loss 106 | loss_vtc = self.clip_loss.vtc_loss( 107 | vision_embeds, text_embeds, idx, self.temp, all_gather=True 108 | ) 109 | 110 | return dict( 111 | loss_vtc=loss_vtc, 112 | ) 113 | 114 | def encode_vision(self, image, test=False): 115 | """encode image / videos as features. 116 | 117 | Args: 118 | image (torch.Tensor): The input images. 119 | test (bool): Whether testing. 120 | 121 | Returns: tuple. 122 | - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C]. 123 | - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C]. 124 | 125 | """ 126 | if image.ndim == 5: 127 | image = image.permute(0, 2, 1, 3, 4).contiguous() 128 | else: 129 | image = image.unsqueeze(2) 130 | 131 | if not test and self.masking_prob > 0.0: 132 | return self.vision_encoder( 133 | image, masking_prob=self.masking_prob 134 | ) 135 | 136 | return self.vision_encoder(image) 137 | 138 | def encode_text(self, text): 139 | """encode text. 140 | Args: 141 | text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: 142 | - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. 143 | - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. 144 | - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". 145 | Returns: tuple. 146 | - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. 147 | - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. 148 | 149 | """ 150 | device = next(self.text_encoder.parameters()).device 151 | text = self.text_encoder.tokenize( 152 | text, context_length=self.max_txt_l 153 | ).to(device) 154 | text_embeds = self.text_encoder(text) 155 | return text_embeds 156 | 157 | @torch.no_grad() 158 | def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5): 159 | """Seems only used during pre-training""" 160 | self.temp.clamp_(min=self.temp_min) 161 | 162 | def build_vision_encoder(self): 163 | """build vision encoder 164 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. 165 | 166 | """ 167 | encoder_name = self.vision_encoder_name 168 | if encoder_name != "vit_l14": 169 | raise ValueError(f"Not implemented: {encoder_name}") 170 | vision_encoder = clip_joint_l14( 171 | pretrained=self.vision_encoder_pretrained, 172 | input_resolution=self.inputs_image_res, 173 | kernel_size=self.vision_encoder_kernel_size, 174 | center=self.vision_encoder_center, 175 | num_frames=self.video_input_num_frames, 176 | drop_path=self.vision_encoder_drop_path_rate, 177 | checkpoint_num=self.vision_encoder_checkpoint_num, 178 | ) 179 | return vision_encoder 180 | 181 | def build_text_encoder(self): 182 | """build text_encoder and possiblly video-to-text multimodal fusion encoder. 183 | Returns: nn.Module. The text encoder 184 | 185 | """ 186 | encoder_name = self.text_encoder_name 187 | if encoder_name != "vit_l14": 188 | raise ValueError(f"Not implemented: {encoder_name}") 189 | text_encoder = clip_text_l14( 190 | pretrained=self.text_encoder_pretrained, 191 | embed_dim=self.text_encoder_d_model, 192 | context_length=self.max_txt_l, 193 | vocab_size=self.text_encoder_vocab_size, 194 | checkpoint_num=0, 195 | tokenizer=self.tokenizer, 196 | ) 197 | 198 | return text_encoder 199 | 200 | def get_text_encoder(self): 201 | """get text encoder, used for text and cross-modal encoding""" 202 | encoder = self.text_encoder 203 | return encoder.bert if hasattr(encoder, "bert") else encoder 204 | 205 | def get_text_features(self, input_text, tokenizer, text_feature_dict={}): 206 | if input_text in text_feature_dict: 207 | return text_feature_dict[input_text] 208 | text_template= f"{input_text}" 209 | with torch.no_grad(): 210 | # text_token = tokenizer.encode(text_template).cuda() 211 | text_features = self.encode_text(text_template).float() 212 | text_features /= text_features.norm(dim=-1, keepdim=True) 213 | text_feature_dict[input_text] = text_features 214 | return text_features 215 | 216 | def get_vid_features(self, input_frames): 217 | with torch.no_grad(): 218 | clip_feat = self.encode_vision(input_frames,test=True).float() 219 | clip_feat /= clip_feat.norm(dim=-1, keepdim=True) 220 | return clip_feat 221 | 222 | def get_predict_label(self, clip_feature, text_feats_tensor, top=5): 223 | label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1) 224 | top_probs, top_labels = label_probs.cpu().topk(top, dim=-1) 225 | return top_probs, top_labels -------------------------------------------------------------------------------- /third_party/dover/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | 8 | random.seed(42) 9 | 10 | 11 | def get_spatial_fragments( 12 | video, 13 | fragments_h=7, 14 | fragments_w=7, 15 | fsize_h=32, 16 | fsize_w=32, 17 | aligned=32, 18 | nfrags=1, 19 | random=False, 20 | random_upsample=False, 21 | fallback_type="upsample", 22 | upsample=-1, 23 | **kwargs, 24 | ): 25 | if upsample > 0: 26 | old_h, old_w = video.shape[-2], video.shape[-1] 27 | if old_h >= old_w: 28 | w = upsample 29 | h = int(upsample * old_h / old_w) 30 | else: 31 | h = upsample 32 | w = int(upsample * old_w / old_h) 33 | 34 | video = get_resized_video(video, h, w) 35 | size_h = fragments_h * fsize_h 36 | size_w = fragments_w * fsize_w 37 | ## video: [C,T,H,W] 38 | ## situation for images 39 | if video.shape[1] == 1: 40 | aligned = 1 41 | 42 | dur_t, res_h, res_w = video.shape[-3:] 43 | ratio = min(res_h / size_h, res_w / size_w) 44 | if fallback_type == "upsample" and ratio < 1: 45 | 46 | ovideo = video 47 | video = torch.nn.functional.interpolate( 48 | video / 255.0, scale_factor=1 / ratio, mode="bilinear" 49 | ) 50 | video = (video * 255.0).type_as(ovideo) 51 | 52 | if random_upsample: 53 | 54 | randratio = random.random() * 0.5 + 1 55 | video = torch.nn.functional.interpolate( 56 | video / 255.0, scale_factor=randratio, mode="bilinear" 57 | ) 58 | video = (video * 255.0).type_as(ovideo) 59 | 60 | assert dur_t % aligned == 0, "Please provide match vclip and align index" 61 | size = size_h, size_w 62 | 63 | ## make sure that sampling will not run out of the picture 64 | hgrids = torch.LongTensor( 65 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] 66 | ) 67 | wgrids = torch.LongTensor( 68 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] 69 | ) 70 | hlength, wlength = res_h // fragments_h, res_w // fragments_w 71 | 72 | if random: 73 | print("This part is deprecated. Please remind that.") 74 | if res_h > fsize_h: 75 | rnd_h = torch.randint( 76 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 77 | ) 78 | else: 79 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 80 | if res_w > fsize_w: 81 | rnd_w = torch.randint( 82 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 83 | ) 84 | else: 85 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 86 | else: 87 | if hlength > fsize_h: 88 | rnd_h = torch.randint( 89 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 90 | ) 91 | else: 92 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 93 | if wlength > fsize_w: 94 | rnd_w = torch.randint( 95 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 96 | ) 97 | else: 98 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 99 | 100 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device) 101 | # target_videos = [] 102 | 103 | for i, hs in enumerate(hgrids): 104 | for j, ws in enumerate(wgrids): 105 | for t in range(dur_t // aligned): 106 | t_s, t_e = t * aligned, (t + 1) * aligned 107 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h 108 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w 109 | if random: 110 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h 111 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w 112 | else: 113 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h 114 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w 115 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ 116 | :, t_s:t_e, h_so:h_eo, w_so:w_eo 117 | ] 118 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) 119 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) 120 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments 121 | return target_video 122 | 123 | 124 | @lru_cache 125 | def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False): 126 | if random_crop: 127 | return torchvision.transforms.RandomResizedCrop( 128 | (size_h, size_w), scale=(0.40, 1.0) 129 | ) 130 | if target_ratio > 1: 131 | size_h = int(target_ratio * size_w) 132 | assert size_h > size_w 133 | elif target_ratio < 1: 134 | size_w = int(size_h / target_ratio) 135 | assert size_w > size_h 136 | return torchvision.transforms.Resize((size_h, size_w)) 137 | 138 | 139 | def get_resized_video( 140 | video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs, 141 | ): 142 | video = video.permute(1, 0, 2, 3) 143 | resize_opt = get_resize_function( 144 | size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop 145 | ) 146 | video = resize_opt(video).permute(1, 0, 2, 3) 147 | return video 148 | 149 | 150 | def get_single_view( 151 | video, sample_type="aesthetic", **kwargs, 152 | ): 153 | if sample_type.startswith("aesthetic"): 154 | video = get_resized_video(video, **kwargs) 155 | elif sample_type.startswith("technical"): 156 | video = get_spatial_fragments(video, **kwargs) 157 | elif sample_type == "original": 158 | return video 159 | 160 | return video 161 | 162 | 163 | def spatial_temporal_view_decomposition( 164 | frames: list, sample_types, samplers, is_train=False, augment=False, 165 | ): 166 | video = {} 167 | 168 | ### Avoid duplicated video decoding!!! Important!!!! 169 | all_frame_inds = [] 170 | frame_inds = {} 171 | for stype in samplers: 172 | frame_inds[stype] = samplers[stype](len(frames), is_train) 173 | all_frame_inds.append(frame_inds[stype]) 174 | 175 | ### Each frame is only decoded one time!!! 176 | all_frame_inds = np.concatenate(all_frame_inds, 0) 177 | frame_dict = {idx: frames[idx] for idx in np.unique(all_frame_inds)} 178 | 179 | for stype in samplers: 180 | imgs = [frame_dict[idx] for idx in frame_inds[stype]] 181 | imgs = [torch.from_numpy(np.array(img).astype(np.uint8)).float() for img in imgs] # List of [H,W,C] 182 | video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) 183 | 184 | sampled_video = {} 185 | for stype, sopt in sample_types.items(): 186 | sampled_video[stype] = get_single_view(video[stype], stype, **sopt) 187 | return sampled_video, frame_inds 188 | 189 | 190 | 191 | class UnifiedFrameSampler: 192 | def __init__( 193 | self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0, 194 | ): 195 | 196 | self.fragments_t = fragments_t 197 | self.fsize_t = fsize_t 198 | self.size_t = fragments_t * fsize_t 199 | self.frame_interval = frame_interval 200 | self.num_clips = num_clips 201 | self.drop_rate = drop_rate 202 | 203 | def get_frame_indices(self, num_frames, train=False): 204 | 205 | tgrids = np.array( 206 | [num_frames // self.fragments_t * i for i in range(self.fragments_t)], 207 | dtype=np.int32, 208 | ) 209 | tlength = num_frames // self.fragments_t 210 | 211 | if tlength > self.fsize_t * self.frame_interval: 212 | rnd_t = np.random.randint( 213 | 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) 214 | ) 215 | else: 216 | rnd_t = np.zeros(len(tgrids), dtype=np.int32) 217 | 218 | ranges_t = ( 219 | np.arange(self.fsize_t)[None, :] * self.frame_interval 220 | + rnd_t[:, None] 221 | + tgrids[:, None] 222 | ) 223 | 224 | drop = random.sample( 225 | list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate) 226 | ) 227 | dropped_ranges_t = [] 228 | for i, rt in enumerate(ranges_t): 229 | if i not in drop: 230 | dropped_ranges_t.append(rt) 231 | return np.concatenate(dropped_ranges_t) 232 | 233 | def __call__(self, total_frames, train=False, start_index=0): 234 | frame_inds = [] 235 | 236 | for i in range(self.num_clips): 237 | frame_inds += [self.get_frame_indices(total_frames)] 238 | 239 | frame_inds = np.concatenate(frame_inds) 240 | frame_inds = np.mod(frame_inds + start_index, total_frames) 241 | return frame_inds.astype(np.int32) -------------------------------------------------------------------------------- /third_party/dover/models/evaluator.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .conv_backbone import convnext_3d_small, convnext_3d_tiny, convnextv2_3d_pico, convnextv2_3d_femto 7 | from .head import IQAHead, VARHead, VQAHead 8 | from .swin_backbone import SwinTransformer2D as ImageBackbone 9 | from .swin_backbone import SwinTransformer3D as VideoBackbone 10 | from .swin_backbone import swin_3d_small, swin_3d_tiny 11 | 12 | 13 | class BaseEvaluator(nn.Module): 14 | def __init__( 15 | self, backbone=dict(), vqa_head=dict(), 16 | ): 17 | super().__init__() 18 | self.backbone = VideoBackbone(**backbone) 19 | self.vqa_head = VQAHead(**vqa_head) 20 | 21 | def forward(self, vclip, inference=True, **kwargs): 22 | if inference: 23 | self.eval() 24 | with torch.no_grad(): 25 | feat = self.backbone(vclip) 26 | score = self.vqa_head(feat) 27 | self.train() 28 | return score 29 | else: 30 | feat = self.backbone(vclip) 31 | score = self.vqa_head(feat) 32 | return score 33 | 34 | def forward_with_attention(self, vclip): 35 | self.eval() 36 | with torch.no_grad(): 37 | feat, avg_attns = self.backbone(vclip, require_attn=True) 38 | score = self.vqa_head(feat) 39 | return score, avg_attns 40 | 41 | 42 | class DOVER(nn.Module): 43 | def __init__( 44 | self, 45 | backbone_size="divided", 46 | backbone_preserve_keys="fragments,resize", 47 | multi=False, 48 | layer=-1, 49 | backbone=dict( 50 | resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} 51 | ), 52 | divide_head=False, 53 | vqa_head=dict(in_channels=768), 54 | var=False, 55 | ): 56 | self.backbone_preserve_keys = backbone_preserve_keys.split(",") 57 | self.multi = multi 58 | self.layer = layer 59 | super().__init__() 60 | for key, hypers in backbone.items(): 61 | # print(backbone_size) 62 | if key not in self.backbone_preserve_keys: 63 | continue 64 | if backbone_size == "divided": 65 | t_backbone_size = hypers["type"] 66 | else: 67 | t_backbone_size = backbone_size 68 | if t_backbone_size == "swin_tiny": 69 | b = swin_3d_tiny(**backbone[key]) 70 | elif t_backbone_size == "swin_tiny_grpb": 71 | # to reproduce fast-vqa 72 | b = VideoBackbone() 73 | elif t_backbone_size == "swin_tiny_grpb_m": 74 | # to reproduce fast-vqa-m 75 | b = VideoBackbone(window_size=(4, 4, 4), frag_biases=[0, 0, 0, 0]) 76 | elif t_backbone_size == "swin_small": 77 | b = swin_3d_small(**backbone[key]) 78 | elif t_backbone_size == "conv_tiny": 79 | b = convnext_3d_tiny(pretrained=True) 80 | elif t_backbone_size == "conv_small": 81 | b = convnext_3d_small(pretrained=True) 82 | elif t_backbone_size == "conv_femto": 83 | b = convnextv2_3d_femto(pretrained=True) 84 | elif t_backbone_size == "conv_pico": 85 | b = convnextv2_3d_pico(pretrained=True) 86 | elif t_backbone_size == "xclip": 87 | raise NotImplementedError 88 | # b = build_x_clip_model(**backbone[key]) 89 | else: 90 | raise NotImplementedError 91 | # print("Setting backbone:", key + "_backbone") 92 | setattr(self, key + "_backbone", b) 93 | if divide_head: 94 | for key in backbone: 95 | pre_pool = False #if key == "technical" else True 96 | if key not in self.backbone_preserve_keys: 97 | continue 98 | b = VQAHead(pre_pool=pre_pool, **vqa_head) 99 | # print("Setting head:", key + "_head") 100 | setattr(self, key + "_head", b) 101 | else: 102 | if var: 103 | self.vqa_head = VARHead(**vqa_head) 104 | # print(b) 105 | else: 106 | self.vqa_head = VQAHead(**vqa_head) 107 | 108 | def forward( 109 | self, 110 | vclips, 111 | inference=True, 112 | return_pooled_feats=False, 113 | return_raw_feats=False, 114 | reduce_scores=False, 115 | pooled=False, 116 | **kwargs 117 | ): 118 | assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" 119 | if inference: 120 | self.eval() 121 | with torch.no_grad(): 122 | scores = [] 123 | feats = {} 124 | for key in vclips: 125 | feat = getattr(self, key.split("_")[0] + "_backbone")( 126 | vclips[key], multi=self.multi, layer=self.layer, **kwargs 127 | ) 128 | if hasattr(self, key.split("_")[0] + "_head"): 129 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)] 130 | else: 131 | scores += [getattr(self, "vqa_head")(feat)] 132 | if return_pooled_feats: 133 | feats[key] = feat 134 | if return_raw_feats: 135 | feats[key] = feat 136 | if reduce_scores: 137 | if len(scores) > 1: 138 | scores = reduce(lambda x, y: x + y, scores) 139 | else: 140 | scores = scores[0] 141 | if pooled: 142 | scores = torch.mean(scores, (1, 2, 3, 4)) 143 | self.train() 144 | if return_pooled_feats or return_raw_feats: 145 | return scores, feats 146 | return scores 147 | else: 148 | self.train() 149 | scores = [] 150 | feats = {} 151 | for key in vclips: 152 | feat = getattr(self, key.split("_")[0] + "_backbone")( 153 | vclips[key], multi=self.multi, layer=self.layer, **kwargs 154 | ) 155 | if hasattr(self, key.split("_")[0] + "_head"): 156 | scores += [getattr(self, key.split("_")[0] + "_head")(feat)] 157 | else: 158 | scores += [getattr(self, "vqa_head")(feat)] 159 | if return_pooled_feats: 160 | feats[key] = feat.mean((-3, -2, -1)) 161 | if reduce_scores: 162 | if len(scores) > 1: 163 | scores = reduce(lambda x, y: x + y, scores) 164 | else: 165 | scores = scores[0] 166 | if pooled: 167 | # print(scores.shape) 168 | scores = torch.mean(scores, (1, 2, 3, 4)) 169 | # print(scores.shape) 170 | 171 | if return_pooled_feats: 172 | return scores, feats 173 | return scores 174 | 175 | 176 | class MinimumDOVER(nn.Module): 177 | def __init__(self): 178 | super().__init__() 179 | self.technical_backbone = VideoBackbone() 180 | self.aesthetic_backbone = convnext_3d_tiny(pretrained=True) 181 | self.technical_head = VQAHead(pre_pool=False, in_channels=768) 182 | self.aesthetic_head = VQAHead(pre_pool=False, in_channels=768) 183 | 184 | 185 | def forward(self,aesthetic_view, technical_view): 186 | self.eval() 187 | with torch.no_grad(): 188 | aesthetic_score = self.aesthetic_head(self.aesthetic_backbone(aesthetic_view)) 189 | technical_score = self.technical_head(self.technical_backbone(technical_view)) 190 | 191 | aesthetic_score_pooled = torch.mean(aesthetic_score, (1,2,3,4)) 192 | technical_score_pooled = torch.mean(technical_score, (1,2,3,4)) 193 | return [aesthetic_score_pooled, technical_score_pooled] 194 | 195 | 196 | 197 | class BaseImageEvaluator(nn.Module): 198 | def __init__( 199 | self, backbone=dict(), iqa_head=dict(), 200 | ): 201 | super().__init__() 202 | self.backbone = ImageBackbone(**backbone) 203 | self.iqa_head = IQAHead(**iqa_head) 204 | 205 | def forward(self, image, inference=True, **kwargs): 206 | if inference: 207 | self.eval() 208 | with torch.no_grad(): 209 | feat = self.backbone(image) 210 | score = self.iqa_head(feat) 211 | self.train() 212 | return score 213 | else: 214 | feat = self.backbone(image) 215 | score = self.iqa_head(feat) 216 | return score 217 | 218 | def forward_with_attention(self, image): 219 | self.eval() 220 | with torch.no_grad(): 221 | feat, avg_attns = self.backbone(image, require_attn=True) 222 | score = self.iqa_head(feat) 223 | return score, avg_attns 224 | -------------------------------------------------------------------------------- /third_party/viclip/viclip_text.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | from pkg_resources import packaging 11 | from torch import nn 12 | 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | MODEL_PATH = 'https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K' 19 | _MODELS = { 20 | "ViT-L/14": os.path.join(MODEL_PATH, "vit_l14_text.pth"), 21 | } 22 | 23 | 24 | class LayerNorm(nn.LayerNorm): 25 | """Subclass torch's LayerNorm to handle fp16.""" 26 | 27 | def forward(self, x: torch.Tensor): 28 | orig_type = x.dtype 29 | ret = super().forward(x.type(torch.float32)) 30 | return ret.type(orig_type) 31 | 32 | 33 | class QuickGELU(nn.Module): 34 | def forward(self, x: torch.Tensor): 35 | return x * torch.sigmoid(1.702 * x) 36 | 37 | 38 | class ResidualAttentionBlock(nn.Module): 39 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 40 | super().__init__() 41 | 42 | self.attn = nn.MultiheadAttention(d_model, n_head) 43 | self.ln_1 = LayerNorm(d_model) 44 | self.mlp = nn.Sequential(OrderedDict([ 45 | ("c_fc", nn.Linear(d_model, d_model * 4)), 46 | ("gelu", QuickGELU()), 47 | ("c_proj", nn.Linear(d_model * 4, d_model)) 48 | ])) 49 | self.ln_2 = LayerNorm(d_model) 50 | self.attn_mask = attn_mask 51 | 52 | def attention(self, x: torch.Tensor): 53 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 54 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 55 | 56 | def forward(self, x: torch.Tensor): 57 | x = x + self.attention(self.ln_1(x)) 58 | x = x + self.mlp(self.ln_2(x)) 59 | return x 60 | 61 | 62 | class Transformer(nn.Module): 63 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, 64 | checkpoint_num: int = 0): 65 | super().__init__() 66 | self.width = width 67 | self.layers = layers 68 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 69 | 70 | self.checkpoint_num = checkpoint_num 71 | 72 | def forward(self, x: torch.Tensor): 73 | if self.checkpoint_num > 0: 74 | segments = min(self.checkpoint_num, len(self.resblocks)) 75 | return checkpoint.checkpoint_sequential(self.resblocks, segments, x) 76 | else: 77 | return self.resblocks(x) 78 | 79 | 80 | class CLIP_TEXT(nn.Module): 81 | def __init__( 82 | self, 83 | embed_dim: int, 84 | context_length: int, 85 | vocab_size: int, 86 | transformer_width: int, 87 | transformer_heads: int, 88 | transformer_layers: int, 89 | checkpoint_num: int, 90 | tokenizer=None, 91 | ): 92 | super().__init__() 93 | 94 | self.context_length = context_length 95 | self._tokenizer = tokenizer or _Tokenizer() 96 | 97 | self.transformer = Transformer( 98 | width=transformer_width, 99 | layers=transformer_layers, 100 | heads=transformer_heads, 101 | attn_mask=self.build_attention_mask(), 102 | checkpoint_num=checkpoint_num, 103 | ) 104 | 105 | self.vocab_size = vocab_size 106 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 107 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 108 | self.ln_final = LayerNorm(transformer_width) 109 | 110 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 111 | 112 | def no_weight_decay(self): 113 | return {'token_embedding', 'positional_embedding'} 114 | 115 | @functools.lru_cache(maxsize=None) 116 | def build_attention_mask(self): 117 | # lazily create causal attention mask, with full attention between the vision tokens 118 | # pytorch uses additive attention mask; fill with -inf 119 | mask = torch.empty(self.context_length, self.context_length) 120 | mask.fill_(float("-inf")) 121 | mask.triu_(1) # zero out the lower diagonal 122 | return mask 123 | 124 | def tokenize(self, texts, context_length=77, truncate=True): 125 | """ 126 | Returns the tokenized representation of given input string(s) 127 | Parameters 128 | ---------- 129 | texts : Union[str, List[str]] 130 | An input string or a list of input strings to tokenize 131 | context_length : int 132 | The context length to use; all CLIP models use 77 as the context length 133 | truncate: bool 134 | Whether to truncate the text in case its encoding is longer than the context length 135 | Returns 136 | ------- 137 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 138 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 139 | """ 140 | if isinstance(texts, str): 141 | texts = [texts] 142 | 143 | sot_token = self._tokenizer.encoder["<|startoftext|>"] 144 | eot_token = self._tokenizer.encoder["<|endoftext|>"] 145 | all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] 146 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 147 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 148 | else: 149 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 150 | 151 | for i, tokens in enumerate(all_tokens): 152 | if len(tokens) > context_length: 153 | if truncate: 154 | tokens = tokens[:context_length] 155 | tokens[-1] = eot_token 156 | else: 157 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 158 | result[i, :len(tokens)] = torch.tensor(tokens) 159 | 160 | return result 161 | 162 | def forward(self, text): 163 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 164 | 165 | x = x + self.positional_embedding 166 | x = x.permute(1, 0, 2) # NLD -> LND 167 | x = self.transformer(x) 168 | x = x.permute(1, 0, 2) # LND -> NLD 169 | x = self.ln_final(x) 170 | 171 | # x.shape = [batch_size, n_ctx, transformer.width] 172 | # take features from the eot embedding (eot_token is the highest number in each sequence) 173 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 174 | 175 | return x 176 | 177 | 178 | def clip_text_b16( 179 | embed_dim=512, 180 | context_length=77, 181 | vocab_size=49408, 182 | transformer_width=512, 183 | transformer_heads=8, 184 | transformer_layers=12, 185 | ): 186 | raise NotImplementedError 187 | model = CLIP_TEXT( 188 | embed_dim, 189 | context_length, 190 | vocab_size, 191 | transformer_width, 192 | transformer_heads, 193 | transformer_layers 194 | ) 195 | pretrained = _MODELS["ViT-B/16"] 196 | logger.info(f"Load pretrained weights from {pretrained}") 197 | state_dict = torch.load(pretrained, map_location='cpu') 198 | model.load_state_dict(state_dict, strict=False) 199 | return model.eval() 200 | 201 | 202 | def clip_text_l14( 203 | embed_dim=768, 204 | context_length=77, 205 | vocab_size=49408, 206 | transformer_width=768, 207 | transformer_heads=12, 208 | transformer_layers=12, 209 | checkpoint_num=0, 210 | pretrained=True, 211 | tokenizer=None, 212 | ): 213 | model = CLIP_TEXT( 214 | embed_dim, 215 | context_length, 216 | vocab_size, 217 | transformer_width, 218 | transformer_heads, 219 | transformer_layers, 220 | checkpoint_num, 221 | tokenizer, 222 | ) 223 | if pretrained: 224 | if isinstance(pretrained, str) and pretrained != "bert-base-uncased": 225 | pretrained = _MODELS[pretrained] 226 | else: 227 | pretrained = _MODELS["ViT-L/14"] 228 | logger.info(f"Load pretrained weights from {pretrained}") 229 | state_dict = torch.load(pretrained, map_location='cpu') 230 | if context_length != state_dict["positional_embedding"].size(0): 231 | # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." 232 | # print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") 233 | if context_length < state_dict["positional_embedding"].size(0): 234 | state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] 235 | else: 236 | state_dict["positional_embedding"] = F.pad( 237 | state_dict["positional_embedding"], 238 | (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), 239 | value=0, 240 | ) 241 | 242 | message = model.load_state_dict(state_dict, strict=False) 243 | # print(f"Load pretrained weights from {pretrained}: {message}") 244 | return model.eval() 245 | 246 | 247 | def clip_text_l14_336( 248 | embed_dim=768, 249 | context_length=77, 250 | vocab_size=49408, 251 | transformer_width=768, 252 | transformer_heads=12, 253 | transformer_layers=12, 254 | ): 255 | raise NotImplementedError 256 | model = CLIP_TEXT( 257 | embed_dim, 258 | context_length, 259 | vocab_size, 260 | transformer_width, 261 | transformer_heads, 262 | transformer_layers 263 | ) 264 | pretrained = _MODELS["ViT-L/14_336"] 265 | logger.info(f"Load pretrained weights from {pretrained}") 266 | state_dict = torch.load(pretrained, map_location='cpu') 267 | model.load_state_dict(state_dict, strict=False) 268 | return model.eval() 269 | 270 | 271 | def build_clip(config): 272 | model_cls = config.text_encoder.clip_teacher 273 | model = eval(model_cls)() 274 | return model -------------------------------------------------------------------------------- /third_party/viclip/viclip_vision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import os 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.utils.checkpoint as checkpoint 8 | from einops import rearrange 9 | from timm.models.layers import DropPath 10 | from timm.models.registry import register_model 11 | from torch import nn 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 17 | """ 18 | Add/Remove extra temporal_embeddings as needed. 19 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 20 | 21 | temp_embed_old: (1, num_frames_old, 1, d) 22 | temp_embed_new: (1, num_frames_new, 1, d) 23 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 24 | """ 25 | # TODO zero pad 26 | num_frms_new = temp_embed_new.shape[1] 27 | num_frms_old = temp_embed_old.shape[1] 28 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 29 | if num_frms_new > num_frms_old: 30 | if add_zero: 31 | temp_embed_new[ 32 | :, :num_frms_old 33 | ] = temp_embed_old # untrained embeddings are zeros. 34 | else: 35 | raise NotImplementedError 36 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 37 | elif num_frms_new < num_frms_old: 38 | temp_embed_new = temp_embed_old[:, :num_frms_new] 39 | else: # = 40 | temp_embed_new = temp_embed_old 41 | return temp_embed_new 42 | 43 | 44 | MODEL_PATH = 'https://huggingface.co/OpenGVLab/VBench_Used_Models/blob/main/' 45 | _MODELS = { 46 | "ViT-L/14": os.path.join(MODEL_PATH, "ViClip-InternVid-10M-FLT.pth"), 47 | } 48 | 49 | 50 | class QuickGELU(nn.Module): 51 | def forward(self, x): 52 | return x * torch.sigmoid(1.702 * x) 53 | 54 | 55 | class ResidualAttentionBlock(nn.Module): 56 | def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.): 57 | super().__init__() 58 | 59 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 60 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 61 | self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 62 | self.ln_1 = nn.LayerNorm(d_model) 63 | self.mlp = nn.Sequential(OrderedDict([ 64 | ("c_fc", nn.Linear(d_model, d_model * 4)), 65 | ("gelu", QuickGELU()), 66 | ("drop1", nn.Dropout(dropout)), 67 | ("c_proj", nn.Linear(d_model * 4, d_model)), 68 | ("drop2", nn.Dropout(dropout)), 69 | ])) 70 | self.ln_2 = nn.LayerNorm(d_model) 71 | self.attn_mask = attn_mask 72 | 73 | def attention(self, x): 74 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 75 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 76 | 77 | def forward(self, x): 78 | x = x + self.drop_path1(self.attention(self.ln_1(x))) 79 | x = x + self.drop_path2(self.mlp(self.ln_2(x))) 80 | return x 81 | 82 | 83 | class Transformer(nn.Module): 84 | def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.): 85 | super().__init__() 86 | dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] 87 | self.resblocks = nn.ModuleList() 88 | for idx in range(layers): 89 | self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout)) 90 | self.checkpoint_num = checkpoint_num 91 | 92 | def forward(self, x): 93 | for idx, blk in enumerate(self.resblocks): 94 | if idx < self.checkpoint_num: 95 | x = checkpoint.checkpoint(blk, x) 96 | else: 97 | x = blk(x) 98 | return x 99 | 100 | 101 | class VisionTransformer(nn.Module): 102 | def __init__( 103 | self, input_resolution, patch_size, width, layers, heads, output_dim=None, 104 | kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0., 105 | temp_embed=True, 106 | ): 107 | super().__init__() 108 | self.output_dim = output_dim 109 | self.conv1 = nn.Conv3d( 110 | 3, width, 111 | (kernel_size, patch_size, patch_size), 112 | (kernel_size, patch_size, patch_size), 113 | (0, 0, 0), bias=False 114 | ) 115 | 116 | scale = width ** -0.5 117 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 118 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 119 | self.ln_pre = nn.LayerNorm(width) 120 | if temp_embed: 121 | self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) 122 | 123 | self.transformer = Transformer( 124 | width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num, 125 | dropout=dropout) 126 | 127 | self.ln_post = nn.LayerNorm(width) 128 | if output_dim is not None: 129 | self.proj = nn.Parameter(torch.empty(width, output_dim)) 130 | else: 131 | self.proj = None 132 | 133 | self.dropout = nn.Dropout(dropout) 134 | 135 | def get_num_layers(self): 136 | return len(self.transformer.resblocks) 137 | 138 | @torch.jit.ignore 139 | def no_weight_decay(self): 140 | return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'} 141 | 142 | def mask_tokens(self, inputs, masking_prob=0.0): 143 | B, L, _ = inputs.shape 144 | 145 | # This is different from text as we are masking a fix number of tokens 146 | Lm = int(masking_prob * L) 147 | masked_indices = torch.zeros(B, L) 148 | indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm] 149 | batch_indices = ( 150 | torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices) 151 | ) 152 | masked_indices[batch_indices, indices] = 1 153 | 154 | masked_indices = masked_indices.bool() 155 | 156 | return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1]) 157 | 158 | def forward(self, x, masking_prob=0.0): 159 | x = self.conv1(x) # shape = [*, width, grid, grid] 160 | B, C, T, H, W = x.shape 161 | x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C) 162 | 163 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 164 | x = x + self.positional_embedding.to(x.dtype) 165 | 166 | # temporal pos 167 | cls_tokens = x[:B, :1, :] 168 | x = x[:, 1:] 169 | x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) 170 | if hasattr(self, 'temporal_positional_embedding'): 171 | if x.size(1) == 1: 172 | # This is a workaround for unused parameter issue 173 | x = x + self.temporal_positional_embedding.mean(1) 174 | else: 175 | x = x + self.temporal_positional_embedding 176 | x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) 177 | 178 | if masking_prob > 0.0: 179 | x = self.mask_tokens(x, masking_prob) 180 | 181 | x = torch.cat((cls_tokens, x), dim=1) 182 | 183 | x = self.ln_pre(x) 184 | 185 | x = x.permute(1, 0, 2) #BND -> NBD 186 | x = self.transformer(x) 187 | 188 | x = self.ln_post(x) 189 | 190 | if self.proj is not None: 191 | x = self.dropout(x[0]) @ self.proj 192 | else: 193 | x = x.permute(1, 0, 2) #NBD -> BND 194 | 195 | return x 196 | 197 | 198 | def inflate_weight(weight_2d, time_dim, center=True): 199 | logger.info(f'Init center: {center}') 200 | if center: 201 | weight_3d = torch.zeros(*weight_2d.shape) 202 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 203 | middle_idx = time_dim // 2 204 | weight_3d[:, :, middle_idx, :, :] = weight_2d 205 | else: 206 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 207 | weight_3d = weight_3d / time_dim 208 | return weight_3d 209 | 210 | 211 | def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True): 212 | state_dict_3d = model.state_dict() 213 | for k in state_dict.keys(): 214 | if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape: 215 | if len(state_dict_3d[k].shape) <= 2: 216 | logger.info(f'Ignore: {k}') 217 | continue 218 | logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') 219 | time_dim = state_dict_3d[k].shape[2] 220 | state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center) 221 | 222 | pos_embed_checkpoint = state_dict['positional_embedding'] 223 | embedding_size = pos_embed_checkpoint.shape[-1] 224 | num_patches = (input_resolution // patch_size) ** 2 225 | orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5) 226 | new_size = int(num_patches ** 0.5) 227 | if orig_size != new_size: 228 | logger.info(f'Pos_emb from {orig_size} to {new_size}') 229 | extra_tokens = pos_embed_checkpoint[:1] 230 | pos_tokens = pos_embed_checkpoint[1:] 231 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 232 | pos_tokens = torch.nn.functional.interpolate( 233 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 234 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2) 235 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) 236 | state_dict['positional_embedding'] = new_pos_embed 237 | 238 | message = model.load_state_dict(state_dict, strict=False) 239 | logger.info(f"Load pretrained weights: {message}") 240 | 241 | 242 | @register_model 243 | def clip_joint_b16( 244 | pretrained=True, input_resolution=224, kernel_size=1, 245 | center=True, num_frames=8, drop_path=0. 246 | ): 247 | model = VisionTransformer( 248 | input_resolution=input_resolution, patch_size=16, 249 | width=768, layers=12, heads=12, output_dim=512, 250 | kernel_size=kernel_size, num_frames=num_frames, 251 | drop_path=drop_path, 252 | ) 253 | raise NotImplementedError 254 | if pretrained: 255 | logger.info('load pretrained weights') 256 | state_dict = torch.load(_MODELS["ViT-B/16"], map_location='cpu') 257 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center) 258 | return model.eval() 259 | 260 | 261 | @register_model 262 | def clip_joint_l14( 263 | pretrained=False, input_resolution=224, kernel_size=1, 264 | center=True, num_frames=8, drop_path=0., checkpoint_num=0, 265 | dropout=0., 266 | ): 267 | model = VisionTransformer( 268 | input_resolution=input_resolution, patch_size=14, 269 | width=1024, layers=24, heads=16, output_dim=768, 270 | kernel_size=kernel_size, num_frames=num_frames, 271 | drop_path=drop_path, checkpoint_num=checkpoint_num, 272 | dropout=dropout, 273 | ) 274 | if pretrained: 275 | if isinstance(pretrained, str): 276 | model_name = pretrained 277 | else: 278 | model_name = "ViT-L/14" 279 | logger.info('load pretrained weights') 280 | state_dict = torch.load(_MODELS[model_name], map_location='cpu') 281 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) 282 | return model.eval() 283 | 284 | 285 | @register_model 286 | def clip_joint_l14_336( 287 | pretrained=True, input_resolution=336, kernel_size=1, 288 | center=True, num_frames=8, drop_path=0. 289 | ): 290 | raise NotImplementedError 291 | model = VisionTransformer( 292 | input_resolution=input_resolution, patch_size=14, 293 | width=1024, layers=24, heads=16, output_dim=768, 294 | kernel_size=kernel_size, num_frames=num_frames, 295 | drop_path=drop_path, 296 | ) 297 | if pretrained: 298 | logger.info('load pretrained weights') 299 | state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu') 300 | load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) 301 | return model.eval() 302 | 303 | 304 | def interpolate_pos_embed_vit(state_dict, new_model): 305 | key = "vision_encoder.temporal_positional_embedding" 306 | if key in state_dict: 307 | vision_temp_embed_new = new_model.state_dict()[key] 308 | vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d] 309 | vision_temp_embed_old = state_dict[key] 310 | vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2) 311 | 312 | state_dict[key] = load_temp_embed_with_mismatch( 313 | vision_temp_embed_old, vision_temp_embed_new, add_zero=False 314 | ).squeeze(2) 315 | 316 | key = "text_encoder.positional_embedding" 317 | if key in state_dict: 318 | text_temp_embed_new = new_model.state_dict()[key] 319 | text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d] 320 | text_temp_embed_old = state_dict[key] 321 | text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2) 322 | 323 | state_dict[key] = load_temp_embed_with_mismatch( 324 | text_temp_embed_old, text_temp_embed_new, add_zero=False 325 | ).squeeze(2).squeeze(0) 326 | return state_dict -------------------------------------------------------------------------------- /third_party/gmflow/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import split_feature, merge_splits 6 | 7 | 8 | def single_head_full_attention(q, k, v): 9 | # q, k, v: [B, L, C] 10 | assert q.dim() == k.dim() == v.dim() == 3 11 | 12 | scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] 13 | attn = torch.softmax(scores, dim=2) # [B, L, L] 14 | out = torch.matmul(attn, v) # [B, L, C] 15 | 16 | return out 17 | 18 | 19 | def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, 20 | shift_size_h, shift_size_w, device=torch.device('cuda')): 21 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 22 | # calculate attention mask for SW-MSA 23 | h, w = input_resolution 24 | img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 25 | h_slices = (slice(0, -window_size_h), 26 | slice(-window_size_h, -shift_size_h), 27 | slice(-shift_size_h, None)) 28 | w_slices = (slice(0, -window_size_w), 29 | slice(-window_size_w, -shift_size_w), 30 | slice(-shift_size_w, None)) 31 | cnt = 0 32 | for h in h_slices: 33 | for w in w_slices: 34 | img_mask[:, h, w, :] = cnt 35 | cnt += 1 36 | 37 | mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) 38 | 39 | mask_windows = mask_windows.view(-1, window_size_h * window_size_w) 40 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 41 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 42 | 43 | return attn_mask 44 | 45 | 46 | def single_head_split_window_attention(q, k, v, 47 | num_splits=1, 48 | with_shift=False, 49 | h=None, 50 | w=None, 51 | attn_mask=None, 52 | ): 53 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 54 | # q, k, v: [B, L, C] 55 | assert q.dim() == k.dim() == v.dim() == 3 56 | 57 | assert h is not None and w is not None 58 | assert q.size(1) == h * w 59 | 60 | b, _, c = q.size() 61 | 62 | b_new = b * num_splits * num_splits 63 | 64 | window_size_h = h // num_splits 65 | window_size_w = w // num_splits 66 | 67 | q = q.view(b, h, w, c) # [B, H, W, C] 68 | k = k.view(b, h, w, c) 69 | v = v.view(b, h, w, c) 70 | 71 | scale_factor = c ** 0.5 72 | 73 | if with_shift: 74 | assert attn_mask is not None # compute once 75 | shift_size_h = window_size_h // 2 76 | shift_size_w = window_size_w // 2 77 | 78 | q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 79 | k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 80 | v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 81 | 82 | q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] 83 | k = split_feature(k, num_splits=num_splits, channel_last=True) 84 | v = split_feature(v, num_splits=num_splits, channel_last=True) 85 | 86 | scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) 87 | ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] 88 | 89 | if with_shift: 90 | scores += attn_mask.repeat(b, 1, 1) 91 | 92 | attn = torch.softmax(scores, dim=-1) 93 | 94 | out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] 95 | 96 | out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), 97 | num_splits=num_splits, channel_last=True) # [B, H, W, C] 98 | 99 | # shift back 100 | if with_shift: 101 | out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) 102 | 103 | out = out.view(b, -1, c) 104 | 105 | return out 106 | 107 | 108 | class TransformerLayer(nn.Module): 109 | def __init__(self, 110 | d_model=256, 111 | nhead=1, 112 | attention_type='swin', 113 | no_ffn=False, 114 | ffn_dim_expansion=4, 115 | with_shift=False, 116 | **kwargs, 117 | ): 118 | super(TransformerLayer, self).__init__() 119 | 120 | self.dim = d_model 121 | self.nhead = nhead 122 | self.attention_type = attention_type 123 | self.no_ffn = no_ffn 124 | 125 | self.with_shift = with_shift 126 | 127 | # multi-head attention 128 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 129 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 130 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 131 | 132 | self.merge = nn.Linear(d_model, d_model, bias=False) 133 | 134 | self.norm1 = nn.LayerNorm(d_model) 135 | 136 | # no ffn after self-attn, with ffn after cross-attn 137 | if not self.no_ffn: 138 | in_channels = d_model * 2 139 | self.mlp = nn.Sequential( 140 | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), 141 | nn.GELU(), 142 | nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), 143 | ) 144 | 145 | self.norm2 = nn.LayerNorm(d_model) 146 | 147 | def forward(self, source, target, 148 | height=None, 149 | width=None, 150 | shifted_window_attn_mask=None, 151 | attn_num_splits=None, 152 | **kwargs, 153 | ): 154 | # source, target: [B, L, C] 155 | query, key, value = source, target, target 156 | 157 | # single-head attention 158 | query = self.q_proj(query) # [B, L, C] 159 | key = self.k_proj(key) # [B, L, C] 160 | value = self.v_proj(value) # [B, L, C] 161 | 162 | if self.attention_type == 'swin' and attn_num_splits > 1: 163 | if self.nhead > 1: 164 | # we observe that multihead attention slows down the speed and increases the memory consumption 165 | # without bringing obvious performance gains and thus the implementation is removed 166 | raise NotImplementedError 167 | else: 168 | message = single_head_split_window_attention(query, key, value, 169 | num_splits=attn_num_splits, 170 | with_shift=self.with_shift, 171 | h=height, 172 | w=width, 173 | attn_mask=shifted_window_attn_mask, 174 | ) 175 | else: 176 | message = single_head_full_attention(query, key, value) # [B, L, C] 177 | 178 | message = self.merge(message) # [B, L, C] 179 | message = self.norm1(message) 180 | 181 | if not self.no_ffn: 182 | message = self.mlp(torch.cat([source, message], dim=-1)) 183 | message = self.norm2(message) 184 | 185 | return source + message 186 | 187 | 188 | class TransformerBlock(nn.Module): 189 | """self attention + cross attention + FFN""" 190 | 191 | def __init__(self, 192 | d_model=256, 193 | nhead=1, 194 | attention_type='swin', 195 | ffn_dim_expansion=4, 196 | with_shift=False, 197 | **kwargs, 198 | ): 199 | super(TransformerBlock, self).__init__() 200 | 201 | self.self_attn = TransformerLayer(d_model=d_model, 202 | nhead=nhead, 203 | attention_type=attention_type, 204 | no_ffn=True, 205 | ffn_dim_expansion=ffn_dim_expansion, 206 | with_shift=with_shift, 207 | ) 208 | 209 | self.cross_attn_ffn = TransformerLayer(d_model=d_model, 210 | nhead=nhead, 211 | attention_type=attention_type, 212 | ffn_dim_expansion=ffn_dim_expansion, 213 | with_shift=with_shift, 214 | ) 215 | 216 | def forward(self, source, target, 217 | height=None, 218 | width=None, 219 | shifted_window_attn_mask=None, 220 | attn_num_splits=None, 221 | **kwargs, 222 | ): 223 | # source, target: [B, L, C] 224 | 225 | # self attention 226 | source = self.self_attn(source, source, 227 | height=height, 228 | width=width, 229 | shifted_window_attn_mask=shifted_window_attn_mask, 230 | attn_num_splits=attn_num_splits, 231 | ) 232 | 233 | # cross attention and ffn 234 | source = self.cross_attn_ffn(source, target, 235 | height=height, 236 | width=width, 237 | shifted_window_attn_mask=shifted_window_attn_mask, 238 | attn_num_splits=attn_num_splits, 239 | ) 240 | 241 | return source 242 | 243 | 244 | class FeatureTransformer(nn.Module): 245 | def __init__(self, 246 | num_layers=6, 247 | d_model=128, 248 | nhead=1, 249 | attention_type='swin', 250 | ffn_dim_expansion=4, 251 | **kwargs, 252 | ): 253 | super(FeatureTransformer, self).__init__() 254 | 255 | self.attention_type = attention_type 256 | 257 | self.d_model = d_model 258 | self.nhead = nhead 259 | 260 | self.layers = nn.ModuleList([ 261 | TransformerBlock(d_model=d_model, 262 | nhead=nhead, 263 | attention_type=attention_type, 264 | ffn_dim_expansion=ffn_dim_expansion, 265 | with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, 266 | ) 267 | for i in range(num_layers)]) 268 | 269 | for p in self.parameters(): 270 | if p.dim() > 1: 271 | nn.init.xavier_uniform_(p) 272 | 273 | def forward(self, feature0, feature1, 274 | attn_num_splits=None, 275 | **kwargs, 276 | ): 277 | 278 | b, c, h, w = feature0.shape 279 | assert self.d_model == c 280 | 281 | feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 282 | feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 283 | 284 | if self.attention_type == 'swin' and attn_num_splits > 1: 285 | # global and refine use different number of splits 286 | window_size_h = h // attn_num_splits 287 | window_size_w = w // attn_num_splits 288 | 289 | # compute attn mask once 290 | shifted_window_attn_mask = generate_shift_window_attn_mask( 291 | input_resolution=(h, w), 292 | window_size_h=window_size_h, 293 | window_size_w=window_size_w, 294 | shift_size_h=window_size_h // 2, 295 | shift_size_w=window_size_w // 2, 296 | device=feature0.device, 297 | ) # [K*K, H/K*W/K, H/K*W/K] 298 | else: 299 | shifted_window_attn_mask = None 300 | 301 | # concat feature0 and feature1 in batch dimension to compute in parallel 302 | concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] 303 | concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] 304 | 305 | for layer in self.layers: 306 | concat0 = layer(concat0, concat1, 307 | height=h, 308 | width=w, 309 | shifted_window_attn_mask=shifted_window_attn_mask, 310 | attn_num_splits=attn_num_splits, 311 | ) 312 | 313 | # update feature1 314 | concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) 315 | 316 | feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] 317 | 318 | # reshape back 319 | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 320 | feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 321 | 322 | return feature0, feature1 323 | 324 | 325 | class FeatureFlowAttention(nn.Module): 326 | """ 327 | flow propagation with self-attention on feature 328 | query: feature0, key: feature0, value: flow 329 | """ 330 | 331 | def __init__(self, in_channels, 332 | **kwargs, 333 | ): 334 | super(FeatureFlowAttention, self).__init__() 335 | 336 | self.q_proj = nn.Linear(in_channels, in_channels) 337 | self.k_proj = nn.Linear(in_channels, in_channels) 338 | 339 | for p in self.parameters(): 340 | if p.dim() > 1: 341 | nn.init.xavier_uniform_(p) 342 | 343 | def forward(self, feature0, flow, 344 | local_window_attn=False, 345 | local_window_radius=1, 346 | **kwargs, 347 | ): 348 | # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] 349 | if local_window_attn: 350 | return self.forward_local_window_attn(feature0, flow, 351 | local_window_radius=local_window_radius) 352 | 353 | b, c, h, w = feature0.size() 354 | 355 | query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] 356 | 357 | # a note: the ``correct'' implementation should be: 358 | # ``query = self.q_proj(query), key = self.k_proj(query)'' 359 | # this problem is observed while cleaning up the code 360 | # however, this doesn't affect the performance since the projection is a linear operation, 361 | # thus the two projection matrices for key can be merged 362 | # so I just leave it as is in order to not re-train all models :) 363 | query = self.q_proj(query) # [B, H*W, C] 364 | key = self.k_proj(query) # [B, H*W, C] 365 | 366 | value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] 367 | 368 | scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] 369 | prob = torch.softmax(scores, dim=-1) 370 | 371 | out = torch.matmul(prob, value) # [B, H*W, 2] 372 | out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] 373 | 374 | return out 375 | 376 | def forward_local_window_attn(self, feature0, flow, 377 | local_window_radius=1, 378 | ): 379 | assert flow.size(1) == 2 380 | assert local_window_radius > 0 381 | 382 | b, c, h, w = feature0.size() 383 | 384 | feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) 385 | ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] 386 | 387 | kernel_size = 2 * local_window_radius + 1 388 | 389 | feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) 390 | 391 | feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, 392 | padding=local_window_radius) # [B, C*(2R+1)^2), H*W] 393 | 394 | feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( 395 | 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] 396 | 397 | flow_window = F.unfold(flow, kernel_size=kernel_size, 398 | padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] 399 | 400 | flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( 401 | 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] 402 | 403 | scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] 404 | 405 | prob = torch.softmax(scores, dim=-1) 406 | 407 | out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] 408 | 409 | return out -------------------------------------------------------------------------------- /third_party/dover/models/conv_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_, DropPath 5 | 6 | 7 | class GRN(nn.Module): 8 | """ GRN (Global Response Normalization) layer 9 | """ 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 13 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 14 | 15 | def forward(self, x): 16 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 17 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 18 | return self.gamma * (x * Nx) + self.beta + x 19 | 20 | class Block(nn.Module): 21 | r""" ConvNeXt Block. There are two equivalent implementations: 22 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 23 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 24 | We use (2) as we find it slightly faster in PyTorch 25 | 26 | Args: 27 | dim (int): Number of input channels. 28 | drop_path (float): Stochastic depth rate. Default: 0.0 29 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 30 | """ 31 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 32 | super().__init__() 33 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 34 | self.norm = LayerNorm(dim, eps=1e-6) 35 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 36 | self.act = nn.GELU() 37 | self.pwconv2 = nn.Linear(4 * dim, dim) 38 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 39 | requires_grad=True) if layer_scale_init_value > 0 else None 40 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 41 | 42 | def forward(self, x): 43 | input = x 44 | x = self.dwconv(x) 45 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 46 | x = self.norm(x) 47 | x = self.pwconv1(x) 48 | x = self.act(x) 49 | x = self.pwconv2(x) 50 | if self.gamma is not None: 51 | x = self.gamma * x 52 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 53 | 54 | x = input + self.drop_path(x) 55 | return x 56 | 57 | class ConvNeXt(nn.Module): 58 | r""" ConvNeXt 59 | A PyTorch impl of : `A ConvNet for the 2020s` - 60 | https://arxiv.org/pdf/2201.03545.pdf 61 | Args: 62 | in_chans (int): Number of input image channels. Default: 3 63 | num_classes (int): Number of classes for classification head. Default: 1000 64 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 65 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 66 | drop_path_rate (float): Stochastic depth rate. Default: 0. 67 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 68 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 69 | """ 70 | def __init__(self, in_chans=3, num_classes=1000, 71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 72 | layer_scale_init_value=1e-6, head_init_scale=1., 73 | ): 74 | super().__init__() 75 | 76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 77 | stem = nn.Sequential( 78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 80 | ) 81 | self.downsample_layers.append(stem) 82 | for i in range(3): 83 | downsample_layer = nn.Sequential( 84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 86 | ) 87 | self.downsample_layers.append(downsample_layer) 88 | 89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 90 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 91 | cur = 0 92 | for i in range(4): 93 | stage = nn.Sequential( 94 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 95 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 96 | ) 97 | self.stages.append(stage) 98 | cur += depths[i] 99 | 100 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 101 | self.head = nn.Linear(dims[-1], num_classes) 102 | 103 | self.apply(self._init_weights) 104 | self.head.weight.data.mul_(head_init_scale) 105 | self.head.bias.data.mul_(head_init_scale) 106 | 107 | def _init_weights(self, m): 108 | if isinstance(m, (nn.Conv2d, nn.Linear)): 109 | trunc_normal_(m.weight, std=.02) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def forward_features(self, x): 113 | for i in range(4): 114 | x = self.downsample_layers[i](x) 115 | x = self.stages[i](x) 116 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 117 | 118 | def forward(self, x): 119 | x = self.forward_features(x) 120 | x = self.head(x) 121 | return x 122 | 123 | class LayerNorm(nn.Module): 124 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 125 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 126 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 127 | with shape (batch_size, channels, height, width). 128 | """ 129 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 130 | super().__init__() 131 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 132 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 133 | self.eps = eps 134 | self.data_format = data_format 135 | if self.data_format not in ["channels_last", "channels_first"]: 136 | raise NotImplementedError 137 | self.normalized_shape = (normalized_shape, ) 138 | 139 | def forward(self, x): 140 | if self.data_format == "channels_last": 141 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 142 | elif self.data_format == "channels_first": 143 | u = x.mean(1, keepdim=True) 144 | s = (x - u).pow(2).mean(1, keepdim=True) 145 | x = (x - u) / torch.sqrt(s + self.eps) 146 | if len(x.shape) == 4: 147 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 148 | elif len(x.shape) == 5: 149 | x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] 150 | return x 151 | 152 | 153 | class Block3D(nn.Module): 154 | r""" ConvNeXt Block. There are two equivalent implementations: 155 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 156 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 157 | We use (2) as we find it slightly faster in PyTorch 158 | 159 | Args: 160 | dim (int): Number of input channels. 161 | drop_path (float): Stochastic depth rate. Default: 0.0 162 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 163 | """ 164 | def __init__(self, dim, drop_path=0., inflate_len=3, layer_scale_init_value=1e-6): 165 | super().__init__() 166 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv 167 | self.norm = LayerNorm(dim, eps=1e-6) 168 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 169 | self.act = nn.GELU() 170 | self.pwconv2 = nn.Linear(4 * dim, dim) 171 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 172 | requires_grad=True) if layer_scale_init_value > 0 else None 173 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 174 | 175 | def forward(self, x): 176 | input = x 177 | x = self.dwconv(x) 178 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) 179 | x = self.norm(x) 180 | x = self.pwconv1(x) 181 | x = self.act(x) 182 | x = self.pwconv2(x) 183 | if self.gamma is not None: 184 | x = self.gamma * x 185 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) 186 | 187 | x = input + self.drop_path(x) 188 | return x 189 | 190 | class BlockV2(nn.Module): 191 | """ ConvNeXtV2 Block. 192 | 193 | Args: 194 | dim (int): Number of input channels. 195 | drop_path (float): Stochastic depth rate. Default: 0.0 196 | """ 197 | def __init__(self, dim, drop_path=0.): 198 | super().__init__() 199 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 200 | self.norm = LayerNorm(dim, eps=1e-6) 201 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 202 | self.act = nn.GELU() 203 | self.grn = GRN(4 * dim) 204 | self.pwconv2 = nn.Linear(4 * dim, dim) 205 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 206 | 207 | def forward(self, x): 208 | input = x 209 | x = self.dwconv(x) 210 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 211 | x = self.norm(x) 212 | x = self.pwconv1(x) 213 | x = self.act(x) 214 | x = self.grn(x) 215 | x = self.pwconv2(x) 216 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 217 | 218 | x = input + self.drop_path(x) 219 | return x 220 | 221 | class BlockV23D(nn.Module): 222 | """ ConvNeXtV2 Block. 223 | 224 | Args: 225 | dim (int): Number of input channels. 226 | drop_path (float): Stochastic depth rate. Default: 0.0 227 | """ 228 | def __init__(self, dim, drop_path=0., inflate_len=3,): 229 | super().__init__() 230 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv 231 | self.norm = LayerNorm(dim, eps=1e-6) 232 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 233 | self.act = nn.GELU() 234 | self.grn = GRN(4 * dim) 235 | self.pwconv2 = nn.Linear(4 * dim, dim) 236 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 237 | 238 | def forward(self, x): 239 | input = x 240 | x = self.dwconv(x) 241 | x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) 242 | x = self.norm(x) 243 | x = self.pwconv1(x) 244 | x = self.act(x) 245 | x = self.grn(x) 246 | x = self.pwconv2(x) 247 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) 248 | 249 | x = input + self.drop_path(x) 250 | return x 251 | 252 | class ConvNeXtV2(nn.Module): 253 | """ ConvNeXt V2 254 | 255 | Args: 256 | in_chans (int): Number of input image channels. Default: 3 257 | num_classes (int): Number of classes for classification head. Default: 1000 258 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 259 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 260 | drop_path_rate (float): Stochastic depth rate. Default: 0. 261 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 262 | """ 263 | def __init__(self, in_chans=3, num_classes=1000, 264 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 265 | drop_path_rate=0., head_init_scale=1. 266 | ): 267 | super().__init__() 268 | self.depths = depths 269 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 270 | stem = nn.Sequential( 271 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 272 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 273 | ) 274 | self.downsample_layers.append(stem) 275 | for i in range(3): 276 | downsample_layer = nn.Sequential( 277 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 278 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 279 | ) 280 | self.downsample_layers.append(downsample_layer) 281 | 282 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 283 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 284 | cur = 0 285 | for i in range(4): 286 | stage = nn.Sequential( 287 | *[BlockV2(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 288 | ) 289 | self.stages.append(stage) 290 | cur += depths[i] 291 | 292 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 293 | self.head = nn.Linear(dims[-1], num_classes) 294 | 295 | self.apply(self._init_weights) 296 | self.head.weight.data.mul_(head_init_scale) 297 | self.head.bias.data.mul_(head_init_scale) 298 | 299 | def _init_weights(self, m): 300 | if isinstance(m, (nn.Conv2d, nn.Linear)): 301 | trunc_normal_(m.weight, std=.02) 302 | nn.init.constant_(m.bias, 0) 303 | 304 | def forward_features(self, x): 305 | for i in range(4): 306 | x = self.downsample_layers[i](x) 307 | x = self.stages[i](x) 308 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 309 | 310 | def forward(self, x): 311 | x = self.forward_features(x) 312 | x = self.head(x) 313 | return x 314 | 315 | def convnextv2_atto(**kwargs): 316 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 317 | return model 318 | 319 | def convnextv2_femto(**kwargs): 320 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 321 | return model 322 | 323 | def convnext_pico(**kwargs): 324 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 325 | return model 326 | 327 | def convnextv2_nano(**kwargs): 328 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 329 | return model 330 | 331 | def convnextv2_tiny(**kwargs): 332 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 333 | return model 334 | 335 | def convnextv2_base(**kwargs): 336 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 337 | return model 338 | 339 | def convnextv2_large(**kwargs): 340 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 341 | return model 342 | 343 | def convnextv2_huge(**kwargs): 344 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 345 | return model 346 | 347 | class ConvNeXt3D(nn.Module): 348 | r""" ConvNeXt 349 | A PyTorch impl of : `A ConvNet for the 2020s` - 350 | https://arxiv.org/pdf/2201.03545.pdf 351 | Args: 352 | in_chans (int): Number of input image channels. Default: 3 353 | num_classes (int): Number of classes for classification head. Default: 1000 354 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 355 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 356 | drop_path_rate (float): Stochastic depth rate. Default: 0. 357 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 358 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 359 | """ 360 | def __init__(self, in_chans=3, num_classes=1000, 361 | inflate_strategy='131', 362 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 363 | layer_scale_init_value=1e-6, head_init_scale=1., 364 | ): 365 | super().__init__() 366 | 367 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 368 | stem = nn.Sequential( 369 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), 370 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 371 | ) 372 | self.downsample_layers.append(stem) 373 | for i in range(3): 374 | downsample_layer = nn.Sequential( 375 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 376 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), 377 | ) 378 | self.downsample_layers.append(downsample_layer) 379 | 380 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 381 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 382 | cur = 0 383 | for i in range(4): 384 | stage = nn.Sequential( 385 | *[Block3D(dim=dims[i], inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), 386 | drop_path=dp_rates[cur + j], 387 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 388 | ) 389 | self.stages.append(stage) 390 | cur += depths[i] 391 | 392 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 393 | 394 | self.apply(self._init_weights) 395 | 396 | def inflate_weights(self, s_state_dict): 397 | t_state_dict = self.state_dict() 398 | from collections import OrderedDict 399 | for key in t_state_dict.keys(): 400 | if key not in s_state_dict: 401 | # print(key) 402 | continue 403 | if t_state_dict[key].shape != s_state_dict[key].shape: 404 | t = t_state_dict[key].shape[2] 405 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t 406 | self.load_state_dict(s_state_dict, strict=False) 407 | 408 | def _init_weights(self, m): 409 | if isinstance(m, (nn.Conv3d, nn.Linear)): 410 | trunc_normal_(m.weight, std=.02) 411 | nn.init.constant_(m.bias, 0) 412 | 413 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1): 414 | if multi: 415 | xs = [] 416 | for i in range(4): 417 | x = self.downsample_layers[i](x) 418 | x = self.stages[i](x) 419 | if multi: 420 | xs.append(x) 421 | if return_spatial: 422 | if multi: 423 | shape = xs[-1].shape[2:] 424 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) 425 | elif layer > -1: 426 | return xs[layer] 427 | else: 428 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) 429 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) 430 | 431 | def forward(self, x, multi=False, layer=-1): 432 | x = self.forward_features(x, True, multi=multi, layer=layer) 433 | return x 434 | 435 | 436 | class ConvNeXtV23D(nn.Module): 437 | """ ConvNeXt V2 438 | 439 | Args: 440 | in_chans (int): Number of input image channels. Default: 3 441 | num_classes (int): Number of classes for classification head. Default: 1000 442 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 443 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 444 | drop_path_rate (float): Stochastic depth rate. Default: 0. 445 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 446 | """ 447 | def __init__(self, in_chans=3, num_classes=1000, 448 | inflate_strategy='131', 449 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 450 | drop_path_rate=0., head_init_scale=1. 451 | ): 452 | super().__init__() 453 | self.depths = depths 454 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 455 | stem = nn.Sequential( 456 | nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), 457 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 458 | ) 459 | self.downsample_layers.append(stem) 460 | for i in range(3): 461 | downsample_layer = nn.Sequential( 462 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 463 | nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), 464 | ) 465 | self.downsample_layers.append(downsample_layer) 466 | 467 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 468 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 469 | cur = 0 470 | for i in range(4): 471 | stage = nn.Sequential( 472 | *[BlockV23D(dim=dims[i], drop_path=dp_rates[cur + j], 473 | inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), 474 | ) for j in range(depths[i])] 475 | ) 476 | self.stages.append(stage) 477 | cur += depths[i] 478 | 479 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 480 | self.head = nn.Linear(dims[-1], num_classes) 481 | 482 | self.apply(self._init_weights) 483 | self.head.weight.data.mul_(head_init_scale) 484 | self.head.bias.data.mul_(head_init_scale) 485 | 486 | def inflate_weights(self, pretrained_path): 487 | t_state_dict = self.state_dict() 488 | s_state_dict = torch.load(pretrained_path)["model"] 489 | from collections import OrderedDict 490 | for key in t_state_dict.keys(): 491 | if key not in s_state_dict: 492 | # print(key) 493 | continue 494 | if t_state_dict[key].shape != s_state_dict[key].shape: 495 | # print(t_state_dict[key].shape, s_state_dict[key].shape) 496 | t = t_state_dict[key].shape[2] 497 | s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t 498 | self.load_state_dict(s_state_dict, strict=False) 499 | 500 | def _init_weights(self, m): 501 | if isinstance(m, (nn.Conv3d, nn.Linear)): 502 | trunc_normal_(m.weight, std=.02) 503 | nn.init.constant_(m.bias, 0) 504 | 505 | def forward_features(self, x, return_spatial=False, multi=False, layer=-1): 506 | if multi: 507 | xs = [] 508 | for i in range(4): 509 | x = self.downsample_layers[i](x) 510 | x = self.stages[i](x) 511 | if multi: 512 | xs.append(x) 513 | if return_spatial: 514 | if multi: 515 | shape = xs[-1].shape[2:] 516 | return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) 517 | elif layer > -1: 518 | return xs[layer] 519 | else: 520 | return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) 521 | return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) 522 | 523 | def forward(self, x, multi=False, layer=-1): 524 | x = self.forward_features(x, True, multi=multi, layer=layer) 525 | return x 526 | 527 | 528 | model_urls = { 529 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 530 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 531 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 532 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 533 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 534 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 535 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 536 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 537 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 538 | } 539 | 540 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 541 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 542 | if pretrained: 543 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 544 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 545 | model.load_state_dict(checkpoint["model"]) 546 | return model 547 | 548 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 549 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 550 | if pretrained: 551 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 552 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 553 | model.load_state_dict(checkpoint["model"]) 554 | return model 555 | 556 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 557 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 558 | if pretrained: 559 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 560 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 561 | model.load_state_dict(checkpoint["model"]) 562 | return model 563 | 564 | 565 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 566 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 567 | if pretrained: 568 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 569 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 570 | model.load_state_dict(checkpoint["model"]) 571 | return model 572 | 573 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 574 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 575 | if pretrained: 576 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 577 | url = model_urls['convnext_xlarge_22k'] 578 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 579 | model.load_state_dict(checkpoint["model"]) 580 | 581 | return model 582 | 583 | def convnext_3d_tiny(pretrained=False, in_22k=False, **kwargs): 584 | # print("Using Imagenet 22K pretrain", in_22k) 585 | model = ConvNeXt3D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 586 | if pretrained: 587 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 588 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 589 | model.inflate_weights(checkpoint["model"]) 590 | return model 591 | 592 | def convnext_3d_small(pretrained=False, in_22k=False, **kwargs): 593 | model = ConvNeXt3D(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 594 | if pretrained: 595 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 596 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 597 | model.inflate_weights(checkpoint["model"]) 598 | 599 | return model 600 | 601 | def convnextv2_3d_atto(**kwargs): 602 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 603 | 604 | return model 605 | 606 | def convnextv2_3d_femto(pretrained="../pretrained/convnextv2_femto_1k_224_ema.pt", **kwargs): 607 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 608 | #model.inflate_weights(pretrained) 609 | return model 610 | 611 | def convnextv2_3d_pico(pretrained="../pretrained/convnextv2_pico_1k_224_ema.pt", **kwargs): 612 | model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 613 | #model.inflate_weights(pretrained) 614 | return model 615 | 616 | def convnextv2_3d_nano(pretrained="../pretrained/convnextv2_nano_1k_224_ema.pt", **kwargs): 617 | model = ConvNeXtV23D(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 618 | #model.inflate_weights(pretrained) 619 | return model 620 | 621 | def convnextv2_tiny(**kwargs): 622 | model = ConvNeXtV23D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 623 | return model 624 | 625 | def convnextv2_base(**kwargs): 626 | model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 627 | return model 628 | 629 | def convnextv2_large(**kwargs): 630 | model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 631 | return model 632 | 633 | def convnextv2_huge(**kwargs): 634 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 635 | return model --------------------------------------------------------------------------------