├── dataset ├── __init__.py └── lavdf.py ├── model ├── __init__.py ├── frame_classifier.py ├── fusion_module.py ├── boundary_module.py ├── audio_encoder.py ├── video_encoder.py ├── batfd.py ├── boundary_module_plus.py └── batfd_plus.py ├── requirements.txt ├── config ├── batfd_default.toml └── batfd_plus_default.toml ├── CITATION.cff ├── TERMS_AND_CONDITIONS.md ├── post_process.py ├── loss.py ├── evaluate.py ├── train.py ├── metrics.py ├── inference.py ├── README.md ├── utils.py ├── .gitignore └── LICENSE /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .lavdf import Lavdf, LavdfDataModule 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .batfd import Batfd 2 | from .batfd_plus import BatfdPlus 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.0 2 | numpy>=1.20.1 3 | torchaudio>=0.13.0 4 | torchvision>=0.14.0 5 | tqdm>=4.41.1 6 | av>=8.0.3 7 | einops>=0.3.0 8 | pytorch_lightning==1.7.* 9 | torchmetrics==0.7.* 10 | scipy>=1.7.3 11 | 12 | pandas>=1.2.0 13 | toml>=0.10.0 14 | -------------------------------------------------------------------------------- /model/frame_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import Module 3 | 4 | from utils import Conv1d 5 | 6 | 7 | class FrameLogisticRegression(Module): 8 | """ 9 | Frame classifier (FC_v and FC_a) for video feature (F_v) and audio feature (F_a). 10 | Input: 11 | F_v or F_a: (B, C_f, T) 12 | Output: 13 | Y^: (B, 1, T) 14 | """ 15 | 16 | def __init__(self, n_features: int): 17 | super().__init__() 18 | self.lr_layer = Conv1d(n_features, 1, kernel_size=1) 19 | 20 | def forward(self, features: Tensor) -> Tensor: 21 | return self.lr_layer(features) 22 | -------------------------------------------------------------------------------- /config/batfd_default.toml: -------------------------------------------------------------------------------- 1 | name = "batfd_default" 2 | num_frames = 512 # T 3 | max_duration = 40 # D 4 | model_type = "batfd" 5 | dataset = "lavdf" 6 | 7 | [model.video_encoder] 8 | type = "c3d" 9 | hidden_dims = [64, 96, 128, 128] 10 | cla_feature_in = 256 # C_f 11 | 12 | [model.audio_encoder] 13 | type = "cnn" 14 | hidden_dims = [32, 64, 64] 15 | cla_feature_in = 256 # C_f 16 | 17 | [model.frame_classifier] 18 | type = "lr" 19 | 20 | [model.boundary_module] 21 | hidden_dims = [512, 128] 22 | samples = 10 # N 23 | 24 | [optimizer] 25 | learning_rate = 0.00001 26 | frame_loss_weight = 2.0 27 | modal_bm_loss_weight = 1.0 28 | contrastive_loss_weight = 0.1 29 | contrastive_loss_margin = 0.99 30 | weight_decay = 0.0001 31 | 32 | [soft_nms] 33 | alpha = 0.7234 34 | t1 = 0.1968 35 | t2 = 0.4123 36 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find this work useful in your research, please cite it." 3 | preferred-citation: 4 | type: article 5 | title: "Glitch in the Matrix: A Large Scale Benchmark for Content Driven Audio-Visual Forgery Detection and Localization" 6 | authors: 7 | - family-names: "Cai" 8 | given-names: "Zhixi" 9 | - family-names: "Ghosh" 10 | given-names: "Shreya" 11 | - family-names: "Dhall" 12 | given-names: "Abhinav" 13 | - family-names: "Gedeon" 14 | given-names: "Tom" 15 | - family-names: "Stefanov" 16 | given-names: "Kalin" 17 | - family-names: "Hayat" 18 | given-names: "Munawar" 19 | journal: "Computer Vision and Image Understanding" 20 | year: 2023 21 | volume: 236 22 | start: 103818 23 | doi: "10.1016/j.cviu.2023.103818" 24 | -------------------------------------------------------------------------------- /config/batfd_plus_default.toml: -------------------------------------------------------------------------------- 1 | name = "batfd_plus_default" 2 | num_frames = 512 # T 3 | max_duration = 40 # D 4 | model_type = "batfd_plus" 5 | dataset = "lavdf" 6 | 7 | [model.video_encoder] 8 | type = "mvit_b" 9 | hidden_dims = [] # handled by model type 10 | cla_feature_in = 256 # C_f 11 | 12 | [model.audio_encoder] 13 | type = "vit_b" 14 | hidden_dims = [] # handled by model type 15 | cla_feature_in = 256 # C_f 16 | 17 | [model.frame_classifier] 18 | type = "lr" 19 | 20 | [model.boundary_module] 21 | hidden_dims = [512, 128] 22 | samples = 10 # N 23 | 24 | [optimizer] 25 | learning_rate = 0.00001 26 | frame_loss_weight = 2.0 27 | modal_bm_loss_weight = 1.0 28 | cbg_feature_weight = 0.0 29 | prb_weight_forward = 1.0 30 | contrastive_loss_weight = 0.1 31 | contrastive_loss_margin = 0.99 32 | weight_decay = 0.0001 33 | 34 | [soft_nms] 35 | alpha = 0.7234 36 | t1 = 0.1968 37 | t2 = 0.4123 38 | -------------------------------------------------------------------------------- /TERMS_AND_CONDITIONS.md: -------------------------------------------------------------------------------- 1 | ## Terms and Conditions of LAV-DF 2 | 3 | The users should agree to the terms and conditions to use the LAV-DF dataset. In exchange for such permission, the users hereby agree to the following terms and conditions: 4 | 5 | - The dataset can only be used for non-commercial research and educational purposes. 6 | - You understand that the LAV-DF dataset is a deepfake dataset generated based on Voxceleb2. You also agree to all agreements of the VoxCeleb2 dataset. 7 | - The authors of the dataset make no representations or warranties regarding the dataset, including but not limited to warranties of non-infringement or fitness for a particular purpose. 8 | - You accept full responsibility for your use of the dataset and shall defend and indemnify the Authors of LAV-DF, against any and all claims arising from your use of the dataset, including but not limited to your use of any copies of copyrighted images that you may create from the dataset. 9 | -------------------------------------------------------------------------------- /model/fusion_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Sigmoid, Module 4 | 5 | from utils import Conv1d 6 | 7 | 8 | class ModalFeatureAttnBoundaryMapFusion(Module): 9 | """ 10 | Fusion module for video and audio boundary maps. 11 | 12 | Input: 13 | F_v: (B, C_f, T) 14 | F_a: (B, C_f, T) 15 | M_v^: (B, D, T) 16 | M_a^: (B, D, T) 17 | 18 | Output: 19 | M^: (B, D, T) 20 | """ 21 | 22 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257, max_duration: int = 40): 23 | super().__init__() 24 | 25 | self.a_attn_block = ModalMapAttnBlock(n_audio_features, n_video_features, max_duration) 26 | self.v_attn_block = ModalMapAttnBlock(n_video_features, n_audio_features, max_duration) 27 | 28 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_bm: Tensor, audio_bm: Tensor) -> Tensor: 29 | a_attn = self.a_attn_block(audio_bm, audio_feature, video_feature) 30 | v_attn = self.v_attn_block(video_bm, video_feature, audio_feature) 31 | 32 | sum_attn = a_attn + v_attn 33 | 34 | a_w = a_attn / sum_attn 35 | v_w = v_attn / sum_attn 36 | 37 | fusion_bm = video_bm * v_w + audio_bm * a_w 38 | return fusion_bm 39 | 40 | 41 | class ModalMapAttnBlock(Module): 42 | 43 | def __init__(self, n_self_features: int, n_another_features: int, max_duration: int = 40): 44 | super().__init__() 45 | self.attn_from_self_features = Conv1d(n_self_features, max_duration, kernel_size=1) 46 | self.attn_from_another_features = Conv1d(n_another_features, max_duration, kernel_size=1) 47 | self.attn_from_bm = Conv1d(max_duration, max_duration, kernel_size=1) 48 | self.sigmoid = Sigmoid() 49 | 50 | def forward(self, self_bm: Tensor, self_features: Tensor, another_features: Tensor) -> Tensor: 51 | w_bm = self.attn_from_bm(self_bm) 52 | w_self_feat = self.attn_from_self_features(self_features) 53 | w_another_feat = self.attn_from_another_features(another_features) 54 | w_stack = torch.stack((w_bm, w_self_feat, w_another_feat), dim=3) 55 | w = w_stack.mean(dim=3) 56 | return self.sigmoid(w) 57 | 58 | 59 | class ModalFeatureAttnCfgFusion(ModalFeatureAttnBoundaryMapFusion): 60 | 61 | def __init__(self, n_video_features: int = 257, n_audio_features: int = 257): 62 | super().__init__() 63 | self.a_attn_block = ModalCbgAttnBlock(n_audio_features, n_video_features) 64 | self.v_attn_block = ModalCbgAttnBlock(n_video_features, n_audio_features) 65 | 66 | def forward(self, video_feature: Tensor, audio_feature: Tensor, video_cfg: Tensor, audio_cfg: Tensor) -> Tensor: 67 | video_cfg = video_cfg.unsqueeze(1) 68 | audio_cfg = audio_cfg.unsqueeze(1) 69 | fusion_cfg = super().forward(video_feature, audio_feature, video_cfg, audio_cfg) 70 | return fusion_cfg.squeeze(1) 71 | 72 | 73 | class ModalCbgAttnBlock(ModalMapAttnBlock): 74 | 75 | def __init__(self, n_self_features: int, n_another_features: int): 76 | super().__init__(n_self_features, n_another_features, 1) 77 | -------------------------------------------------------------------------------- /post_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | from concurrent.futures import ProcessPoolExecutor 4 | from os import cpu_count 5 | from typing import List 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm.auto import tqdm 10 | 11 | from dataset.lavdf import Metadata 12 | from utils import iou_with_anchors 13 | 14 | 15 | def soft_nms(df, alpha, t1, t2, fps): 16 | df = df.sort_values(by="score", ascending=False) 17 | t_start = list(df.begin.values[:] / fps) 18 | t_end = list(df.end.values[:] / fps) 19 | t_score = list(df.score.values[:]) 20 | 21 | r_start = [] 22 | r_end = [] 23 | r_score = [] 24 | 25 | while len(t_score) > 1 and len(r_score) < 101: 26 | max_index = t_score.index(max(t_score)) 27 | tmp_iou_list = iou_with_anchors( 28 | np.array(t_start), 29 | np.array(t_end), t_start[max_index], t_end[max_index]) 30 | for idx in range(0, len(t_score)): 31 | if idx != max_index: 32 | tmp_iou = tmp_iou_list[idx] 33 | tmp_width = t_end[max_index] - t_start[max_index] 34 | if tmp_iou > t1 + (t2 - t1) * tmp_width: 35 | t_score[idx] *= np.exp(-np.square(tmp_iou) / alpha) 36 | 37 | r_start.append(t_start[max_index]) 38 | r_end.append(t_end[max_index]) 39 | r_score.append(t_score[max_index]) 40 | t_start.pop(max_index) 41 | t_end.pop(max_index) 42 | t_score.pop(max_index) 43 | 44 | new_df = pd.DataFrame() 45 | new_df['score'] = r_score 46 | new_df['begin'] = r_start 47 | new_df['end'] = r_end 48 | 49 | new_df['begin'] *= fps 50 | new_df['end'] *= fps 51 | return new_df 52 | 53 | 54 | def video_post_process(meta, model_name, fps=25, alpha=0.4, t1=0.2, t2=0.9, dataset_name="lavdf"): 55 | file = resolve_csv_file_name(meta, dataset_name) 56 | df = pd.read_csv(os.path.join("output", "results", model_name, file)) 57 | 58 | if len(df) > 1: 59 | df = soft_nms(df, alpha, t1, t2, fps) 60 | 61 | df = df.sort_values(by="score", ascending=False) 62 | 63 | proposal_list = [] 64 | 65 | for j in range(len(df)): 66 | proposal_list.append([ 67 | df.score.values[j], 68 | df.begin.values[j].item(), 69 | df.end.values[j].item() 70 | ]) 71 | 72 | return [meta.file, proposal_list] 73 | 74 | 75 | def resolve_csv_file_name(meta: Metadata, dataset_name: str = "lavdf") -> str: 76 | if dataset_name == "lavdf": 77 | return meta.file.split("/")[-1].replace(".mp4", ".csv") 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | def post_process(model_name: str, metadata: List[Metadata], fps=25, 83 | alpha=0.4, t1=0.2, t2=0.9, dataset_name="lavdf" 84 | ): 85 | with ProcessPoolExecutor(cpu_count() // 2 - 1) as executor: 86 | futures = [] 87 | for meta in metadata: 88 | futures.append(executor.submit(video_post_process, meta, model_name, fps, 89 | alpha, t1, t2, dataset_name 90 | )) 91 | 92 | results = dict(map(lambda x: x.result(), tqdm(futures))) 93 | 94 | with open(os.path.join("output", "results", f"{model_name}.json"), "w") as f: 95 | json.dump(results, f, indent=4) 96 | -------------------------------------------------------------------------------- /model/boundary_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops.layers.torch import Rearrange 4 | from torch import Tensor 5 | from torch.nn import Sequential, LeakyReLU, Sigmoid, Module 6 | 7 | from utils import Conv3d, Conv2d 8 | 9 | 10 | class BoundaryModule(Module): 11 | """ 12 | Boundary matching module for video or audio features. 13 | Input: 14 | F_v or F_a: (B, C_f, T) 15 | Output: 16 | M_v^ or M_a^: (B, D, T) 17 | 18 | """ 19 | 20 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512, 21 | max_duration: int = 40 22 | ): 23 | super().__init__() 24 | 25 | dim0, dim1 = n_features 26 | 27 | # (B, n_feature_in, temporal_dim) -> (B, n_feature_in, sample, max_duration, temporal_dim) 28 | self.bm_layer = BMLayer(temporal_dim, num_samples, max_duration) 29 | 30 | # (B, n_feature_in, sample, max_duration, temporal_dim) -> (B, dim0, max_duration, temporal_dim) 31 | self.block0 = Sequential( 32 | Conv3d(n_feature_in, dim0, kernel_size=(num_samples, 1, 1), stride=(num_samples, 1, 1), 33 | build_activation=LeakyReLU 34 | ), 35 | Rearrange("b c n d t -> b c (n d) t") 36 | ) 37 | 38 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim) 39 | self.block1 = Sequential( 40 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU), 41 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU), 42 | Conv2d(dim1, 1, kernel_size=1, build_activation=Sigmoid), 43 | Rearrange("b c d t -> b (c d) t") 44 | ) 45 | 46 | def forward(self, feature: Tensor) -> Tensor: 47 | feature = self.bm_layer(feature) 48 | feature = self.block0(feature) 49 | feature = self.block1(feature) 50 | return feature 51 | 52 | 53 | class BMLayer(Module): 54 | """BM Layer""" 55 | 56 | def __init__(self, temporal_dim: int, num_sample: int, max_duration: int, roi_expand_ratio: float = 0.5): 57 | super().__init__() 58 | self.temporal_dim = temporal_dim 59 | # self.feat_dim = opt['bmn_feat_dim'] 60 | self.num_sample = num_sample 61 | self.duration = max_duration 62 | self.roi_expand_ratio = roi_expand_ratio 63 | self.smp_weight = self.get_pem_smp_weight() 64 | 65 | def get_pem_smp_weight(self): 66 | T = self.temporal_dim 67 | N = self.num_sample 68 | D = self.duration 69 | w = torch.zeros([T, N, D, T]) # T * N * D * T 70 | # In each temporal location i, there are D predefined proposals, 71 | # with length ranging between 1 and D 72 | # the j-th proposal is [i, i+j+1], 0<=j T - 1: 87 | continue 88 | left, right = int(np.floor(xp)), int(np.ceil(xp)) 89 | left_weight = 1 - (xp - left) 90 | right_weight = 1 - (right - xp) 91 | w[left, k, j, i] += left_weight 92 | w[right, k, j, i] += right_weight 93 | return w.view(T, -1).float() 94 | 95 | def _apply(self, fn): 96 | self.smp_weight = fn(self.smp_weight) 97 | 98 | def forward(self, X): 99 | input_size = X.size() 100 | assert (input_size[-1] == self.temporal_dim) 101 | # assert(len(input_size) == 3 and 102 | X_view = X.view(-1, input_size[-1]) 103 | # feature [bs*C, T] 104 | # smp_w [T, N*D*T] 105 | # out [bs*C, N*D*T] --> [bs, C, N, D, T] 106 | result = torch.matmul(X_view, self.smp_weight) 107 | return result.view(-1, input_size[1], self.num_sample, self.duration, self.temporal_dim) 108 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, MSELoss 4 | 5 | 6 | class MaskedBMLoss(Module): 7 | 8 | def __init__(self, loss_fn: Module): 9 | super().__init__() 10 | self.loss_fn = loss_fn 11 | 12 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 13 | loss = [] 14 | for i, frame in enumerate(n_frames): 15 | loss.append(self.loss_fn(pred[i, :, :frame], true[i, :, :frame])) 16 | return torch.mean(torch.stack(loss)) 17 | 18 | 19 | class MaskedFrameLoss(Module): 20 | 21 | def __init__(self, loss_fn: Module): 22 | super().__init__() 23 | self.loss_fn = loss_fn 24 | 25 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 26 | # input: (B, T) 27 | loss = [] 28 | for i, frame in enumerate(n_frames): 29 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame])) 30 | return torch.mean(torch.stack(loss)) 31 | 32 | 33 | class MaskedContrastLoss(Module): 34 | 35 | def __init__(self, margin: float = 0.99): 36 | super().__init__() 37 | self.margin = margin 38 | 39 | def forward(self, pred1: Tensor, pred2: Tensor, labels: Tensor, n_frames: Tensor): 40 | # input: (B, C, T) 41 | loss = [] 42 | for i, frame in enumerate(n_frames): 43 | # mean L2 distance squared 44 | d = torch.dist(pred1[i, :, :frame], pred2[i, :, :frame], 2) 45 | if labels[i]: 46 | # if is positive pair, minimize distance 47 | loss.append(d ** 2) 48 | else: 49 | # if is negative pair, minimize (margin - distance) if distance < margin 50 | loss.append(torch.clip(self.margin - d, min=0.) ** 2) 51 | return torch.mean(torch.stack(loss)) 52 | 53 | 54 | class MaskedMSE(Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | self.loss_fn = MSELoss() 59 | 60 | def forward(self, pred: Tensor, true: Tensor, n_frames: Tensor): 61 | loss = [] 62 | for i, frame in enumerate(n_frames): 63 | loss.append(self.loss_fn(pred[i, :frame], true[i, :frame])) 64 | return torch.mean(torch.stack(loss)) 65 | 66 | 67 | class MaskedBsnppLoss(Module): 68 | """Simplified version of BSN++ loss function.""" 69 | 70 | def __init__(self, cbg_feature_weight=0.01, prb_weight_forward=1): 71 | super().__init__() 72 | self.cbg_feature_weight = cbg_feature_weight 73 | self.prb_weight_forward = prb_weight_forward 74 | 75 | self.cbg_loss_func = MaskedMSE() 76 | self.cbg_feature_loss = MaskedBMLoss(MSELoss()) 77 | self.bsnpp_pem_reg_loss_func = self.cbg_feature_loss 78 | 79 | def forward(self, pred_bm_p, pred_bm_c, pred_bm_p_c, pred_start, pred_end, 80 | pred_start_backward, pred_end_backward, gt_iou_map, gt_start, gt_end, n_frames, 81 | feature_forward=None, feature_backward=None 82 | ): 83 | if self.cbg_feature_weight > 0: 84 | cbg_loss_forward = self.cbg_loss_func(pred_start, gt_start, n_frames) + \ 85 | self.cbg_loss_func(pred_end, gt_end, n_frames) 86 | cbg_loss_backward = self.cbg_loss_func(torch.flip(pred_end_backward, dims=(1,)), gt_start, n_frames) + \ 87 | self.cbg_loss_func(torch.flip(pred_start_backward, dims=(1,)), gt_end, n_frames) 88 | 89 | cbg_loss = cbg_loss_forward + cbg_loss_backward 90 | if feature_forward is not None and feature_backward is not None: 91 | inter_feature_loss = self.cbg_feature_weight * self.cbg_feature_loss(feature_forward, 92 | torch.flip(feature_backward, dims=(2,)), n_frames) 93 | cbg_loss += inter_feature_loss 94 | else: 95 | inter_feature_loss = None 96 | else: 97 | cbg_loss = None 98 | cbg_loss_forward = None 99 | cbg_loss_backward = None 100 | inter_feature_loss = None 101 | 102 | prb_reg_loss_p = self.bsnpp_pem_reg_loss_func(pred_bm_p, gt_iou_map, n_frames) 103 | prb_reg_loss_c = self.bsnpp_pem_reg_loss_func(pred_bm_c, gt_iou_map, n_frames) 104 | prb_reg_loss_p_c = self.bsnpp_pem_reg_loss_func(pred_bm_p_c, gt_iou_map, n_frames) 105 | prb_loss = prb_reg_loss_p + prb_reg_loss_c + prb_reg_loss_p_c 106 | 107 | loss = cbg_loss + prb_loss if cbg_loss is not None else prb_loss 108 | return loss, cbg_loss, prb_loss, cbg_loss_forward, cbg_loss_backward, inter_feature_loss 109 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import toml 5 | import torch 6 | 7 | from dataset.lavdf import LavdfDataModule 8 | from inference import inference_batfd 9 | from metrics import AP, AR 10 | from model import Batfd, BatfdPlus 11 | from post_process import post_process 12 | from utils import generate_metadata_min, read_json 13 | 14 | parser = argparse.ArgumentParser(description="BATFD evaluation") 15 | parser.add_argument("--config", type=str) 16 | parser.add_argument("--data_root", type=str) 17 | parser.add_argument("--checkpoint", type=str) 18 | parser.add_argument("--batch_size", type=int, default=4) 19 | parser.add_argument("--num_workers", type=int, default=8) 20 | parser.add_argument("--modalities", type=str, nargs="+", default=["fusion"]) 21 | parser.add_argument("--subset", type=str, nargs="+", default=["full"]) 22 | parser.add_argument("--gpus", type=int, default=1) 23 | 24 | 25 | def visual_subset_condition(meta): 26 | return not (meta.modify_video is False and meta.modify_audio is True) 27 | 28 | 29 | def audio_subset_condition(meta): 30 | return not (meta.modify_video is True and meta.modify_audio is False) 31 | 32 | 33 | conditions = { 34 | "full": None, 35 | "subset_for_visual_only": visual_subset_condition, 36 | "subset_for_audio_only": audio_subset_condition 37 | } 38 | 39 | 40 | def evaluate_lavdf(config, args): 41 | for modal in args.modalities: 42 | assert modal in ["fusion", "audio", "visual"] 43 | 44 | for subset in args.subset: 45 | assert subset in ["full", "subset_for_visual_only", "subset_for_audio_only"] 46 | 47 | model_name = config["name"] 48 | alpha = config["soft_nms"]["alpha"] 49 | t1 = config["soft_nms"]["t1"] 50 | t2 = config["soft_nms"]["t2"] 51 | 52 | model_type = config["model_type"] 53 | v_feature = None 54 | a_feature = None 55 | 56 | # prepare model 57 | if config["model_type"] == "batfd_plus": 58 | model = BatfdPlus.load_from_checkpoint(args.checkpoint) 59 | require_match_scores = True 60 | get_meta_attr = BatfdPlus.get_meta_attr 61 | elif config["model_type"] == "batfd": 62 | model = Batfd.load_from_checkpoint(args.checkpoint) 63 | require_match_scores = False 64 | get_meta_attr = Batfd.get_meta_attr 65 | else: 66 | raise ValueError("Invalid model type") 67 | 68 | # prepare dataset 69 | dm = LavdfDataModule( 70 | root=args.data_root, 71 | frame_padding=config["num_frames"], 72 | require_match_scores=require_match_scores, 73 | feature_types=(v_feature, a_feature), 74 | max_duration=config["max_duration"], 75 | batch_size=args.batch_size, num_workers=args.num_workers, 76 | get_meta_attr=get_meta_attr, 77 | return_file_name=True 78 | ) 79 | dm.setup() 80 | 81 | # inference and save dense proposals as csv file 82 | inference_batfd(model_name, model, dm, config["max_duration"], model_type, args.modalities, args.gpus) 83 | 84 | # postprocess by soft-nms 85 | for modality in args.modalities: 86 | proposal_file_name = f"{model_name}{'' if modality == 'fusion' else '_' + modality[0]}" 87 | post_process(proposal_file_name, dm.test_dataset.metadata, 25, alpha, t1, t2) 88 | 89 | for modality in args.modalities: 90 | proposal_file_name = f"{model_name}{'' if modality == 'fusion' else '_' + modality[0]}" 91 | proposals = read_json(f"output/results/{proposal_file_name}.json") 92 | 93 | for subset_name in args.subset: 94 | 95 | dm_subset = LavdfDataModule( 96 | root=args.data_root, 97 | frame_padding=config["num_frames"], 98 | require_match_scores=require_match_scores, 99 | max_duration=config["max_duration"], 100 | batch_size=1, num_workers=3, 101 | get_meta_attr=get_meta_attr, 102 | cond=conditions[subset_name] 103 | ) 104 | dm_subset.setup() 105 | 106 | metadata = dm_subset.test_dataset.metadata 107 | # evaluate AP 108 | iou_thresholds = [0.5, 0.75, 0.95] 109 | print("--------------------------------------------------") 110 | ap_score = AP(iou_thresholds=iou_thresholds)(metadata, proposals) 111 | for iou_threshold in iou_thresholds: 112 | print(f"AP@{iou_threshold} Score for {modality} modality in {subset_name} set: " 113 | f"{ap_score[iou_threshold]}") 114 | print("--------------------------------------------------") 115 | 116 | # evaluate AR 117 | iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 118 | n_proposals_list = [100, 50, 20, 10] 119 | 120 | ar_score = AR(n_proposals_list, iou_thresholds=iou_thresholds)(metadata, proposals) 121 | 122 | for n_proposals in n_proposals_list: 123 | print(f"AR@{n_proposals} Score for {modality} modality in {subset_name} set: " 124 | f"{ar_score[n_proposals]}") 125 | print("--------------------------------------------------") 126 | 127 | if __name__ == '__main__': 128 | args = parser.parse_args() 129 | 130 | if os.path.exists(os.path.join(args.data_root, "metadata.min.json")): 131 | generate_metadata_min(args.data_root) 132 | 133 | config = toml.load(args.config) 134 | torch.backends.cudnn.benchmark = True 135 | if config["dataset"] == "lavdf": 136 | evaluate_lavdf(config, args) 137 | else: 138 | raise NotImplementedError 139 | -------------------------------------------------------------------------------- /model/audio_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from einops import rearrange 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import Module, Sequential, LeakyReLU, MaxPool2d, Linear 7 | from torchvision.models.vision_transformer import Encoder as ViTEncoder 8 | 9 | from utils import Conv2d 10 | 11 | 12 | class CNNAudioEncoder(Module): 13 | """ 14 | Audio encoder (E_a): Process log mel spectrogram to extract features. 15 | Input: 16 | A': (B, F_m, T_a) 17 | Output: 18 | E_a: (B, C_f, T) 19 | """ 20 | 21 | def __init__(self, n_features=(32, 64, 64)): 22 | super().__init__() 23 | 24 | n_dim0, n_dim1, n_dim2 = n_features 25 | 26 | # (B, 64, 2048) -> (B, 1, 64, 2048) -> (B, 32, 32, 1024) 27 | self.block0 = Sequential( 28 | Rearrange("b c t -> b 1 c t"), 29 | Conv2d(1, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 30 | MaxPool2d(2) 31 | ) 32 | 33 | # (B, 32, 32, 1024) -> (B, 64, 16, 512) 34 | self.block1 = Sequential( 35 | Conv2d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 36 | Conv2d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 37 | MaxPool2d(2) 38 | ) 39 | 40 | # (B, 64, 16, 512) -> (B, 64, 4, 512) -> (B, 256, 512) 41 | self.block2 = Sequential( 42 | Conv2d(n_dim1, n_dim2, kernel_size=(2, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU), 43 | MaxPool2d((2, 1)), 44 | Conv2d(n_dim2, n_dim2, kernel_size=(3, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU), 45 | MaxPool2d((2, 1)), 46 | Rearrange("b f c t -> b (f c) t") 47 | ) 48 | 49 | def forward(self, audio: Tensor) -> Tensor: 50 | x = self.block0(audio) 51 | x = self.block1(x) 52 | x = self.block2(x) 53 | return x 54 | 55 | 56 | class SelfAttentionAudioEncoder(Module): 57 | 58 | def __init__(self, block_type: Literal["vit_t", "vit_s", "vit_b"], a_cla_feature_in: int = 256, temporal_size: int = 512): 59 | super().__init__() 60 | # The ViT configurations are from: 61 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 62 | if block_type == "vit_t": 63 | self.n_features = 192 64 | self.block = ViTEncoder( 65 | seq_length=temporal_size, 66 | num_layers=12, 67 | num_heads=3, 68 | hidden_dim=self.n_features, 69 | mlp_dim=self.n_features * 4, 70 | dropout=0., 71 | attention_dropout=0. 72 | ) 73 | elif block_type == "vit_s": 74 | self.n_features = 384 75 | self.block = ViTEncoder( 76 | seq_length=temporal_size, 77 | num_layers=12, 78 | num_heads=6, 79 | hidden_dim=self.n_features, 80 | mlp_dim=self.n_features * 4, 81 | dropout=0., 82 | attention_dropout=0. 83 | ) 84 | elif block_type == "vit_b": 85 | self.n_features = 768 86 | self.block = ViTEncoder( 87 | seq_length=temporal_size, 88 | num_layers=12, 89 | num_heads=12, 90 | hidden_dim=self.n_features, 91 | mlp_dim=self.n_features * 4, 92 | dropout=0., 93 | attention_dropout=0. 94 | ) 95 | else: 96 | raise ValueError(f"Unknown block type: {block_type}") 97 | 98 | self.input_proj = Conv2d(1, self.n_features, kernel_size=(64, 4), stride=(64, 4)) 99 | self.output_proj = Linear(self.n_features, a_cla_feature_in) 100 | 101 | def forward(self, audio: Tensor) -> Tensor: 102 | x = audio.unsqueeze(1) # (B, 64, 2048) -> (B, 1, 64, 2048) 103 | x = self.input_proj(x) # (B, 1, 64, 2048) -> (B, feat, 1, 512) 104 | x = rearrange(x, "b f 1 t -> b t f") # (B, feat, 1, 512) -> (B, 512, feat) 105 | x = self.block(x) 106 | x = self.output_proj(x) # (B, 512, feat) -> (B, 512, 256) 107 | x = x.permute(0, 2, 1) # (B, 512, 256) -> (B, 256, 512) 108 | return x 109 | 110 | 111 | class AudioFeatureProjection(Module): 112 | 113 | def __init__(self, input_feature_dim: int, a_cla_feature_in: int = 256): 114 | super().__init__() 115 | self.proj = Linear(input_feature_dim, a_cla_feature_in) 116 | 117 | def forward(self, x: Tensor) -> Tensor: 118 | x = self.proj(x) 119 | return x.permute(0, 2, 1) 120 | 121 | 122 | def get_audio_encoder(a_cla_feature_in, temporal_size, a_encoder, ae_features): 123 | if a_encoder == "cnn": 124 | audio_encoder = CNNAudioEncoder(n_features=ae_features) 125 | elif a_encoder == "vit_t": 126 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_t", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 127 | elif a_encoder == "vit_s": 128 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_s", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 129 | elif a_encoder == "vit_b": 130 | audio_encoder = SelfAttentionAudioEncoder(block_type="vit_b", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size) 131 | elif a_encoder == "wav2vec2": 132 | audio_encoder = AudioFeatureProjection(input_feature_dim=1536, a_cla_feature_in=a_cla_feature_in) 133 | elif a_encoder == "trillsson3": 134 | audio_encoder = AudioFeatureProjection(input_feature_dim=1280, a_cla_feature_in=a_cla_feature_in) 135 | else: 136 | raise ValueError(f"Invalid audio encoder: {a_encoder}") 137 | return audio_encoder 138 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import toml 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | 8 | from dataset.lavdf import LavdfDataModule 9 | from model import Batfd, BatfdPlus 10 | from utils import LrLogger, EarlyStoppingLR, generate_metadata_min 11 | 12 | parser = argparse.ArgumentParser(description="BATFD training") 13 | parser.add_argument("--config", type=str) 14 | parser.add_argument("--data_root", type=str) 15 | parser.add_argument("--batch_size", type=int, default=4) 16 | parser.add_argument("--num_workers", type=int, default=8) 17 | parser.add_argument("--gpus", type=int, default=1) 18 | parser.add_argument("--precision", default=32) 19 | parser.add_argument("--num_train", type=int, default=None) 20 | parser.add_argument("--num_val", type=int, default=1000) 21 | parser.add_argument("--max_epochs", type=int, default=500) 22 | parser.add_argument("--resume", type=str, default=None) 23 | 24 | if __name__ == '__main__': 25 | args = parser.parse_args() 26 | config = toml.load(args.config) 27 | 28 | if not os.path.exists(os.path.join(args.data_root, "metadata.min.json")): 29 | generate_metadata_min(args.data_root) 30 | 31 | learning_rate = config["optimizer"]["learning_rate"] 32 | gpus = args.gpus 33 | total_batch_size = args.batch_size * gpus 34 | learning_rate = learning_rate * total_batch_size / 4 35 | dataset = config["dataset"] 36 | 37 | v_encoder_type = config["model"]["video_encoder"]["type"] 38 | a_encoder_type = config["model"]["audio_encoder"]["type"] 39 | 40 | v_feature = None 41 | a_feature = None 42 | 43 | if config["model_type"] == "batfd_plus": 44 | model = BatfdPlus( 45 | v_encoder=v_encoder_type, 46 | a_encoder=config["model"]["audio_encoder"]["type"], 47 | frame_classifier=config["model"]["frame_classifier"]["type"], 48 | ve_features=config["model"]["video_encoder"]["hidden_dims"], 49 | ae_features=config["model"]["audio_encoder"]["hidden_dims"], 50 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"], 51 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"], 52 | boundary_features=config["model"]["boundary_module"]["hidden_dims"], 53 | boundary_samples=config["model"]["boundary_module"]["samples"], 54 | temporal_dim=config["num_frames"], 55 | max_duration=config["max_duration"], 56 | weight_frame_loss=config["optimizer"]["frame_loss_weight"], 57 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"], 58 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"], 59 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"], 60 | cbg_feature_weight=config["optimizer"]["cbg_feature_weight"], 61 | prb_weight_forward=config["optimizer"]["prb_weight_forward"], 62 | weight_decay=config["optimizer"]["weight_decay"], 63 | learning_rate=learning_rate, 64 | distributed=args.gpus > 1 65 | ) 66 | require_match_scores = True 67 | get_meta_attr = BatfdPlus.get_meta_attr 68 | elif config["model_type"] == "batfd": 69 | model = Batfd( 70 | v_encoder=config["model"]["video_encoder"]["type"], 71 | a_encoder=config["model"]["audio_encoder"]["type"], 72 | frame_classifier=config["model"]["frame_classifier"]["type"], 73 | ve_features=config["model"]["video_encoder"]["hidden_dims"], 74 | ae_features=config["model"]["audio_encoder"]["hidden_dims"], 75 | v_cla_feature_in=config["model"]["video_encoder"]["cla_feature_in"], 76 | a_cla_feature_in=config["model"]["audio_encoder"]["cla_feature_in"], 77 | boundary_features=config["model"]["boundary_module"]["hidden_dims"], 78 | boundary_samples=config["model"]["boundary_module"]["samples"], 79 | temporal_dim=config["num_frames"], 80 | max_duration=config["max_duration"], 81 | weight_frame_loss=config["optimizer"]["frame_loss_weight"], 82 | weight_modal_bm_loss=config["optimizer"]["modal_bm_loss_weight"], 83 | weight_contrastive_loss=config["optimizer"]["contrastive_loss_weight"], 84 | contrast_loss_margin=config["optimizer"]["contrastive_loss_margin"], 85 | weight_decay=config["optimizer"]["weight_decay"], 86 | learning_rate=learning_rate, 87 | distributed=args.gpus > 1 88 | ) 89 | require_match_scores = False 90 | get_meta_attr = Batfd.get_meta_attr 91 | else: 92 | raise ValueError("Invalid model type") 93 | 94 | if dataset == "lavdf": 95 | dm = LavdfDataModule( 96 | root=args.data_root, 97 | frame_padding=config["num_frames"], 98 | require_match_scores=require_match_scores, 99 | feature_types=(v_feature, a_feature), 100 | max_duration=config["max_duration"], 101 | batch_size=args.batch_size, num_workers=args.num_workers, 102 | take_train=args.num_train, take_dev=args.num_val, 103 | get_meta_attr=get_meta_attr 104 | ) 105 | else: 106 | raise ValueError("Invalid dataset type") 107 | 108 | try: 109 | precision = int(args.precision) 110 | except ValueError: 111 | precision = args.precision 112 | 113 | monitor = "val_fusion_bm_loss" 114 | 115 | trainer = Trainer(log_every_n_steps=50, precision=precision, max_epochs=args.max_epochs, 116 | callbacks=[ 117 | ModelCheckpoint( 118 | dirpath=f"./ckpt/{config['name']}", save_last=True, filename=config["name"] + "-{epoch}-{val_loss:.3f}", 119 | monitor=monitor, mode="min" 120 | ), 121 | LrLogger(), 122 | EarlyStoppingLR(lr_threshold=1e-7) 123 | ], enable_checkpointing=True, 124 | benchmark=True, 125 | accelerator="auto", 126 | devices=args.gpus, 127 | strategy=None if args.gpus < 2 else "ddp", 128 | resume_from_checkpoint=args.resume, 129 | ) 130 | 131 | trainer.fit(model, dm) 132 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from tqdm.auto import tqdm 6 | 7 | from dataset.lavdf import Metadata 8 | from utils import iou_1d 9 | 10 | 11 | class AP: 12 | """ 13 | Average Precision 14 | 15 | The mean precision in Precision-Recall curve. 16 | """ 17 | 18 | def __init__(self, iou_thresholds: Union[float, List[float]] = 0.5, tqdm_pos: int = 1): 19 | super().__init__() 20 | self.iou_thresholds: List[float] = iou_thresholds if type(iou_thresholds) is list else [iou_thresholds] 21 | self.tqdm_pos = tqdm_pos 22 | self.n_labels = 0 23 | self.ap: dict = {} 24 | 25 | def __call__(self, metadata: List[Metadata], proposals_dict: dict) -> dict: 26 | 27 | for iou_threshold in self.iou_thresholds: 28 | values = [] 29 | self.n_labels = 0 30 | 31 | for meta in tqdm(metadata): 32 | proposals = torch.tensor(proposals_dict[meta.file]) 33 | labels = torch.tensor(meta.fake_periods) 34 | values.append(AP.get_values(iou_threshold, proposals, labels, 25.)) 35 | self.n_labels += len(labels) 36 | 37 | # sort proposals 38 | values = torch.cat(values) 39 | ind = values[:, 0].sort(stable=True, descending=True).indices 40 | values = values[ind] 41 | 42 | # accumulate to calculate precision and recall 43 | curve = self.calculate_curve(values) 44 | ap = self.calculate_ap(curve) 45 | self.ap[iou_threshold] = ap 46 | 47 | return self.ap 48 | 49 | def calculate_curve(self, values): 50 | is_TP = values[:, 1] 51 | acc_TP = torch.cumsum(is_TP, dim=0) 52 | precision = acc_TP / (torch.arange(len(is_TP)) + 1) 53 | recall = acc_TP / self.n_labels 54 | curve = torch.stack([recall, precision]).T 55 | curve = torch.cat([torch.tensor([[1., 0.]]), torch.flip(curve, dims=(0,))]) 56 | return curve 57 | 58 | @staticmethod 59 | def calculate_ap(curve): 60 | x, y = curve.T 61 | y_max = y.cummax(dim=0).values 62 | x_diff = x.diff().abs() 63 | ap = (x_diff * y_max[:-1]).sum() 64 | return ap 65 | 66 | @staticmethod 67 | def get_values( 68 | iou_threshold: float, 69 | proposals: Tensor, 70 | labels: Tensor, 71 | fps: float, 72 | ) -> Tensor: 73 | n_labels = len(labels) 74 | n_proposals = len(proposals) 75 | if n_labels > 0: 76 | ious = iou_1d(proposals[:, 1:] / fps, labels) 77 | else: 78 | ious = torch.zeros((n_proposals, 0)) 79 | 80 | # values: (confidence, is_TP) rows 81 | n_labels = ious.shape[1] 82 | detected = torch.full((n_labels,), False) 83 | confidence = proposals[:, 0] 84 | potential_TP = ious > iou_threshold 85 | 86 | tp_indexes = [] 87 | 88 | for i in range(n_labels): 89 | potential_TP_index = potential_TP[:, i].nonzero() 90 | for (j,) in potential_TP_index: 91 | if j not in tp_indexes: 92 | tp_indexes.append(j) 93 | break 94 | 95 | is_TP = torch.zeros(n_proposals, dtype=torch.bool) 96 | if len(tp_indexes) > 0: 97 | tp_indexes = torch.stack(tp_indexes) 98 | is_TP[tp_indexes] = True 99 | values = torch.column_stack([confidence, is_TP]) 100 | return values 101 | 102 | 103 | class AR: 104 | """ 105 | Average Recall 106 | 107 | Args: 108 | n_proposals_list: Number of proposals. 100 for AR@100. 109 | iou_thresholds: IOU threshold samples for the curve. Default: [0.5:0.05:0.95] 110 | 111 | """ 112 | 113 | def __init__(self, n_proposals_list: Union[List[int], int] = 100, iou_thresholds: List[float] = None): 114 | super().__init__() 115 | if iou_thresholds is None: 116 | iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 117 | self.n_proposals_list = n_proposals_list if type(n_proposals_list) is list else [n_proposals_list] 118 | self.n_proposals_list = torch.tensor(self.n_proposals_list) 119 | self.iou_thresholds = iou_thresholds 120 | self.ar: dict = {} 121 | 122 | def __call__(self, metadata: List[Metadata], proposals_dict: dict) -> dict: 123 | # shape: (n_metadata, n_iou_thresholds, n_proposal_thresholds, 2) 124 | values = torch.zeros((len(metadata), len(self.iou_thresholds), len(self.n_proposals_list), 2)) 125 | for i, meta in enumerate(tqdm(metadata)): 126 | proposals = torch.tensor(proposals_dict[meta.file]) 127 | labels = torch.tensor(meta.fake_periods) 128 | values[i] = self.get_values(self.iou_thresholds, proposals, labels, 25.) 129 | 130 | values_sum = values.sum(dim=0) 131 | 132 | TP = values_sum[:, :, 0] 133 | FN = values_sum[:, :, 1] 134 | recall = TP / (TP + FN) # (n_iou_thresholds, n_proposal_thresholds) 135 | for i, n_proposals in enumerate(self.n_proposals_list): 136 | self.ar[n_proposals.item()] = recall[:, i].mean().item() 137 | 138 | return self.ar 139 | 140 | def get_values( 141 | self, 142 | iou_thresholds: List[float], 143 | proposals: Tensor, 144 | labels: Tensor, 145 | fps: float, 146 | ): 147 | n_proposals_list = self.n_proposals_list 148 | max_proposals = max(n_proposals_list) 149 | 150 | proposals = proposals[:max_proposals] 151 | n_labels = len(labels) 152 | 153 | if n_labels > 0: 154 | ious = iou_1d(proposals[:, 1:] / fps, labels) 155 | else: 156 | ious = torch.zeros((max_proposals, 0)) 157 | 158 | # values: matrix of (TP, FN), shapes (n_iou_thresholds, n_proposal_thresholds, 2) 159 | iou_max = ious.cummax(0).values[n_proposals_list - 1] # shape (n_iou_thresholds, n_labels) 160 | iou_max = iou_max[None] 161 | 162 | iou_thresholds = torch.tensor(iou_thresholds)[:, None, None] 163 | TP = (iou_max > iou_thresholds).sum(-1) 164 | FN = n_labels - TP 165 | values = torch.stack([TP, FN], dim=-1) 166 | 167 | return values 168 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from pathlib import Path 3 | from typing import Any, List, Literal, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from pytorch_lightning import LightningModule, Trainer, Callback 8 | from torch import Tensor 9 | 10 | from dataset import LavdfDataModule 11 | from dataset.lavdf import Metadata 12 | 13 | 14 | def nullable_index(obj, index): 15 | if obj is None: 16 | return None 17 | return obj[index] 18 | 19 | 20 | class SaveToCsvCallback(Callback): 21 | 22 | def __init__(self, max_duration: int, metadata: List[Metadata], model_name: str, model_type: str, 23 | modalities: List[Literal["fusion", "visual", "audio"]] 24 | ): 25 | super().__init__() 26 | self.max_duration = max_duration 27 | self.metadata = metadata 28 | self.model_name = model_name 29 | self.model_type = model_type 30 | self.save_fusion = "fusion" in modalities 31 | self.save_visual = "visual" in modalities 32 | self.save_audio = "audio" in modalities 33 | 34 | def on_predict_batch_end( 35 | self, 36 | trainer: Trainer, 37 | pl_module: LightningModule, 38 | outputs: Any, 39 | batch: Any, 40 | batch_idx: int, 41 | dataloader_idx: int, 42 | ) -> None: 43 | if self.model_type == "batfd": 44 | fusion_bm_map, v_bm_map, a_bm_map = outputs 45 | batch_size = fusion_bm_map.shape[0] 46 | 47 | for i in range(batch_size): 48 | n_frames = batch[3][i] 49 | video_name = batch[9][i] 50 | assert isinstance(video_name, str) 51 | assert video_name == self.metadata[batch_idx * batch_size + i].file 52 | if self.save_fusion: 53 | self.gen_df_for_batfd(fusion_bm_map[i], n_frames, os.path.join( 54 | "output", "results", self.model_name, video_name.split('/')[-1].replace(".mp4", ".csv") 55 | )) 56 | if self.save_visual: 57 | self.gen_df_for_batfd(v_bm_map[i], n_frames, os.path.join( 58 | "output", "results", f"{self.model_name}_v", video_name.split('/')[-1].replace(".mp4", ".csv") 59 | )) 60 | if self.save_audio: 61 | self.gen_df_for_batfd(a_bm_map[i], n_frames, os.path.join( 62 | "output", "results", f"{self.model_name}_a", video_name.split('/')[-1].replace(".mp4", ".csv") 63 | )) 64 | elif self.model_type == "batfd_plus": 65 | fusion_bm_map, fusion_start, fusion_end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end = outputs 66 | batch_size = fusion_bm_map.shape[0] 67 | 68 | for i in range(batch_size): 69 | n_frames = batch[5][i] 70 | video_name = batch[-1][i] 71 | assert isinstance(video_name, str) 72 | 73 | if self.save_fusion: 74 | self.gen_df_for_batfd_plus(fusion_bm_map[i], nullable_index(fusion_start, i), nullable_index(fusion_end, i), 75 | n_frames, os.path.join("output", "results", self.model_name, video_name.split('/')[-1].replace(".mp4", ".csv") 76 | )) 77 | if self.save_visual: 78 | self.gen_df_for_batfd_plus(v_bm_map[i], nullable_index(v_start, i), nullable_index(v_end, i), 79 | n_frames, os.path.join("output", "results", f"{self.model_name}_v", video_name.split('/')[-1].replace(".mp4", ".csv") 80 | )) 81 | if self.save_audio: 82 | self.gen_df_for_batfd_plus(a_bm_map[i], nullable_index(a_start, i), nullable_index(a_end, i), 83 | n_frames, os.path.join("output", "results", f"{self.model_name}_a", video_name.split('/')[-1].replace(".mp4", ".csv") 84 | )) 85 | else: 86 | raise ValueError("Invalid model type") 87 | 88 | def gen_df_for_batfd(self, bm_map: Tensor, n_frames: int, output_file: str): 89 | bm_map = bm_map.cpu().numpy() 90 | n_frames = n_frames.cpu().item() 91 | # for each boundary proposal in boundary map 92 | df = pd.DataFrame(bm_map) 93 | df = df.stack().reset_index() 94 | df.columns = ["duration", "begin", "score"] 95 | df["end"] = df.duration + df.begin 96 | df = df[(df.duration > 0) & (df.end <= n_frames)] 97 | df = df.sort_values(["begin", "end"]) 98 | df = df.reset_index()[["begin", "end", "score"]] 99 | df.to_csv(output_file, index=False) 100 | 101 | def gen_df_for_batfd_plus(self, bm_map: Tensor, start: Optional[Tensor], end: Optional[Tensor], n_frames: int, 102 | output_file: str 103 | ): 104 | bm_map = bm_map.cpu().numpy() 105 | n_frames = n_frames.cpu().item() 106 | if start is not None and end is not None: 107 | start = start.cpu().numpy() 108 | end = end.cpu().numpy() 109 | 110 | # for each boundary proposal in boundary map 111 | df = pd.DataFrame(bm_map) 112 | df = df.stack().reset_index() 113 | df.columns = ["duration", "begin", "score"] 114 | df["end"] = df.duration + df.begin 115 | df = df[(df.duration > 0) & (df.end <= n_frames)] 116 | df = df.sort_values(["begin", "end"]) 117 | df = df.reset_index()[["begin", "end", "score"]] 118 | if start is not None and end is not None: 119 | df["score"] = df["score"] * start[df.begin] * end[df.end] 120 | df.to_csv(output_file, index=False) 121 | 122 | 123 | def inference_batfd(model_name: str, model: LightningModule, dm: LavdfDataModule, 124 | max_duration: int, model_type: str, 125 | modalities: Optional[List[Literal["fusion", "visual", "audio"]]] = None, 126 | gpus: int = 1 127 | ): 128 | modalities = modalities or ["fusion"] 129 | 130 | if "fusion" in modalities: 131 | Path(os.path.join("output", "results", model_name)).mkdir(parents=True, exist_ok=True) 132 | if "visual" in modalities: 133 | Path(os.path.join("output", "results", f"{model_name}_v")).mkdir(parents=True, exist_ok=True) 134 | if "audio" in modalities: 135 | Path(os.path.join("output", "results", f"{model_name}_a")).mkdir(parents=True, exist_ok=True) 136 | 137 | model.eval() 138 | 139 | test_dataset = dm.test_dataset 140 | 141 | trainer = Trainer(logger=False, 142 | enable_checkpointing=False, devices=1 if gpus > 1 else None, 143 | accelerator="gpu" if gpus > 0 else "cpu", 144 | callbacks=[SaveToCsvCallback(max_duration, test_dataset.metadata, model_name, model_type, modalities)] 145 | ) 146 | 147 | trainer.predict(model, dm.test_dataloader()) 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Localized Audio Visual DeepFake Dataset (LAV-DF) 2 | 3 |
4 | 5 |

6 |
7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
37 | 38 | This repo is the official PyTorch implementation for the DICTA paper [Do You Really Mean That? Content Driven Audio-Visual 39 | Deepfake Dataset and Multimodal Method for Temporal Forgery Localization](https://ieeexplore.ieee.org/document/10034605) 40 | (Best Award), and the journal paper [_Glitch in the Matrix_: A Large Scale Benchmark for Content Driven Audio-Visual 41 | Forgery Detection and Localization](https://www.sciencedirect.com/science/article/pii/S1077314223001984) accepted by CVIU. 42 | 43 | ## LAV-DF Dataset 44 | 45 | ### Download 46 | 47 | To use this LAV-DF dataset, you should agree the [terms and conditions](https://github.com/ControlNet/LAV-DF/blob/master/TERMS_AND_CONDITIONS.md). 48 | 49 | Download link: [OneDrive](https://monashuni-my.sharepoint.com/:f:/g/personal/zhixi_cai_monash_edu/EklD-8lD_GRNl0yyJJ-cF3kBWEiHRmH4U5Dtg7eJjAOUlg?e=wowDpd), [Google Drive](https://drive.google.com/drive/folders/1U8asIMb0bpH6-zMR_5FaJmPnC53lomq7?usp=sharing), [HuggingFace](https://huggingface.co/datasets/ControlNet/LAV-DF). 50 | 51 | ### Baseline Benchmark 52 | 53 | | Method | AP@0.5 | AP@0.75 | AP@0.95 | AR@100 | AR@50 | AR@20 | AR@10 | 54 | |---------|--------|---------|---------|--------|-------|-------|-------| 55 | | BA-TFD | 79.15 | 38.57 | 00.24 | 67.03 | 64.18 | 60.89 | 58.51 | 56 | | BA-TFD+ | 96.30 | 84.96 | 04.44 | 81.62 | 80.48 | 79.40 | 78.75 | 57 | 58 | Please note this result of BA-TFD is slightly better than the one reported in the paper. 59 | This is because we have used the better hyperparameters in this repository. 60 | 61 | ## Baseline Models 62 | 63 | ### Requirements 64 | 65 | The main versions are, 66 | - Python >= 3.7, < 3.11 67 | - PyTorch >= 1.13 68 | - torchvision >= 0.14 69 | - pytorch_lightning == 1.7.* 70 | 71 | Run the following command to install the required packages. 72 | 73 | ```bash 74 | pip install -r requirements.txt 75 | ``` 76 | 77 | ### Training BA-TFD 78 | 79 | Train the BA-TFD introduced in paper [Do You Really Mean That? Content Driven Audio-Visual 80 | Deepfake Dataset and Multimodal Method for Temporal Forgery Localization](https://ieeexplore.ieee.org/document/10034605) with default hyperparameter on LAV-DF dataset. 81 | 82 | ```bash 83 | python train.py \ 84 | --config ./config/batfd_default.toml \ 85 | --data_root \ 86 | --batch_size 4 --num_workers 8 --gpus 1 --precision 16 87 | ``` 88 | 89 | The checkpoint will be saved in `ckpt` directory, and the tensorboard log will be saved in `lighntning_logs` directory. If you meet the NaN issue when training BA-TFD+, that might be caused by the bug in PyTorch self attention ops, upgrading or changing the PyTorch version can solve it. 90 | 91 | ### Training BA-TFD+ 92 | 93 | Train the BA-TFD+ introduced in paper [_Glitch in the Matrix_: A Large Scale Benchmark for Content Driven Audio-Visual Forgery Detection and Localization](https://www.sciencedirect.com/science/article/pii/S1077314223001984) with default hyperparameter on LAV-DF dataset. 94 | 95 | ```bash 96 | python train.py \ 97 | --config ./config/batfd_plus_default.toml \ 98 | --data_root \ 99 | --batch_size 4 --num_workers 8 --gpus 2 --precision 32 100 | ``` 101 | 102 | Please use `FP32` for training BA-TFD+ as `FP16` will cause inf and nan. 103 | 104 | The checkpoint will be saved in `ckpt` directory, and the tensorboard log will be saved in `lighntning_logs` directory. 105 | 106 | 107 | ### Evaluation 108 | 109 | Please run the following command to evaluate the model with the checkpoint saved in `ckpt` directory. 110 | 111 | Besides, you can also download the [BA-TFD](https://github.com/ControlNet/LAV-DF/releases/download/pretrained_model/batfd_default.ckpt) and [BA-TFD+](https://github.com/ControlNet/LAV-DF/releases/download/pretrained_model_v2/batfd_plus_default.ckpt) pretrained models. 112 | 113 | ```bash 114 | python evaluate.py \ 115 | --config \ 116 | --data_root \ 117 | --checkpoint \ 118 | --batch_size 1 --num_workers 4 119 | ``` 120 | 121 | In the script, there will be a temporal inference results generated in `output` directory, and the AP and AR scores will 122 | be printed in the console. 123 | 124 | Note please make sure only one GPU is visible to the evaluation script. 125 | 126 | ## License 127 | 128 | This project is under the CC BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 129 | 130 | ## References 131 | 132 | If you find this work useful in your research, please cite them. 133 | 134 | The conference paper, 135 | ```bibtex 136 | @inproceedings{cai2022you, 137 | title = {Do You Really Mean That? Content Driven Audio-Visual Deepfake Dataset and Multimodal Method for Temporal Forgery Localization}, 138 | author = {Cai, Zhixi and Stefanov, Kalin and Dhall, Abhinav and Hayat, Munawar}, 139 | booktitle = {2022 International Conference on Digital Image Computing: Techniques and Applications (DICTA)}, 140 | year = {2022}, 141 | doi = {10.1109/DICTA56598.2022.10034605}, 142 | pages = {1--10}, 143 | address = {Sydney, Australia}, 144 | } 145 | ``` 146 | 147 | The extended journal version is accepted by CVIU, 148 | ```bibtex 149 | @article{cai2023glitch, 150 | title = {Glitch in the Matrix: A Large Scale Benchmark for Content Driven Audio-Visual Forgery Detection and Localization}, 151 | author = {Cai, Zhixi and Ghosh, Shreya and Dhall, Abhinav and Gedeon, Tom and Stefanov, Kalin and Hayat, Munawar}, 152 | journal = {Computer Vision and Image Understanding}, 153 | year = {2023}, 154 | volume = {236}, 155 | pages = {103818}, 156 | issn = {1077-3142}, 157 | doi = {10.1016/j.cviu.2023.103818}, 158 | } 159 | ``` 160 | 161 | ## Acknowledgements 162 | 163 | Some code related to boundary matching mechanism is borrowed from 164 | [JJBOY/BMN-Boundary-Matching-Network](https://github.com/JJBOY/BMN-Boundary-Matching-Network) and 165 | [xxcheng0708/BSNPlusPlus-boundary-sensitive-network](https://github.com/xxcheng0708/BSNPlusPlus-boundary-sensitive-network). 166 | -------------------------------------------------------------------------------- /model/video_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import Sequential, LeakyReLU, MaxPool3d, Module, Linear 7 | from torchvision.models.video.mvit import MSBlockConfig, _mvit 8 | 9 | from utils import Conv3d, Conv1d 10 | 11 | 12 | class C3DVideoEncoder(Module): 13 | """ 14 | Video encoder (E_v): Process video frames to extract features. 15 | Input: 16 | V: (B, C, T, H, W) 17 | Output: 18 | F_v: (B, C_f, T) 19 | """ 20 | 21 | def __init__(self, n_features=(64, 96, 128, 128), v_cla_feature_in: int = 256): 22 | super().__init__() 23 | 24 | n_dim0, n_dim1, n_dim2, n_dim3 = n_features 25 | 26 | # (B, 3, 512, 96, 96) -> (B, 64, 512, 32, 32) 27 | self.block0 = Sequential( 28 | Conv3d(3, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 29 | Conv3d(n_dim0, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 30 | MaxPool3d((1, 3, 3)) 31 | ) 32 | 33 | # (B, 64, 512, 32, 32) -> (B, 96, 512, 16, 16) 34 | self.block1 = Sequential( 35 | Conv3d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 36 | Conv3d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 37 | MaxPool3d((1, 2, 2)) 38 | ) 39 | 40 | # (B, 96, 512, 16, 16) -> (B, 128, 512, 8, 8) 41 | self.block2 = Sequential( 42 | Conv3d(n_dim1, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 43 | Conv3d(n_dim2, n_dim2, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 44 | MaxPool3d((1, 2, 2)) 45 | ) 46 | 47 | # (B, 128, 512, 8, 8) -> (B, 128, 512, 2, 2) -> (B, 512, 512) -> (B, 256, 512) 48 | self.block3 = Sequential( 49 | Conv3d(n_dim2, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 50 | MaxPool3d((1, 2, 2)), 51 | Conv3d(n_dim3, n_dim3, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU), 52 | MaxPool3d((1, 2, 2)), 53 | Rearrange("b c t h w -> b (c h w) t"), 54 | Conv1d(n_dim3 * 4, v_cla_feature_in, kernel_size=1, stride=1, build_activation=LeakyReLU) 55 | ) 56 | 57 | def forward(self, video: Tensor) -> Tensor: 58 | x = self.block0(video) 59 | x = self.block1(x) 60 | x = self.block2(x) 61 | x = self.block3(x) 62 | return x 63 | 64 | 65 | class MvitVideoEncoder(Module): 66 | 67 | def __init__(self, v_cla_feature_in: int = 256, 68 | temporal_size: int = 512, 69 | mvit_type: Literal["mvit_v2_t", "mvit_v2_s", "mvit_v2_b"] = "mvit_v2_t" 70 | ): 71 | super().__init__() 72 | if mvit_type == "mvit_v2_t": 73 | self.mvit = mvit_v2_t(v_cla_feature_in, temporal_size) 74 | elif mvit_type == "mvit_v2_s": 75 | self.mvit = mvit_v2_s(v_cla_feature_in, temporal_size) 76 | elif mvit_type == "mvit_v2_b": 77 | self.mvit = mvit_v2_b(v_cla_feature_in, temporal_size) 78 | else: 79 | raise ValueError(f"Invalid mvit_type: {mvit_type}") 80 | del self.mvit.head 81 | 82 | def forward(self, video: Tensor) -> Tensor: 83 | feat = self.mvit.conv_proj(video) 84 | feat = feat.flatten(2).transpose(1, 2) 85 | feat = self.mvit.pos_encoding(feat) 86 | thw = (self.mvit.pos_encoding.temporal_size,) + self.mvit.pos_encoding.spatial_size 87 | for block in self.mvit.blocks: 88 | feat, thw = block(feat, thw) 89 | 90 | feat = self.mvit.norm(feat) 91 | feat = feat[:, 1:] 92 | feat = feat.permute(0, 2, 1) 93 | return feat 94 | 95 | 96 | def generate_config(blocks, heads, channels, out_dim): 97 | num_heads = [] 98 | input_channels = [] 99 | kernel_qkv = [] 100 | stride_q = [[1, 1, 1]] * sum(blocks) 101 | blocks_cum = np.cumsum(blocks) 102 | stride_kv = [] 103 | 104 | for i in range(len(blocks)): 105 | num_heads.extend([heads[i]] * blocks[i]) 106 | input_channels.extend([channels[i]] * blocks[i]) 107 | kernel_qkv.extend([[3, 3, 3]] * blocks[i]) 108 | 109 | if i != len(blocks) - 1: 110 | stride_q[blocks_cum[i]] = [1, 2, 2] 111 | 112 | stride_kv_value = 2 ** (len(blocks) - 1 - i) 113 | stride_kv.extend([[1, stride_kv_value, stride_kv_value]] * blocks[i]) 114 | 115 | return { 116 | "num_heads": num_heads, 117 | "input_channels": [input_channels[0]] + input_channels[:-1], 118 | "output_channels": input_channels[:-1] + [out_dim], 119 | "kernel_q": kernel_qkv, 120 | "kernel_kv": kernel_qkv, 121 | "stride_q": stride_q, 122 | "stride_kv": stride_kv 123 | } 124 | 125 | 126 | def build_mvit(config, kwargs, temporal_size=512): 127 | block_setting = [] 128 | for i in range(len(config["num_heads"])): 129 | block_setting.append( 130 | MSBlockConfig( 131 | num_heads=config["num_heads"][i], 132 | input_channels=config["input_channels"][i], 133 | output_channels=config["output_channels"][i], 134 | kernel_q=config["kernel_q"][i], 135 | kernel_kv=config["kernel_kv"][i], 136 | stride_q=config["stride_q"][i], 137 | stride_kv=config["stride_kv"][i], 138 | ) 139 | ) 140 | return _mvit( 141 | spatial_size=(96, 96), 142 | temporal_size=temporal_size, 143 | block_setting=block_setting, 144 | residual_pool=True, 145 | residual_with_cls_embed=False, 146 | rel_pos_embed=True, 147 | proj_after_attn=True, 148 | stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), 149 | weights=None, 150 | progress=False, 151 | patch_embed_kernel=(3, 15, 15), 152 | patch_embed_stride=(1, 12, 12), 153 | patch_embed_padding=(1, 3, 3), 154 | **kwargs, 155 | ) 156 | 157 | 158 | def mvit_v2_b(out_dim: int, temporal_size: int, **kwargs): 159 | config = generate_config([2, 3, 16, 3], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 160 | return build_mvit(config, kwargs, temporal_size=temporal_size) 161 | 162 | 163 | def mvit_v2_s(out_dim: int, temporal_size: int, **kwargs): 164 | config = generate_config([1, 2, 11, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 165 | return build_mvit(config, kwargs, temporal_size=temporal_size) 166 | 167 | 168 | def mvit_v2_t(out_dim: int, temporal_size: int, **kwargs): 169 | config = generate_config([1, 2, 5, 2], [1, 2, 4, 8], [96, 192, 384, 768], out_dim) 170 | return build_mvit(config, kwargs, temporal_size=temporal_size) 171 | 172 | 173 | class VideoFeatureProjection(Module): 174 | 175 | def __init__(self, input_feature_dim: int, v_cla_feature_in: int = 256): 176 | super().__init__() 177 | self.proj = Linear(input_feature_dim, v_cla_feature_in) 178 | 179 | def forward(self, x: Tensor) -> Tensor: 180 | x = self.proj(x) 181 | return x.permute(0, 2, 1) 182 | 183 | 184 | def get_video_encoder(v_cla_feature_in, temporal_size, v_encoder, ve_features): 185 | if v_encoder == "c3d": 186 | video_encoder = C3DVideoEncoder(n_features=ve_features, v_cla_feature_in=v_cla_feature_in) 187 | elif v_encoder == "mvit_t": 188 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_t") 189 | elif v_encoder == "mvit_s": 190 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_s") 191 | elif v_encoder == "mvit_b": 192 | video_encoder = MvitVideoEncoder(v_cla_feature_in=v_cla_feature_in, temporal_size=temporal_size, mvit_type="mvit_v2_b") 193 | elif v_encoder == "marlin_vit_small": 194 | video_encoder = VideoFeatureProjection(input_feature_dim=13824, v_cla_feature_in=v_cla_feature_in) 195 | elif v_encoder == "i3d": 196 | video_encoder = VideoFeatureProjection(input_feature_dim=2048, v_cla_feature_in=v_cla_feature_in) 197 | elif v_encoder == "3dmm": 198 | video_encoder = VideoFeatureProjection(input_feature_dim=393, v_cla_feature_in=v_cla_feature_in) 199 | else: 200 | raise ValueError(f"Invalid video encoder: {v_encoder}") 201 | return video_encoder 202 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | import json 3 | import os 4 | import re 5 | from abc import ABC 6 | from typing import List, Tuple, Optional 7 | 8 | import numpy as np 9 | import torch 10 | import torchaudio 11 | import torchvision 12 | from einops import rearrange 13 | from pytorch_lightning import Callback, Trainer, LightningModule 14 | from torch import Tensor 15 | from torch.nn import functional as F, Module 16 | 17 | 18 | def read_json(path: str, object_hook=None): 19 | with open(path, 'r') as f: 20 | return json.load(f, object_hook=object_hook) 21 | 22 | 23 | def read_video(path: str): 24 | video, audio, info = torchvision.io.read_video(path, pts_unit="sec") 25 | video = video.permute(0, 3, 1, 2) / 255 26 | audio = audio.permute(1, 0) 27 | return video, audio, info 28 | 29 | 30 | def read_audio(path: str): 31 | return torchaudio.load(path) 32 | 33 | 34 | def read_image(path: str): 35 | return torchvision.io.read_image(path).float() / 255.0 36 | 37 | 38 | def padding_video(tensor: Tensor, target: int, padding_method: str = "zero", padding_position: str = "tail") -> Tensor: 39 | t, c, h, w = tensor.shape 40 | padding_size = target - t 41 | 42 | pad = _get_padding_pair(padding_size, padding_position) 43 | 44 | if padding_method == "zero": 45 | return F.pad(tensor, pad=[0, 0, 0, 0, 0, 0] + pad) 46 | elif padding_method == "same": 47 | tensor = rearrange(tensor, "t c h w -> c h w t") 48 | tensor = F.pad(tensor, pad=pad + [0, 0], mode="replicate") 49 | return rearrange(tensor, "c h w t -> t c h w") 50 | else: 51 | raise ValueError("Wrong padding method. It should be zero or tail or average.") 52 | 53 | 54 | def padding_audio(tensor: Tensor, target: int, 55 | padding_method: str = "zero", 56 | padding_position: str = "tail" 57 | ) -> Tensor: 58 | t, c = tensor.shape 59 | padding_size = target - t 60 | pad = _get_padding_pair(padding_size, padding_position) 61 | 62 | if padding_method == "zero": 63 | return F.pad(tensor, pad=[0, 0] + pad) 64 | elif padding_method == "same": 65 | tensor = rearrange(tensor, "t c -> 1 c t") 66 | tensor = F.pad(tensor, pad=pad, mode="replicate") 67 | return rearrange(tensor, "1 c t -> t c") 68 | else: 69 | raise ValueError("Wrong padding method. It should be zero or tail or average.") 70 | 71 | 72 | def _get_padding_pair(padding_size: int, padding_position: str) -> List[int]: 73 | if padding_position == "tail": 74 | pad = [0, padding_size] 75 | elif padding_position == "head": 76 | pad = [padding_size, 0] 77 | elif padding_position == "average": 78 | padding_head = padding_size // 2 79 | padding_tail = padding_size - padding_head 80 | pad = [padding_head, padding_tail] 81 | else: 82 | raise ValueError("Wrong padding position. It should be zero or tail or average.") 83 | return pad 84 | 85 | 86 | def resize_video(tensor: Tensor, size: Tuple[int, int], resize_method: str = "bicubic") -> Tensor: 87 | return F.interpolate(tensor, size=size, mode=resize_method) 88 | 89 | 90 | class _ConvNd(Module, ABC): 91 | 92 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, 93 | build_activation: Optional[callable] = None 94 | ): 95 | super().__init__() 96 | self.conv = self.PtConv( 97 | in_channels, out_channels, kernel_size, stride=stride, padding=padding 98 | ) 99 | if build_activation is not None: 100 | self.activation = build_activation() 101 | else: 102 | self.activation = None 103 | 104 | def forward(self, x: Tensor) -> Tensor: 105 | x = self.conv(x) 106 | if self.activation is not None: 107 | x = self.activation(x) 108 | return x 109 | 110 | 111 | class Conv1d(_ConvNd): 112 | PtConv = torch.nn.Conv1d 113 | 114 | 115 | class Conv2d(_ConvNd): 116 | PtConv = torch.nn.Conv2d 117 | 118 | 119 | class Conv3d(_ConvNd): 120 | PtConv = torch.nn.Conv3d 121 | 122 | 123 | def iou_with_anchors(anchors_min, anchors_max, box_min, box_max): 124 | """Compute jaccard score between a box and the anchors.""" 125 | 126 | len_anchors = anchors_max - anchors_min 127 | int_xmin = np.maximum(anchors_min, box_min) 128 | int_xmax = np.minimum(anchors_max, box_max) 129 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 130 | union_len = len_anchors - inter_len + box_max - box_min 131 | iou = inter_len / union_len 132 | return iou 133 | 134 | 135 | def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max): 136 | # calculate the overlap proportion between the anchor and all bbox for supervise signal, 137 | # the length of the anchor is 0.01 138 | len_anchors = anchors_max - anchors_min 139 | int_xmin = np.maximum(anchors_min, box_min) 140 | int_xmax = np.minimum(anchors_max, box_max) 141 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 142 | scores = np.divide(inter_len, len_anchors) 143 | return scores 144 | 145 | 146 | def iou_1d(proposal, target) -> Tensor: 147 | """ 148 | Calculate 1D IOU for N proposals with L labels. 149 | 150 | Args: 151 | proposal (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The predicted array with [M, 2]. First column is 152 | beginning, second column is end. 153 | target (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The label array with [N, 2]. First column is 154 | beginning, second column is end. 155 | 156 | Returns: 157 | :class:`~torch.Tensor`: The iou result with [M, N]. 158 | """ 159 | if type(proposal) is np.ndarray: 160 | proposal = torch.from_numpy(proposal) 161 | 162 | if type(target) is np.ndarray: 163 | target = torch.from_numpy(target) 164 | 165 | proposal_begin = proposal[:, 0].unsqueeze(0).T 166 | proposal_end = proposal[:, 1].unsqueeze(0).T 167 | target_begin = target[:, 0] 168 | target_end = target[:, 1] 169 | 170 | inner_begin = torch.maximum(proposal_begin, target_begin) 171 | inner_end = torch.minimum(proposal_end, target_end) 172 | outer_begin = torch.minimum(proposal_begin, target_begin) 173 | outer_end = torch.maximum(proposal_end, target_end) 174 | 175 | inter = torch.clamp(inner_end - inner_begin, min=0.) 176 | union = outer_end - outer_begin 177 | return inter / union 178 | 179 | 180 | class LrLogger(Callback): 181 | """Log learning rate in each epoch start.""" 182 | 183 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 184 | for i, optimizer in enumerate(trainer.optimizers): 185 | for j, params in enumerate(optimizer.param_groups): 186 | key = f"opt{i}_lr{j}" 187 | value = params["lr"] 188 | pl_module.logger.log_metrics({key: value}, step=trainer.global_step) 189 | pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed) 190 | 191 | 192 | class EarlyStoppingLR(Callback): 193 | """Early stop model training when the LR is lower than threshold.""" 194 | 195 | def __init__(self, lr_threshold: float, mode="all"): 196 | self.lr_threshold = lr_threshold 197 | 198 | if mode in ("any", "all"): 199 | self.mode = mode 200 | else: 201 | raise ValueError(f"mode must be one of ('any', 'all')") 202 | 203 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 204 | self._run_early_stop_checking(trainer) 205 | 206 | def _run_early_stop_checking(self, trainer: Trainer) -> None: 207 | metrics = trainer._logger_connector.callback_metrics 208 | if len(metrics) == 0: 209 | return 210 | all_lr = [] 211 | for key, value in metrics.items(): 212 | if re.match(r"opt\d+_lr\d+", key): 213 | all_lr.append(value) 214 | 215 | if len(all_lr) == 0: 216 | return 217 | 218 | if self.mode == "all": 219 | if all(lr <= self.lr_threshold for lr in all_lr): 220 | trainer.should_stop = True 221 | elif self.mode == "any": 222 | if any(lr <= self.lr_threshold for lr in all_lr): 223 | trainer.should_stop = True 224 | 225 | 226 | def generate_metadata_min(data_root: str): 227 | metadata_full = read_json(os.path.join(data_root, "metadata.json")) 228 | metadata_min = [] 229 | for meta in metadata_full: 230 | del meta["timestamps"] 231 | del meta["transcript"] 232 | metadata_min.append(meta) 233 | with open(os.path.join(data_root, "metadata.min.json"), "w") as f: 234 | json.dump(metadata_min, f) 235 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=intellij,python,windows,macos,linux,jupyternotebooks 3 | 4 | ### Intellij ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # AWS User-specific 16 | .idea/**/aws.xml 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/artifacts 39 | # .idea/compiler.xml 40 | # .idea/jarRepositories.xml 41 | # .idea/modules.xml 42 | # .idea/*.iml 43 | # .idea/modules 44 | # *.iml 45 | # *.ipr 46 | 47 | # CMake 48 | cmake-build-*/ 49 | 50 | # Mongo Explorer plugin 51 | .idea/**/mongoSettings.xml 52 | 53 | # File-based project format 54 | *.iws 55 | 56 | # IntelliJ 57 | out/ 58 | 59 | # mpeltonen/sbt-idea plugin 60 | .idea_modules/ 61 | 62 | # JIRA plugin 63 | atlassian-ide-plugin.xml 64 | 65 | # Cursive Clojure plugin 66 | .idea/replstate.xml 67 | 68 | # SonarLint plugin 69 | .idea/sonarlint/ 70 | 71 | # Crashlytics plugin (for Android Studio and IntelliJ) 72 | com_crashlytics_export_strings.xml 73 | crashlytics.properties 74 | crashlytics-build.properties 75 | fabric.properties 76 | 77 | # Editor-based Rest Client 78 | .idea/httpRequests 79 | 80 | # Android studio 3.1+ serialized cache file 81 | .idea/caches/build_file_checksums.ser 82 | 83 | ### Intellij Patch ### 84 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 85 | 86 | # *.iml 87 | # modules.xml 88 | # .idea/misc.xml 89 | # *.ipr 90 | .idea 91 | # Sonarlint plugin 92 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 93 | .idea/**/sonarlint/ 94 | 95 | # SonarQube Plugin 96 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 97 | .idea/**/sonarIssues.xml 98 | 99 | # Markdown Navigator plugin 100 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 101 | .idea/**/markdown-navigator.xml 102 | .idea/**/markdown-navigator-enh.xml 103 | .idea/**/markdown-navigator/ 104 | 105 | # Cache file creation bug 106 | # See https://youtrack.jetbrains.com/issue/JBR-2257 107 | .idea/$CACHE_FILE$ 108 | 109 | # CodeStream plugin 110 | # https://plugins.jetbrains.com/plugin/12206-codestream 111 | .idea/codestream.xml 112 | 113 | # Azure Toolkit for IntelliJ plugin 114 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 115 | .idea/**/azureSettings.xml 116 | 117 | ### JupyterNotebooks ### 118 | # gitignore template for Jupyter Notebooks 119 | # website: http://jupyter.org/ 120 | 121 | .ipynb_checkpoints 122 | */.ipynb_checkpoints/* 123 | 124 | # IPython 125 | profile_default/ 126 | ipython_config.py 127 | 128 | # Remove previous ipynb_checkpoints 129 | # git rm -r .ipynb_checkpoints/ 130 | 131 | ### Linux ### 132 | *~ 133 | 134 | # temporary files which can be created if a process still has a handle open of a deleted file 135 | .fuse_hidden* 136 | 137 | # KDE directory preferences 138 | .directory 139 | 140 | # Linux trash folder which might appear on any partition or disk 141 | .Trash-* 142 | 143 | # .nfs files are created when an open file is removed but is still being accessed 144 | .nfs* 145 | 146 | ### macOS ### 147 | # General 148 | .DS_Store 149 | .AppleDouble 150 | .LSOverride 151 | 152 | # Icon must end with two \r 153 | Icon 154 | 155 | 156 | # Thumbnails 157 | ._* 158 | 159 | # Files that might appear in the root of a volume 160 | .DocumentRevisions-V100 161 | .fseventsd 162 | .Spotlight-V100 163 | .TemporaryItems 164 | .Trashes 165 | .VolumeIcon.icns 166 | .com.apple.timemachine.donotpresent 167 | 168 | # Directories potentially created on remote AFP share 169 | .AppleDB 170 | .AppleDesktop 171 | Network Trash Folder 172 | Temporary Items 173 | .apdisk 174 | 175 | ### macOS Patch ### 176 | # iCloud generated files 177 | *.icloud 178 | 179 | ### Python ### 180 | # Byte-compiled / optimized / DLL files 181 | __pycache__/ 182 | *.py[cod] 183 | *$py.class 184 | 185 | # C extensions 186 | *.so 187 | 188 | # Distribution / packaging 189 | .Python 190 | build/ 191 | develop-eggs/ 192 | dist/ 193 | downloads/ 194 | eggs/ 195 | .eggs/ 196 | lib/ 197 | lib64/ 198 | parts/ 199 | sdist/ 200 | var/ 201 | wheels/ 202 | share/python-wheels/ 203 | *.egg-info/ 204 | .installed.cfg 205 | *.egg 206 | MANIFEST 207 | 208 | # PyInstaller 209 | # Usually these files are written by a python script from a template 210 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 211 | *.manifest 212 | *.spec 213 | 214 | # Installer logs 215 | pip-log.txt 216 | pip-delete-this-directory.txt 217 | 218 | # Unit test / coverage reports 219 | htmlcov/ 220 | .tox/ 221 | .nox/ 222 | .coverage 223 | .coverage.* 224 | .cache 225 | nosetests.xml 226 | coverage.xml 227 | *.cover 228 | *.py,cover 229 | .hypothesis/ 230 | .pytest_cache/ 231 | cover/ 232 | 233 | # Translations 234 | *.mo 235 | *.pot 236 | 237 | # Django stuff: 238 | *.log 239 | local_settings.py 240 | db.sqlite3 241 | db.sqlite3-journal 242 | 243 | # Flask stuff: 244 | instance/ 245 | .webassets-cache 246 | 247 | # Scrapy stuff: 248 | .scrapy 249 | 250 | # Sphinx documentation 251 | docs/_build/ 252 | 253 | # PyBuilder 254 | .pybuilder/ 255 | target/ 256 | 257 | # Jupyter Notebook 258 | 259 | # IPython 260 | 261 | # pyenv 262 | # For a library or package, you might want to ignore these files since the code is 263 | # intended to run in multiple environments; otherwise, check them in: 264 | # .python-version 265 | 266 | # pipenv 267 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 268 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 269 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 270 | # install all needed dependencies. 271 | #Pipfile.lock 272 | 273 | # poetry 274 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 275 | # This is especially recommended for binary packages to ensure reproducibility, and is more 276 | # commonly ignored for libraries. 277 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 278 | #poetry.lock 279 | 280 | # pdm 281 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 282 | #pdm.lock 283 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 284 | # in version control. 285 | # https://pdm.fming.dev/#use-with-ide 286 | .pdm.toml 287 | 288 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 289 | __pypackages__/ 290 | 291 | # Celery stuff 292 | celerybeat-schedule 293 | celerybeat.pid 294 | 295 | # SageMath parsed files 296 | *.sage.py 297 | 298 | # Environments 299 | .env 300 | .venv 301 | env/ 302 | venv/ 303 | ENV/ 304 | env.bak/ 305 | venv.bak/ 306 | 307 | # Spyder project settings 308 | .spyderproject 309 | .spyproject 310 | 311 | # Rope project settings 312 | .ropeproject 313 | 314 | # mkdocs documentation 315 | /site 316 | 317 | # mypy 318 | .mypy_cache/ 319 | .dmypy.json 320 | dmypy.json 321 | 322 | # Pyre type checker 323 | .pyre/ 324 | 325 | # pytype static type analyzer 326 | .pytype/ 327 | 328 | # Cython debug symbols 329 | cython_debug/ 330 | 331 | # PyCharm 332 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 333 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 334 | # and can be added to the global gitignore or merged into this file. For a more nuclear 335 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 336 | #.idea/ 337 | 338 | ### Windows ### 339 | # Windows thumbnail cache files 340 | Thumbs.db 341 | Thumbs.db:encryptable 342 | ehthumbs.db 343 | ehthumbs_vista.db 344 | 345 | # Dump file 346 | *.stackdump 347 | 348 | # Folder config file 349 | [Dd]esktop.ini 350 | 351 | # Recycle Bin used on file shares 352 | $RECYCLE.BIN/ 353 | 354 | # Windows Installer files 355 | *.cab 356 | *.msi 357 | *.msix 358 | *.msm 359 | *.msp 360 | 361 | # Windows shortcuts 362 | *.lnk 363 | 364 | # End of https://www.toptal.com/developers/gitignore/api/intellij,python,windows,macos,linux,jupyternotebooks 365 | 366 | /output 367 | /lightning_logs 368 | /ckpt 369 | *.ckpt 370 | *.pth -------------------------------------------------------------------------------- /model/batfd.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union, Sequence, Tuple 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule 5 | from torch import Tensor 6 | from torch.nn import BCEWithLogitsLoss, MSELoss 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | from dataset.lavdf import Metadata 11 | from loss import MaskedFrameLoss, MaskedBMLoss, MaskedContrastLoss 12 | from .audio_encoder import get_audio_encoder 13 | from .boundary_module import BoundaryModule 14 | from .frame_classifier import FrameLogisticRegression 15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion 16 | from .video_encoder import get_video_encoder 17 | 18 | 19 | class Batfd(LightningModule): 20 | 21 | def __init__(self, 22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr", 23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256, 24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40, 25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99, 26 | weight_decay=0.0001, learning_rate=0.0002, distributed=False 27 | ): 28 | super().__init__() 29 | self.save_hyperparameters() 30 | self.cla_feature_in = v_cla_feature_in 31 | self.temporal_dim = temporal_dim 32 | 33 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features) 34 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features) 35 | 36 | if frame_classifier == "lr": 37 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in) 38 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in) 39 | 40 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier 41 | 42 | assert v_cla_feature_in == a_cla_feature_in 43 | 44 | v_bm_in = v_cla_feature_in + 1 45 | a_bm_in = a_cla_feature_in + 1 46 | 47 | self.video_boundary_module = BoundaryModule(v_bm_in, boundary_features, boundary_samples, temporal_dim, 48 | max_duration 49 | ) 50 | self.audio_boundary_module = BoundaryModule(a_bm_in, boundary_features, boundary_samples, temporal_dim, 51 | max_duration 52 | ) 53 | 54 | self.fusion = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 55 | 56 | self.frame_loss = MaskedFrameLoss(BCEWithLogitsLoss()) 57 | self.contrast_loss = MaskedContrastLoss(margin=contrast_loss_margin) 58 | self.bm_loss = MaskedBMLoss(MSELoss()) 59 | self.weight_frame_loss = weight_frame_loss 60 | self.weight_modal_bm_loss = weight_modal_bm_loss 61 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim) 62 | self.weight_decay = weight_decay 63 | self.learning_rate = learning_rate 64 | self.distributed = distributed 65 | 66 | def forward(self, video: Tensor, audio: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 67 | # encoders 68 | v_features = self.video_encoder(video) 69 | a_features = self.audio_encoder(audio) 70 | 71 | # frame classifiers 72 | v_frame_cla = self.video_frame_classifier(v_features) 73 | a_frame_cla = self.audio_frame_classifier(a_features) 74 | 75 | # concat classification result to features 76 | v_bm_in = torch.column_stack([v_features, v_frame_cla]) 77 | a_bm_in = torch.column_stack([a_features, a_frame_cla]) 78 | 79 | # modal boundary module 80 | v_bm_map = self.video_boundary_module(v_bm_in) 81 | a_bm_map = self.audio_boundary_module(a_bm_in) 82 | 83 | # boundary map modal attention fusion 84 | fusion_bm_map = self.fusion(v_bm_in, a_bm_in, v_bm_map, a_bm_map) 85 | 86 | return fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features 87 | 88 | def loss_fn(self, fusion_bm_map: Tensor, v_bm_map: Tensor, a_bm_map: Tensor, 89 | v_frame_cla: Tensor, a_frame_cla: Tensor, label: Tensor, n_frames: Tensor, 90 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 91 | ) -> Dict[str, Tensor]: 92 | fusion_bm_loss = self.bm_loss(fusion_bm_map, label, n_frames) 93 | 94 | v_bm_loss = self.bm_loss(v_bm_map, v_bm_label, n_frames) 95 | a_bm_loss = self.bm_loss(a_bm_map, a_bm_label, n_frames) 96 | 97 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label, n_frames) 98 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label, n_frames) 99 | 100 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label, n_frames) 101 | / (self.cla_feature_in * self.temporal_dim), max=1.) 102 | 103 | loss = fusion_bm_loss + \ 104 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \ 105 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \ 106 | self.weight_contrastive_loss * contrast_loss 107 | 108 | return { 109 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss, 110 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss 111 | } 112 | 113 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 114 | optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None 115 | ) -> Tensor: 116 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch 117 | 118 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio) 119 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames, 120 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 121 | ) 122 | 123 | self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 124 | prog_bar=False, sync_dist=self.distributed) 125 | return loss_dict["loss"] 126 | 127 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 128 | dataloader_idx: Optional[int] = None 129 | ) -> Tensor: 130 | video, audio, label, n_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label = batch 131 | 132 | fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, v_features, a_features = self(video, audio) 133 | loss_dict = self.loss_fn(fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla, label, n_frames, 134 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, v_features, a_features 135 | ) 136 | 137 | self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 138 | prog_bar=False, sync_dist=self.distributed) 139 | return loss_dict["loss"] 140 | 141 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None 142 | ) -> Tuple[Tensor, Tensor, Tensor]: 143 | video, audio, *_ = batch 144 | fusion_bm_map, v_bm_map, a_bm_map, *_ = self(video, audio) 145 | return fusion_bm_map, v_bm_map, a_bm_map 146 | 147 | def configure_optimizers(self): 148 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay) 149 | return { 150 | "optimizer": optimizer, 151 | "lr_scheduler": { 152 | "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=3, verbose=True, min_lr=1e-8), 153 | "monitor": "val_loss" 154 | } 155 | } 156 | 157 | @staticmethod 158 | def get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tensor): 159 | label_fake = label 160 | label_real = torch.zeros(label.size(), dtype=label.dtype, device=label.device) 161 | 162 | v_bm_label = label_fake if meta.modify_video else label_real 163 | a_bm_label = label_fake if meta.modify_audio else label_real 164 | 165 | frame_label_real = torch.zeros(512) 166 | frame_label_fake = torch.zeros(512) 167 | for begin, end in meta.fake_periods: 168 | begin = int(begin * 25) 169 | end = int(end * 25) 170 | frame_label_fake[begin: end] = 1 171 | 172 | v_frame_label = frame_label_fake if meta.modify_video else frame_label_real 173 | a_frame_label = frame_label_fake if meta.modify_audio else frame_label_real 174 | 175 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1 176 | 177 | return [meta.video_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label] 178 | -------------------------------------------------------------------------------- /model/boundary_module_plus.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops.layers.torch import Rearrange 7 | from torch import Tensor 8 | from torch.nn import Sequential, LeakyReLU 9 | 10 | from model.boundary_module import BoundaryModule 11 | from utils import Conv2d 12 | 13 | 14 | class ConvUnit(nn.Module): 15 | """ 16 | Unit in NestedUNet 17 | """ 18 | 19 | def __init__(self, in_ch, out_ch, is_output=False): 20 | super(ConvUnit, self).__init__() 21 | module_list = [nn.Conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True)] 22 | if is_output is False: 23 | module_list.append(nn.BatchNorm1d(out_ch)) 24 | module_list.append(nn.ReLU(inplace=True)) 25 | self.conv = nn.Sequential(*module_list) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | return x 30 | 31 | 32 | class NestedUNet(nn.Module): 33 | """ 34 | UNet - Basic Implementation 35 | Paper : https://arxiv.org/abs/1505.04597 36 | """ 37 | def __init__(self, in_ch=400, out_ch=2): 38 | super(NestedUNet, self).__init__() 39 | 40 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2) 41 | self.up = nn.Upsample(scale_factor=2) 42 | 43 | n1 = 512 44 | filters = [n1, n1 * 2, n1 * 3] 45 | self.conv0_0 = ConvUnit(in_ch, filters[0], is_output=False) 46 | self.conv1_0 = ConvUnit(filters[0], filters[0], is_output=False) 47 | self.conv2_0 = ConvUnit(filters[0], filters[0], is_output=False) 48 | 49 | self.conv0_1 = ConvUnit(filters[1], filters[0], is_output=False) 50 | self.conv1_1 = ConvUnit(filters[1], filters[0], is_output=False) 51 | 52 | self.conv0_2 = ConvUnit(filters[2], filters[0], is_output=False) 53 | 54 | self.final = nn.Conv1d(filters[0] * 3, out_ch, kernel_size=1) 55 | # self.final = ConvUnit(filters[0] * 3, out_ch, is_output=True) 56 | self.out = nn.Sigmoid() 57 | 58 | def forward(self, x): 59 | x0_0 = self.conv0_0(x) 60 | x1_0 = self.conv1_0(self.pool(x0_0)) 61 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 62 | x2_0 = self.conv2_0(self.pool(x1_0)) 63 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 64 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 65 | out_feature = torch.cat([x0_0, x0_1, x0_2], 1) # for calculating loss 66 | final_feature = self.final(out_feature) 67 | out = self.out(final_feature) 68 | 69 | return out, out_feature 70 | 71 | 72 | class PositionAwareAttentionModule(nn.Module): 73 | def __init__(self, in_channels, inter_channels=None, sub_sample=None, dim=2): 74 | super(PositionAwareAttentionModule, self).__init__() 75 | 76 | self.sub_sample = sub_sample 77 | self.in_channels = in_channels 78 | self.inter_channels = inter_channels 79 | self.dim = dim 80 | 81 | if self.inter_channels is None: 82 | self.inter_channels = in_channels // 2 83 | if self.inter_channels == 0: 84 | self.inter_channels = 1 85 | 86 | if self.dim == 2: 87 | conv_nd = nn.Conv2d 88 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 89 | bn = nn.BatchNorm2d 90 | else: 91 | conv_nd = nn.Conv1d 92 | max_pool_layer = nn.MaxPool1d(kernel_size=(2,)) 93 | bn = nn.BatchNorm1d 94 | 95 | self.g = nn.Sequential( 96 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 97 | bn(self.inter_channels), 98 | nn.ReLU(inplace=True) 99 | ) 100 | self.theta = nn.Sequential( 101 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 102 | bn(self.inter_channels), 103 | nn.ReLU(inplace=True) 104 | ) 105 | self.phi = nn.Sequential( 106 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 107 | bn(self.inter_channels), 108 | nn.ReLU(inplace=True) 109 | ) 110 | self.W = nn.Sequential( 111 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 112 | kernel_size=1, stride=1, padding=0), 113 | bn(self.in_channels) 114 | ) 115 | if self.sub_sample: 116 | self.g = nn.Sequential(self.g, max_pool_layer) 117 | self.phi = nn.Sequential(self.phi, max_pool_layer) 118 | 119 | def forward(self, x): 120 | batch_size = x.size(0) 121 | # value 122 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 123 | g_x = g_x.permute(0, 2, 1) 124 | 125 | # query 126 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 127 | theta_x = theta_x.permute(0, 2, 1) 128 | 129 | # key 130 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 131 | 132 | f = torch.matmul(theta_x, phi_x) 133 | f = F.softmax(f, dim=2) 134 | 135 | y = torch.matmul(f, g_x) 136 | y = y.permute(0, 2, 1).contiguous() 137 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 138 | y = self.W(y) 139 | 140 | z = y + x 141 | return z 142 | 143 | 144 | class ChannelAwareAttentionModule(nn.Module): 145 | def __init__(self, in_channels, inter_channels=None, dim=2): 146 | super(ChannelAwareAttentionModule, self).__init__() 147 | 148 | self.in_channels = in_channels 149 | self.inter_channels = inter_channels 150 | self.dim = dim 151 | 152 | if self.inter_channels is None: 153 | self.inter_channels = in_channels // 2 154 | if self.inter_channels == 0: 155 | self.inter_channels = 1 156 | 157 | if self.dim == 2: 158 | conv_nd = nn.Conv2d 159 | bn = nn.BatchNorm2d 160 | else: 161 | conv_nd = nn.Conv1d 162 | bn = nn.BatchNorm1d 163 | 164 | self.g = nn.Sequential( 165 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 166 | bn(self.inter_channels), 167 | nn.ReLU(inplace=True) 168 | ) 169 | self.theta = nn.Sequential( 170 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 171 | bn(self.inter_channels), 172 | nn.ReLU(inplace=True) 173 | ) 174 | self.phi = nn.Sequential( 175 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0), 176 | bn(self.inter_channels), 177 | nn.ReLU(inplace=True) 178 | ) 179 | self.W = nn.Sequential( 180 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 181 | kernel_size=1, stride=1, padding=0), 182 | bn(self.in_channels) 183 | ) 184 | 185 | def forward(self, x): 186 | batch_size = x.size(0) 187 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 188 | 189 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 190 | 191 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 192 | phi_x = phi_x.permute(0, 2, 1) 193 | 194 | f = torch.matmul(theta_x, phi_x) 195 | f = F.softmax(f, dim=2) 196 | 197 | y = torch.matmul(f, g_x) 198 | y = y.permute(0, 2, 1).contiguous() 199 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 200 | y = self.W(y) 201 | 202 | z = y + x 203 | return z 204 | 205 | 206 | def conv_block(in_ch, out_ch, kernel_size=3, stride=1, bn_layer=False, activate=False): 207 | module_list = [nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=1)] 208 | if bn_layer: 209 | module_list.append(nn.BatchNorm2d(out_ch)) 210 | module_list.append(nn.ReLU(inplace=True)) 211 | if activate: 212 | module_list.append(nn.Sigmoid()) 213 | conv = nn.Sequential(*module_list) 214 | return conv 215 | 216 | 217 | class ProposalRelationBlock(nn.Module): 218 | def __init__(self, in_channels, inter_channels=128, out_channels=2, sub_sample=False): 219 | super(ProposalRelationBlock, self).__init__() 220 | self.p_net = PositionAwareAttentionModule(in_channels, inter_channels=inter_channels, sub_sample=sub_sample, dim=2) 221 | self.c_net = ChannelAwareAttentionModule(in_channels, inter_channels=inter_channels, dim=2) 222 | self.conv0_0 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 223 | self.conv0_1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 224 | 225 | self.conv1 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 226 | self.conv2 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 227 | self.conv3 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 228 | self.conv4 = conv_block(in_channels, in_channels, 3, 1, bn_layer=True, activate=False) 229 | self.conv5 = conv_block(in_channels, out_channels, 3, 1, bn_layer=False, activate=True) 230 | 231 | def forward(self, x): 232 | x_p = self.conv0_0(x) 233 | x_c = self.conv0_1(x) 234 | 235 | x_p = self.p_net(x_p) 236 | x_c = self.c_net(x_c) 237 | 238 | x_p_0 = self.conv1(x_p) 239 | x_p_1 = self.conv2(x_p_0) 240 | 241 | x_c_0 = self.conv4(x_c) 242 | x_c_1 = self.conv5(x_c_0) 243 | 244 | x_p_c = self.conv3(x_p_0 + x_c_0) 245 | return x_p_1, x_c_1, x_p_c 246 | 247 | 248 | class BoundaryModulePlus(BoundaryModule): 249 | def __init__(self, n_feature_in, n_features=(512, 128), num_samples: int = 10, temporal_dim: int = 512, 250 | max_duration: int = 40 251 | ): 252 | super().__init__(n_feature_in, n_features, num_samples, temporal_dim, max_duration) 253 | del self.block1 254 | dim0, dim1 = n_features 255 | # (B, dim0, max_duration, temporal_dim) -> (B, max_duration, temporal_dim) 256 | self.block1 = Sequential( 257 | Conv2d(dim0, dim1, kernel_size=1, build_activation=LeakyReLU), 258 | Conv2d(dim1, dim1, kernel_size=3, padding=1, build_activation=LeakyReLU) 259 | ) 260 | # Proposal Relation Block in BSN++ mechanism 261 | self.proposal_block = ProposalRelationBlock(dim1, dim1, 1, sub_sample=True) 262 | self.out = Rearrange("b c d t -> b (c d) t") 263 | 264 | def forward(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 265 | confidence_map = self.bm_layer(feature) 266 | confidence_map = self.block0(confidence_map) 267 | confidence_map = self.block1(confidence_map) 268 | confidence_map_p, confidence_map_c, confidence_map_p_c = self.proposal_block(confidence_map) 269 | 270 | confidence_map_p = self.out(confidence_map_p) 271 | confidence_map_c = self.out(confidence_map_c) 272 | confidence_map_p_c = self.out(confidence_map_p_c) 273 | return confidence_map_p, confidence_map_c, confidence_map_p_c 274 | -------------------------------------------------------------------------------- /dataset/lavdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional, List, Callable, Any, Union, Tuple 5 | 6 | import einops 7 | import numpy as np 8 | import scipy as sp 9 | import torch 10 | import torchaudio 11 | from einops import rearrange 12 | from pytorch_lightning import LightningDataModule 13 | from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS 14 | from torch import Tensor 15 | from torch.nn import functional as F, Identity 16 | from torch.utils.data import DataLoader, RandomSampler 17 | from torch.utils.data import Dataset 18 | 19 | from utils import read_json, read_video, padding_video, padding_audio, resize_video, iou_with_anchors, ioa_with_anchors 20 | 21 | 22 | @dataclass 23 | class Metadata: 24 | file: str 25 | n_fakes: int 26 | fake_periods: List[List[int]] 27 | duration: float 28 | original: Optional[str] 29 | modify_video: bool 30 | modify_audio: bool 31 | split: str 32 | video_frames: int 33 | audio_channels: int 34 | audio_frames: int 35 | 36 | 37 | T_LABEL = Union[Tensor, Tuple[Tensor, Tensor, Tensor]] 38 | 39 | 40 | class Lavdf(Dataset): 41 | 42 | def __init__(self, subset: str, root: str = "data", frame_padding: int = 512, 43 | max_duration: int = 40, fps: int = 25, 44 | video_transform: Callable[[Tensor], Tensor] = Identity(), 45 | audio_transform: Callable[[Tensor], Tensor] = Identity(), 46 | metadata: Optional[List[Metadata]] = None, 47 | get_meta_attr: Callable[[Metadata, Tensor, Tensor, T_LABEL], List[Any]] = None, 48 | require_match_scores: bool = False, 49 | return_file_name: bool = False 50 | ): 51 | self.subset = subset 52 | self.root = root 53 | self.video_padding = frame_padding 54 | self.audio_padding = int(frame_padding / fps * 16000) 55 | self.max_duration = max_duration 56 | self.video_transform = video_transform 57 | self.audio_transform = audio_transform 58 | self.get_meta_attr = get_meta_attr 59 | self.require_match_scores = require_match_scores 60 | self.return_file_name = return_file_name 61 | 62 | label_dir = os.path.join(self.root, "label") 63 | if not os.path.exists(label_dir): 64 | os.mkdir(label_dir) 65 | 66 | if metadata is None: 67 | metadata: List[Metadata] = read_json(os.path.join(self.root, "metadata.min.json"), lambda x: Metadata(**x)) 68 | self.metadata: List[Metadata] = [each for each in metadata if each.split == subset] 69 | 70 | else: 71 | self.metadata: List[Metadata] = metadata 72 | 73 | if self.require_match_scores: 74 | temporal_gap = 1 / self.max_duration 75 | # [-0.05, ..., 0.985] 76 | self.anchor_x_min = [temporal_gap * (i - 0.5) for i in range(self.video_padding)] 77 | # [0.05, ..., 0.995] 78 | self.anchor_x_max = [temporal_gap * (i + 0.5) for i in range(self.video_padding)] 79 | else: 80 | self.anchor_x_min = None 81 | self.anchor_x_max = None 82 | print(f"Load {len(self.metadata)} data in {subset}.") 83 | 84 | def __getitem__(self, index: int) -> List[Tensor]: 85 | meta = self.metadata[index] 86 | video, audio, _ = read_video(os.path.join(self.root, meta.file)) 87 | video = padding_video(video, target=self.video_padding) 88 | audio = padding_audio(audio, target=self.audio_padding) 89 | 90 | video = self.video_transform(video) 91 | audio = self.audio_transform(audio) 92 | 93 | video = rearrange(resize_video(video, (96, 96)), "t c h w -> c t h w") 94 | audio = self._get_log_mel_spectrogram(audio) 95 | 96 | if not self.require_match_scores: 97 | label = self.get_label(meta) 98 | outputs = [video, audio, label] + self.get_meta_attr(meta, video, audio, label) 99 | else: 100 | label = self.get_label_with_match_scores(meta) 101 | outputs = [video, audio, *label] + self.get_meta_attr(meta, video, audio, label) 102 | 103 | if self.return_file_name: 104 | outputs.append(meta.file) 105 | 106 | return outputs 107 | 108 | def get_label(self, meta: Metadata) -> Tensor: 109 | file_name = meta.file.split("/")[-1].split(".")[0] + ".npy" 110 | path = os.path.join(self.root, "label", file_name) 111 | if os.path.exists(path): 112 | try: 113 | arr = np.load(path) 114 | except ValueError: 115 | pass 116 | else: 117 | return torch.tensor(arr) 118 | 119 | label = self._get_train_label(meta.video_frames, meta.fake_periods, meta.video_frames).numpy() 120 | # cache label 121 | np.save(path, label) 122 | return torch.tensor(label) 123 | 124 | def get_label_with_match_scores(self, meta: Metadata) -> Tuple[Tensor, Tensor, Tensor]: 125 | Path(os.path.join(self.root, "label")).mkdir(parents=True, exist_ok=True) 126 | Path(os.path.join(self.root, "match_scores")).mkdir(parents=True, exist_ok=True) 127 | 128 | boundary_map_file_name = meta.file.split("/")[-1].split(".")[0] + ".npy" 129 | boundary_map_file_path = os.path.join(self.root, "label", boundary_map_file_name) 130 | 131 | match_scores_file_name = meta.file.split("/")[-1].split(".")[0] + ".npz" 132 | match_scores_file_path = os.path.join(self.root, "match_scores", match_scores_file_name) 133 | 134 | if os.path.exists(boundary_map_file_path) and os.path.exists(match_scores_file_path): 135 | try: 136 | boundary_map = np.load(boundary_map_file_path) 137 | match_scores = np.load(match_scores_file_path) 138 | except ValueError: 139 | pass 140 | else: 141 | return ( 142 | torch.tensor(boundary_map), 143 | torch.tensor(match_scores["match_score_start"]), 144 | torch.tensor(match_scores["match_score_end"]) 145 | ) 146 | 147 | boundary_map, match_score_start, match_score_end = self._get_train_label( 148 | meta.video_frames, meta.fake_periods, meta.video_frames 149 | ) 150 | 151 | # cache label 152 | np.save(boundary_map_file_path, boundary_map.numpy()) 153 | np.savez( 154 | match_scores_file_path, 155 | match_score_start=match_score_start.numpy(), 156 | match_score_end=match_score_end.numpy() 157 | ) 158 | 159 | return boundary_map, match_score_start, match_score_end 160 | 161 | def gen_label(self) -> None: 162 | # manually pre-generate label as npy 163 | for meta in self.metadata: 164 | self.get_label(meta) 165 | 166 | def __len__(self) -> int: 167 | return len(self.metadata) 168 | 169 | @staticmethod 170 | def _get_log_mel_spectrogram(audio: Tensor) -> Tensor: 171 | ms = torchaudio.transforms.MelSpectrogram(n_fft=321, n_mels=64) 172 | spec = torch.log(ms(audio[:, 0]) + 0.01) 173 | assert spec.shape == (64, 2048), "Wrong log mel-spectrogram setup in Dataset" 174 | return spec 175 | 176 | def _get_train_label(self, frames, video_labels, temporal_scale, fps=25) -> T_LABEL: 177 | corrected_second = frames / fps 178 | temporal_gap = 1 / temporal_scale 179 | 180 | ############################################################################################## 181 | # change the measurement from second to percentage 182 | gt_bbox = [] 183 | for j in range(len(video_labels)): 184 | tmp_start = max(min(1, video_labels[j][0] / corrected_second), 0) 185 | tmp_end = max(min(1, video_labels[j][1] / corrected_second), 0) 186 | gt_bbox.append([tmp_start, tmp_end]) 187 | 188 | #################################################################################################### 189 | # generate R_s and R_e 190 | gt_bbox = torch.tensor(gt_bbox) 191 | if len(gt_bbox) > 0: 192 | gt_xmins = gt_bbox[:, 0] 193 | gt_xmaxs = gt_bbox[:, 1] 194 | else: 195 | gt_xmins = np.array([]) 196 | gt_xmaxs = np.array([]) 197 | ##################################################################################################### 198 | 199 | gt_iou_map = torch.zeros([self.max_duration, temporal_scale]) 200 | if len(gt_bbox) > 0: 201 | for begin in range(temporal_scale): 202 | for duration in range(self.max_duration): 203 | end = begin + duration 204 | if end > temporal_scale: 205 | break 206 | gt_iou_map[duration, begin] = torch.max( 207 | iou_with_anchors(begin * temporal_gap, (end + 1) * temporal_gap, gt_xmins, gt_xmaxs)) 208 | # [i, j]: Start in i, end in j. 209 | 210 | ########################################################################################################## 211 | gt_iou_map = F.pad(gt_iou_map.float(), pad=[0, self.video_padding - frames, 0, 0]) 212 | 213 | if not self.require_match_scores: 214 | return gt_iou_map 215 | 216 | gt_len_small = 3 * temporal_gap 217 | gt_start_bboxs = np.stack((gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1) 218 | gt_end_bboxs = np.stack((gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1) 219 | 220 | ########################################################################################################## 221 | # calculate the ioa for all timestamp 222 | if len(gt_start_bboxs) > 0: 223 | match_score_start = [] 224 | for jdx in range(len(self.anchor_x_min)): 225 | match_score_start.append(np.max(ioa_with_anchors(self.anchor_x_min[jdx], self.anchor_x_max[jdx], 226 | gt_start_bboxs[:, 0], gt_start_bboxs[:, 1]))) 227 | 228 | match_score_end = [] 229 | for jdx in range(len(self.anchor_x_min)): 230 | match_score_end.append(np.max(ioa_with_anchors(self.anchor_x_min[jdx], self.anchor_x_max[jdx], 231 | gt_end_bboxs[:, 0], gt_end_bboxs[:, 1]))) 232 | match_score_start = torch.Tensor(match_score_start) 233 | match_score_end = torch.Tensor(match_score_end) 234 | else: 235 | match_score_start = torch.zeros(len(self.anchor_x_min)) 236 | match_score_end = torch.zeros(len(self.anchor_x_min)) 237 | ############################################################################################################ 238 | return gt_iou_map, match_score_start, match_score_end 239 | 240 | 241 | def _default_get_meta_attr(meta: Metadata, video: Tensor, audio: Tensor, label: Tensor) -> List[Any]: 242 | return [meta.video_frames] 243 | 244 | 245 | class LavdfDataModule(LightningDataModule): 246 | train_dataset: Lavdf 247 | dev_dataset: Lavdf 248 | test_dataset: Lavdf 249 | metadata: List[Metadata] 250 | 251 | def __init__(self, root: str = "data", frame_padding=512, max_duration=40, 252 | require_match_scores: bool = False, feature_types: Tuple[Optional[str], Optional[str]] = (None, None), 253 | batch_size: int = 1, num_workers: int = 0, 254 | take_train: int = None, take_dev: int = None, take_test: int = None, 255 | cond: Optional[Callable[[Metadata], bool]] = None, 256 | get_meta_attr: Callable[[Metadata, Tensor, Tensor, Tensor], List[Any]] = _default_get_meta_attr, 257 | return_file_name: bool = False 258 | ): 259 | super().__init__() 260 | self.root = root 261 | self.frame_padding = frame_padding 262 | self.max_duration = max_duration 263 | self.require_match_scores = require_match_scores 264 | self.batch_size = batch_size 265 | self.num_workers = num_workers 266 | self.take_train = take_train 267 | self.take_dev = take_dev 268 | self.take_test = take_test 269 | self.cond = cond 270 | self.get_meta_attr = get_meta_attr 271 | self.return_file_name = return_file_name 272 | self.Dataset = feature_type_to_dataset_type[feature_types] 273 | 274 | def setup(self, stage: Optional[str] = None) -> None: 275 | self.metadata: List[Metadata] = read_json(os.path.join(self.root, "metadata.min.json"), lambda x: Metadata(**x)) 276 | 277 | train_metadata = [] 278 | dev_metadata = [] 279 | test_metadata = [] 280 | 281 | for meta in self.metadata: 282 | if self.cond is None or self.cond(meta): 283 | if meta.split == "train": 284 | train_metadata.append(meta) 285 | elif meta.split == "dev": 286 | dev_metadata.append(meta) 287 | elif meta.split == "test": 288 | test_metadata.append(meta) 289 | 290 | if self.take_dev is not None: 291 | dev_metadata = dev_metadata[:self.take_dev] 292 | 293 | self.train_dataset = self.Dataset("train", self.root, self.frame_padding, self.max_duration, 294 | metadata=train_metadata, get_meta_attr=self.get_meta_attr, 295 | require_match_scores=self.require_match_scores, 296 | return_file_name=self.return_file_name 297 | ) 298 | self.dev_dataset = self.Dataset("dev", self.root, self.frame_padding, self.max_duration, 299 | metadata=dev_metadata, get_meta_attr=self.get_meta_attr, 300 | require_match_scores=self.require_match_scores, 301 | return_file_name=self.return_file_name 302 | ) 303 | self.test_dataset = self.Dataset("test", self.root, self.frame_padding, self.max_duration, 304 | metadata=test_metadata, get_meta_attr=self.get_meta_attr, 305 | require_match_scores=self.require_match_scores, 306 | return_file_name=self.return_file_name 307 | ) 308 | 309 | def train_dataloader(self) -> TRAIN_DATALOADERS: 310 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, 311 | sampler=RandomSampler(self.train_dataset, num_samples=self.take_train, replacement=True) 312 | ) 313 | 314 | def val_dataloader(self) -> EVAL_DATALOADERS: 315 | return DataLoader(self.dev_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 316 | 317 | def test_dataloader(self) -> EVAL_DATALOADERS: 318 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 319 | 320 | 321 | # The dictionary is used to map the feature type to the dataset type 322 | # The key is a tuple of (visual_feature_type, audio_feature_type), ``None`` means using end-to-end encoder. 323 | feature_type_to_dataset_type = { 324 | (None, None): Lavdf 325 | } 326 | -------------------------------------------------------------------------------- /model/batfd_plus.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union, Sequence, Tuple 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule 5 | from torch import Tensor 6 | from torch.nn import BCEWithLogitsLoss 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ExponentialLR 9 | 10 | from dataset.lavdf import Metadata 11 | from loss import MaskedFrameLoss, MaskedContrastLoss, MaskedBsnppLoss 12 | from .audio_encoder import get_audio_encoder 13 | from .boundary_module_plus import BoundaryModulePlus, NestedUNet 14 | from .frame_classifier import FrameLogisticRegression 15 | from .fusion_module import ModalFeatureAttnBoundaryMapFusion, ModalFeatureAttnCfgFusion 16 | from .video_encoder import get_video_encoder 17 | 18 | 19 | class BatfdPlus(LightningModule): 20 | 21 | def __init__(self, 22 | v_encoder: str = "c3d", a_encoder: str = "cnn", frame_classifier: str = "lr", 23 | ve_features=(64, 96, 128, 128), ae_features=(32, 64, 64), v_cla_feature_in=256, a_cla_feature_in=256, 24 | boundary_features=(512, 128), boundary_samples=10, temporal_dim=512, max_duration=40, 25 | weight_frame_loss=2., weight_modal_bm_loss=1., weight_contrastive_loss=0.1, contrast_loss_margin=0.99, 26 | cbg_feature_weight=0.01, prb_weight_forward=1., 27 | weight_decay=0.0001, learning_rate=0.0002, distributed=False 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | 32 | self.cla_feature_in = v_cla_feature_in 33 | self.temporal_dim = temporal_dim 34 | 35 | self.video_encoder = get_video_encoder(v_cla_feature_in, temporal_dim, v_encoder, ve_features) 36 | self.audio_encoder = get_audio_encoder(a_cla_feature_in, temporal_dim, a_encoder, ae_features) 37 | 38 | if frame_classifier == "lr": 39 | self.video_frame_classifier = FrameLogisticRegression(n_features=v_cla_feature_in) 40 | self.audio_frame_classifier = FrameLogisticRegression(n_features=a_cla_feature_in) 41 | 42 | assert self.video_encoder and self.audio_encoder and self.video_frame_classifier and self.audio_frame_classifier 43 | 44 | assert v_cla_feature_in == a_cla_feature_in 45 | 46 | v_bm_in = v_cla_feature_in + 1 47 | a_bm_in = a_cla_feature_in + 1 48 | 49 | # Complementary Boundary Generator in BSN++ mechanism 50 | self.video_comp_boundary_generator = NestedUNet(in_ch=v_bm_in, out_ch=2) 51 | self.audio_comp_boundary_generator = NestedUNet(in_ch=a_bm_in, out_ch=2) 52 | 53 | # Proposal Relation Block in BSN++ mechanism 54 | self.video_boundary_module = BoundaryModulePlus(v_bm_in, boundary_features, boundary_samples, temporal_dim, 55 | max_duration 56 | ) 57 | self.audio_boundary_module = BoundaryModulePlus(a_bm_in, boundary_features, boundary_samples, temporal_dim, 58 | max_duration 59 | ) 60 | 61 | if cbg_feature_weight > 0: 62 | self.cbg_fusion_start = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in) 63 | self.cbg_fusion_end = ModalFeatureAttnCfgFusion(v_bm_in, a_bm_in) 64 | else: 65 | self.cbg_fusion_start = None 66 | self.cbg_fusion_end = None 67 | 68 | self.prb_fusion_p = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 69 | self.prb_fusion_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 70 | self.prb_fusion_p_c = ModalFeatureAttnBoundaryMapFusion(v_bm_in, a_bm_in, max_duration) 71 | 72 | self.frame_loss = MaskedFrameLoss(BCEWithLogitsLoss()) 73 | self.contrast_loss = MaskedContrastLoss(margin=contrast_loss_margin) 74 | self.bm_loss = MaskedBsnppLoss(cbg_feature_weight, prb_weight_forward) 75 | self.weight_frame_loss = weight_frame_loss 76 | self.weight_modal_bm_loss = weight_modal_bm_loss 77 | self.weight_contrastive_loss = weight_contrastive_loss / (v_cla_feature_in * temporal_dim) 78 | self.weight_decay = weight_decay 79 | self.learning_rate = learning_rate 80 | self.distributed = distributed 81 | 82 | def forward(self, video: Tensor, audio: Tensor) -> Sequence[Tensor]: 83 | a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla = self.forward_features(audio, video) 84 | 85 | # modal boundary module 86 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c = self.video_boundary_module(v_bm_in) 87 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c = self.audio_boundary_module(a_bm_in) 88 | 89 | # complementary boundary generator 90 | if self.cbg_fusion_start is not None: 91 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in) 92 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in) 93 | else: 94 | v_cbg_feature, v_cbg_start, v_cbg_end = None, None, None 95 | a_cbg_feature, a_cbg_start, a_cbg_end = None, None, None 96 | 97 | # boundary map modal attention fusion 98 | fusion_bm_map_p = self.prb_fusion_p(v_bm_in, a_bm_in, v_bm_map_p, a_bm_map_p) 99 | fusion_bm_map_c = self.prb_fusion_c(v_bm_in, a_bm_in, v_bm_map_c, a_bm_map_c) 100 | fusion_bm_map_p_c = self.prb_fusion_p_c(v_bm_in, a_bm_in, v_bm_map_p_c, a_bm_map_p_c) 101 | 102 | # complementary boundary generator modal attention fusion 103 | if self.cbg_fusion_start is not None: 104 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start) 105 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end) 106 | else: 107 | fusion_cbg_start = None 108 | fusion_cbg_end = None 109 | 110 | return ( 111 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 112 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 113 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 114 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 115 | ) 116 | 117 | def forward_back(self, video: Tensor, audio: Tensor) -> Sequence[Optional[Tensor]]: 118 | if self.cbg_fusion_start is not None: 119 | a_bm_in, _, _, v_bm_in, _, _ = self.forward_features(audio, video) 120 | 121 | # complementary boundary generator 122 | v_cbg_feature, v_cbg_start, v_cbg_end = self.forward_video_cbg(v_bm_in) 123 | a_cbg_feature, a_cbg_start, a_cbg_end = self.forward_audio_cbg(a_bm_in) 124 | 125 | # complementary boundary generator modal attention fusion 126 | fusion_cbg_start = self.cbg_fusion_start(v_bm_in, a_bm_in, v_cbg_start, a_cbg_start) 127 | fusion_cbg_end = self.cbg_fusion_end(v_bm_in, a_bm_in, v_cbg_end, a_cbg_end) 128 | 129 | return ( 130 | fusion_cbg_start, fusion_cbg_end, v_cbg_start, v_cbg_end, a_cbg_start, a_cbg_end, 131 | v_cbg_feature, a_cbg_feature 132 | ) 133 | else: 134 | return None, None, None, None, None, None, None, None 135 | 136 | def forward_features(self, audio, video): 137 | # encoders 138 | v_features = self.video_encoder(video) 139 | a_features = self.audio_encoder(audio) 140 | # frame classifiers 141 | v_frame_cla = self.video_frame_classifier(v_features) 142 | a_frame_cla = self.audio_frame_classifier(a_features) 143 | # concat classification result to features 144 | v_bm_in = torch.column_stack([v_features, v_frame_cla]) 145 | a_bm_in = torch.column_stack([a_features, a_frame_cla]) 146 | return a_bm_in, a_features, a_frame_cla, v_bm_in, v_features, v_frame_cla 147 | 148 | def forward_video_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 149 | cbg_prob, cbg_feature = self.video_comp_boundary_generator(feature) 150 | start = cbg_prob[:, 0, :].squeeze(1) 151 | end = cbg_prob[:, 1, :].squeeze(1) 152 | return cbg_feature, end, start 153 | 154 | def forward_audio_cbg(self, feature: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 155 | cbg_prob, cbg_feature = self.audio_comp_boundary_generator(feature) 156 | start = cbg_prob[:, 0, :].squeeze(1) 157 | end = cbg_prob[:, 1, :].squeeze(1) 158 | return cbg_feature, end, start 159 | 160 | def loss_fn(self, 161 | fusion_bm_map_p: Tensor, fusion_bm_map_c: Tensor, fusion_bm_map_p_c: Tensor, 162 | fusion_cbg_start: Tensor, fusion_cbg_end: Tensor, 163 | fusion_cbg_start_back: Tensor, fusion_cbg_end_back: Tensor, 164 | v_bm_map_p: Tensor, v_bm_map_c: Tensor, v_bm_map_p_c: Tensor, 165 | v_cbg_start: Tensor, v_cbg_end: Tensor, v_cbg_feature: Tensor, 166 | v_cbg_start_back: Tensor, v_cbg_end_back: Tensor, v_cbg_feature_back: Tensor, 167 | a_bm_map_p: Tensor, a_bm_map_c: Tensor, a_bm_map_p_c: Tensor, 168 | a_cbg_start: Tensor, a_cbg_end: Tensor, a_cbg_feature: Tensor, 169 | a_cbg_start_back: Tensor, a_cbg_end_back: Tensor, a_cbg_feature_back: Tensor, 170 | v_frame_cla: Tensor, a_frame_cla: Tensor, n_frames: Tensor, 171 | fusion_bm_label: Tensor, fusion_start_label: Tensor, fusion_end_label: Tensor, 172 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label, 173 | v_frame_label, a_frame_label, contrast_label, v_features, a_features 174 | ) -> Dict[str, Tensor]: 175 | ( 176 | fusion_bm_loss, fusion_cbg_loss, fusion_prb_loss, fusion_cbg_loss_forward, fusion_cbg_loss_backward, _ 177 | ) = self.bm_loss( 178 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 179 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back, 180 | fusion_bm_label, fusion_start_label, fusion_end_label, n_frames 181 | ) 182 | 183 | ( 184 | v_bm_loss, v_cbg_loss, v_prb_loss, v_cbg_loss_forward, v_cbg_loss_backward, v_cbg_feature_loss 185 | ) = self.bm_loss( 186 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 187 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back, 188 | v_bm_label, v_start_label, v_end_label, n_frames, 189 | v_cbg_feature, v_cbg_feature_back 190 | ) 191 | 192 | ( 193 | a_bm_loss, a_cbg_loss, a_prb_loss, a_cbg_loss_forward, a_cbg_loss_backward, a_cbg_feature_loss 194 | ) = self.bm_loss( 195 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 196 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back, 197 | a_bm_label, a_start_label, a_end_label, n_frames, 198 | a_cbg_feature, a_cbg_feature_back 199 | ) 200 | 201 | v_frame_loss = self.frame_loss(v_frame_cla.squeeze(1), v_frame_label, n_frames) 202 | a_frame_loss = self.frame_loss(a_frame_cla.squeeze(1), a_frame_label, n_frames) 203 | 204 | contrast_loss = torch.clip(self.contrast_loss(v_features, a_features, contrast_label, n_frames) 205 | / (self.cla_feature_in * self.temporal_dim), max=1.) 206 | 207 | loss = fusion_bm_loss + \ 208 | self.weight_modal_bm_loss * (a_bm_loss + v_bm_loss) / 2 + \ 209 | self.weight_frame_loss * (a_frame_loss + v_frame_loss) / 2 + \ 210 | self.weight_contrastive_loss * contrast_loss 211 | 212 | loss_dict = { 213 | "loss": loss, "fusion_bm_loss": fusion_bm_loss, "v_bm_loss": v_bm_loss, "a_bm_loss": a_bm_loss, 214 | "v_frame_loss": v_frame_loss, "a_frame_loss": a_frame_loss, "contrast_loss": contrast_loss, 215 | "fusion_cbg_loss": fusion_cbg_loss, "v_cbg_loss": v_cbg_loss, "a_cbg_loss": a_cbg_loss, 216 | "fusion_prb_loss": fusion_prb_loss, "v_prb_loss": v_prb_loss, "a_prb_loss": a_prb_loss, 217 | "fusion_cbg_loss_forward": fusion_cbg_loss_forward, "v_cbg_loss_forward": v_cbg_loss_forward, 218 | "a_cbg_loss_forward": a_cbg_loss_forward, "fusion_cbg_loss_backward": fusion_cbg_loss_backward, 219 | "v_cbg_loss_backward": v_cbg_loss_backward, "a_cbg_loss_backward": a_cbg_loss_backward, 220 | "v_cbg_feature_loss": v_cbg_feature_loss, "a_cbg_feature_loss": a_cbg_feature_loss 221 | } 222 | return {k: v for k, v in loss_dict.items() if v is not None} 223 | 224 | def step(self, batch: Sequence[Tensor]) -> Dict[str, Tensor]: 225 | ( 226 | video, audio, fusion_bm_label, fusion_start_label, fusion_end_label, n_frames, 227 | v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, 228 | a_start_label, v_start_label, a_end_label, v_end_label 229 | ) = batch 230 | # forward 231 | ( 232 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 233 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 234 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 235 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 236 | ) = self(video, audio) 237 | # BSN++ back 238 | video_back = torch.flip(video, dims=(2,)) 239 | audio_back = torch.flip(audio, dims=(2,)) 240 | ( 241 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back, 242 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back 243 | ) = self.forward_back(video_back, audio_back) 244 | 245 | # loss 246 | loss_dict = self.loss_fn( 247 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 248 | fusion_cbg_start, fusion_cbg_end, 249 | fusion_cbg_start_back, fusion_cbg_end_back, 250 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 251 | v_cbg_start, v_cbg_end, v_cbg_feature, 252 | v_cbg_start_back, v_cbg_end_back, v_cbg_feature_back, 253 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 254 | a_cbg_start, a_cbg_end, a_cbg_feature, 255 | a_cbg_start_back, a_cbg_end_back, a_cbg_feature_back, 256 | v_frame_cla, a_frame_cla, n_frames, 257 | fusion_bm_label, fusion_start_label, fusion_end_label, 258 | v_bm_label, a_bm_label, v_start_label, a_start_label, v_end_label, a_end_label, 259 | v_frame_label, a_frame_label, contrast_label, v_features, a_features 260 | ) 261 | return loss_dict 262 | 263 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 264 | optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None 265 | ) -> Tensor: 266 | loss_dict = self.step(batch) 267 | 268 | self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 269 | prog_bar=False, sync_dist=self.distributed) 270 | return loss_dict["loss"] 271 | 272 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 273 | dataloader_idx: Optional[int] = None 274 | ) -> Tensor: 275 | loss_dict = self.step(batch) 276 | 277 | self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 278 | prog_bar=False, sync_dist=self.distributed) 279 | return loss_dict["loss"] 280 | 281 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None 282 | ) -> Tuple[ 283 | Tensor, Optional[Tensor], Optional[Tensor], 284 | Tensor, Optional[Tensor], Optional[Tensor], 285 | Tensor, Optional[Tensor], Optional[Tensor] 286 | ]: 287 | video, audio, *_ = batch 288 | # forward 289 | ( 290 | fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, fusion_cbg_start, fusion_cbg_end, 291 | v_bm_map_p, v_bm_map_c, v_bm_map_p_c, v_cbg_start, v_cbg_end, 292 | a_bm_map_p, a_bm_map_c, a_bm_map_p_c, a_cbg_start, a_cbg_end, 293 | v_frame_cla, a_frame_cla, v_features, a_features, v_cbg_feature, a_cbg_feature 294 | ) = self(video, audio) 295 | # BSN++ back 296 | video_back = torch.flip(video, dims=(2,)) 297 | audio_back = torch.flip(audio, dims=(2,)) 298 | ( 299 | fusion_cbg_start_back, fusion_cbg_end_back, v_cbg_start_back, v_cbg_end_back, 300 | a_cbg_start_back, a_cbg_end_back, v_cbg_feature_back, a_cbg_feature_back 301 | ) = self.forward_back(video_back, audio_back) 302 | 303 | fusion_bm_map, start, end = self.post_process_predict(fusion_bm_map_p, fusion_bm_map_c, fusion_bm_map_p_c, 304 | fusion_cbg_start, fusion_cbg_end, fusion_cbg_start_back, fusion_cbg_end_back 305 | ) 306 | 307 | v_bm_map, v_start, v_end = self.post_process_predict(v_bm_map_p, v_bm_map_c, v_bm_map_p_c, 308 | v_cbg_start, v_cbg_end, v_cbg_start_back, v_cbg_end_back 309 | ) 310 | 311 | a_bm_map, a_start, a_end = self.post_process_predict(a_bm_map_p, a_bm_map_c, a_bm_map_p_c, 312 | a_cbg_start, a_cbg_end, a_cbg_start_back, a_cbg_end_back 313 | ) 314 | 315 | return fusion_bm_map, start, end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end 316 | 317 | def post_process_predict(self, 318 | bm_map_p: Tensor, bm_map_c: Tensor, bm_map_p_c: Tensor, 319 | cbg_start: Tensor, cbg_end: Tensor, 320 | cbg_start_back: Tensor, cbg_end_back: Tensor 321 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 322 | 323 | bm_map = (bm_map_p + bm_map_c + bm_map_p_c) / 3 324 | if self.cbg_fusion_start is not None: 325 | start = torch.sqrt(cbg_start * torch.flip(cbg_end_back, dims=(1,))) 326 | end = torch.sqrt(cbg_end * torch.flip(cbg_start_back, dims=(1,))) 327 | else: 328 | start = None 329 | end = None 330 | 331 | return bm_map, start, end 332 | 333 | def configure_optimizers(self): 334 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9), weight_decay=self.weight_decay) 335 | return { 336 | "optimizer": optimizer, 337 | "lr_scheduler": { 338 | "scheduler": ExponentialLR(optimizer, gamma=0.992), 339 | "monitor": "val_loss" 340 | } 341 | } 342 | 343 | @classmethod 344 | def get_meta_attr(cls, meta: Metadata, video: Tensor, audio: Tensor, label: Tuple[Tensor, Tensor, Tensor]): 345 | fusion_bm_label, fusion_start_label, fusion_end_label = label 346 | 347 | a_bm_label, v_bm_label = cls.gen_audio_video_labels(fusion_bm_label, meta) 348 | a_start_label, v_start_label = cls.gen_audio_video_labels(fusion_start_label, meta) 349 | a_end_label, v_end_label = cls.gen_audio_video_labels(fusion_end_label, meta) 350 | 351 | frame_label_real = torch.zeros(512) 352 | frame_label_fake = torch.zeros(512) 353 | for begin, end in meta.fake_periods: 354 | begin = int(begin * 25) 355 | end = int(end * 25) 356 | frame_label_fake[begin: end] = 1 357 | 358 | v_frame_label = frame_label_fake if meta.modify_video else frame_label_real 359 | a_frame_label = frame_label_fake if meta.modify_audio else frame_label_real 360 | 361 | contrast_label = 0 if meta.modify_audio or meta.modify_video else 1 362 | 363 | return [ 364 | meta.video_frames, v_bm_label, a_bm_label, v_frame_label, a_frame_label, contrast_label, 365 | a_start_label, v_start_label, a_end_label, v_end_label 366 | ] 367 | 368 | @classmethod 369 | def gen_audio_video_labels(cls, label_fake: Tensor, meta: Metadata): 370 | label_real = torch.zeros(label_fake.size(), dtype=label_fake.dtype, device=label_fake.device) 371 | v_label = label_fake if meta.modify_video else label_real 372 | a_label = label_fake if meta.modify_audio else label_real 373 | return a_label, v_label 374 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | --------------------------------------------------------------------------------