├── tmp └── txt.txt ├── summary_videos └── txt.txt ├── weights ├── SumMe │ └── txt.txt └── TVSum │ └── txt.txt ├── requirements.txt ├── kts ├── LICENSE ├── README.md ├── cpd_auto.py ├── demo.py └── cpd_nonlin.py ├── knapsack_implementation.py ├── LICENSE ├── splits ├── SumMe_splits.txt └── TVSum_splits.txt ├── evaluation_metrics.py ├── generate_summary.py ├── dataset.py ├── config.py ├── video_helper.py ├── inference.py ├── generate_video.py ├── model.py ├── models ├── positional_encoding.py ├── GoogleNet.py ├── MobileNet.py ├── ResNet.py └── EfficientNet.py ├── utils.py ├── train.py └── README.md /tmp/txt.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /summary_videos/txt.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weights/SumMe/txt.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /weights/TVSum/txt.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.1.0 2 | numpy==1.19.5 3 | scipy==1.5.2 4 | torch==2.2.1 5 | torchvision==0.17.1 6 | tqdm==4.61.0 -------------------------------------------------------------------------------- /kts/LICENSE: -------------------------------------------------------------------------------- 1 | #### DISCLAIMER 2 | This kts library is downloaded from: 3 | - http://lear.inrialpes.fr/software 4 | - http://pascal.inrialpes.fr/data2/potapov/med_summaries/kts_ver1.1.tar.gz 5 | 6 | I just modified the original code to remove weave dependecy. Please follow the 7 | original LICENSE from LEAR if you are using kts. 8 | -------------------------------------------------------------------------------- /knapsack_implementation.py: -------------------------------------------------------------------------------- 1 | # Knapsack algorithm to find shots having high score 2 | def knapSack(W, wt, val, n): 3 | K = [[0 for _ in range(W + 1)] for _ in range(n + 1)] 4 | 5 | for i in range(n + 1): 6 | for w in range(W + 1): 7 | if i == 0 or w == 0: 8 | K[i][w] = 0 9 | elif wt[i - 1] <= w: 10 | K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], K[i - 1][w]) 11 | else: 12 | K[i][w] = K[i - 1][w] 13 | 14 | selected = [] 15 | w = W 16 | for i in range(n, 0, -1): 17 | if K[i][w] != K[i - 1][w]: 18 | selected.insert(0, i - 1) 19 | w -= wt[i - 1] 20 | 21 | return selected -------------------------------------------------------------------------------- /kts/README.md: -------------------------------------------------------------------------------- 1 | This code is from [DS-Net](https://github.com/li-plus/DSNet.git).
2 | 3 | Kernel temporal segmentation 4 | ============================ 5 | 6 |
7 | #### DISCLAIMER 8 | This kts library is downloaded from: 9 | - http://lear.inrialpes.fr/software 10 | - http://pascal.inrialpes.fr/data2/potapov/med_summaries/kts_ver1.1.tar.gz 11 | 12 | I just modified the original code to remove weave dependecy. Please follow the 13 | original LICENSE from LEAR if you are using kts. 14 |
15 | 16 | #### Original documentation 17 | This archive contains the following files: 18 | * cpd_nonlin.py - kernel temporal segmentation with fixed number of segments 19 | * cpd_auto.py - kernel temporal segmentation with autocalibration 20 | * demo.py - demo on synthetic examples 21 | 22 | #### Dependencies: 23 | * python + libraries: numpy, scipy, matplotlib (for demo) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 thswodnjs3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /splits/SumMe_splits.txt: -------------------------------------------------------------------------------- 1 | SumMe/split0/video_1,video_3,video_4,video_5,video_6,video_7,video_8,video_9,video_10,video_12,video_15,video_16,video_17,video_18,video_19,video_20,video_21,video_22,video_23,video_25/video_2,video_11,video_13,video_14,video_24 2 | SumMe/split1/video_1,video_2,video_3,video_4,video_5,video_6,video_7,video_8,video_11,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_20,video_22,video_24,video_25/video_9,video_10,video_12,video_21,video_23 3 | SumMe/split2/video_1,video_2,video_3,video_4,video_6,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_21,video_22,video_23,video_24/video_5,video_7,video_8,video_20,video_25 4 | SumMe/split3/video_2,video_3,video_5,video_6,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_16,video_19,video_20,video_21,video_22,video_23,video_24,video_25/video_1,video_4,video_15,video_17,video_18 5 | SumMe/split4/video_1,video_2,video_4,video_5,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_17,video_18,video_20,video_21,video_23,video_24,video_25/video_3,video_6,video_16,video_19,video_22 6 | -------------------------------------------------------------------------------- /kts/cpd_auto.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from kts.cpd_nonlin import cpd_nonlin 4 | 5 | 6 | def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): 7 | """Detect change points automatically selecting their number 8 | 9 | :param K: Kernel between each pair of frames in video 10 | :param ncp: Maximum number of change points 11 | :param vmax: Special parameter 12 | :param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x 13 | :param kwargs: Extra parameters for ``cpd_nonlin`` 14 | :return: Tuple (cps, costs) 15 | - cps - best selected change-points 16 | - costs - costs for 0,1,2,...,m change-points 17 | """ 18 | m = ncp 19 | _, scores = cpd_nonlin(K, m, backtrack=False, **kwargs) 20 | 21 | N = K.shape[0] 22 | N2 = N * desc_rate # length of the video before down-sampling 23 | 24 | penalties = np.zeros(m + 1) 25 | # Prevent division by zero (in case of 0 changes) 26 | ncp = np.arange(1, m + 1) 27 | penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1) 28 | 29 | costs = scores / float(N) + penalties 30 | m_best = np.argmin(costs) 31 | cps, scores2 = cpd_nonlin(K, m_best, **kwargs) 32 | 33 | return cps, scores2 34 | -------------------------------------------------------------------------------- /evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import spearmanr, kendalltau, rankdata 3 | 4 | # Calculate Kendall's and Spearman's coefficients 5 | def get_corr_coeff(pred_imp_scores, videos, dataset, user_scores=None): 6 | rho_coeff, tau_coeff = [], [] 7 | if dataset=='SumMe': 8 | for pred_imp_score,video in zip(pred_imp_scores,videos): 9 | true = np.mean(user_scores,axis=0) 10 | rho_coeff.append(spearmanr(pred_imp_score,true)[0]) 11 | tau_coeff.append(kendalltau(rankdata(pred_imp_score),rankdata(true))[0]) 12 | elif dataset=='TVSum': 13 | for pred_imp_score,video in zip(pred_imp_scores,videos): 14 | pred_imp_score = np.squeeze(pred_imp_score).tolist() 15 | user = int(video.split("_")[-1]) 16 | 17 | curr_user_score = user_scores[user-1] 18 | 19 | tmp_rho_coeff, tmp_tau_coeff = [], [] 20 | for annotation in range(len(curr_user_score)): 21 | true_user_score = curr_user_score[annotation] 22 | curr_rho_coeff, _ = spearmanr(pred_imp_score, true_user_score) 23 | curr_tau_coeff, _ = kendalltau(rankdata(pred_imp_score), rankdata(true_user_score)) 24 | tmp_rho_coeff.append(curr_rho_coeff) 25 | tmp_tau_coeff.append(curr_tau_coeff) 26 | rho_coeff.append(np.mean(tmp_rho_coeff)) 27 | tau_coeff.append(np.mean(tmp_tau_coeff)) 28 | rho_coeff = np.array(rho_coeff).mean() 29 | tau_coeff = np.array(tau_coeff).mean() 30 | return rho_coeff, tau_coeff -------------------------------------------------------------------------------- /generate_summary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from knapsack_implementation import knapSack 3 | 4 | # Generate summary videos 5 | def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions): 6 | all_summaries = [] 7 | for video_index in range(len(all_scores)): 8 | shot_bound = all_shot_bound[video_index] 9 | frame_init_scores = all_scores[video_index] 10 | n_frames = all_nframes[video_index] 11 | positions = all_positions[video_index] 12 | 13 | frame_scores = np.zeros(n_frames, dtype=np.float32) 14 | if positions.dtype != int: 15 | positions = positions.astype(np.int32) 16 | if positions[-1] != n_frames: 17 | positions = np.concatenate([positions, [n_frames]]) 18 | for i in range(len(positions) - 1): 19 | pos_left, pos_right = positions[i], positions[i + 1] 20 | if i == len(frame_init_scores): 21 | frame_scores[pos_left:pos_right] = 0 22 | else: 23 | frame_scores[pos_left:pos_right] = frame_init_scores[i] 24 | 25 | shot_imp_scores = [] 26 | shot_lengths = [] 27 | for shot in shot_bound: 28 | shot_lengths.append(shot[1] - shot[0] + 1) 29 | shot_imp_scores.append((frame_scores[shot[0]:shot[1] + 1].mean()).item()) 30 | 31 | final_shot = shot_bound[-1] 32 | final_max_length = int((final_shot[1] + 1) * 0.15) 33 | 34 | selected = knapSack(final_max_length, shot_lengths, shot_imp_scores, len(shot_lengths)) 35 | 36 | summary = np.zeros(final_shot[1] + 1, dtype=np.int8) 37 | for shot in selected: 38 | summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1 39 | 40 | all_summaries.append(summary) 41 | 42 | return all_summaries 43 | -------------------------------------------------------------------------------- /splits/TVSum_splits.txt: -------------------------------------------------------------------------------- 1 | TVSum/split0/video_1,video_2,video_3,video_4,video_5,video_6,video_8,video_9,video_10,video_11,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_20,video_22,video_23,video_25,video_26,video_27,video_28,video_29,video_30,video_32,video_33,video_35,video_36,video_37,video_39,video_40,video_41,video_42,video_43,video_45,video_46,video_47,video_50/video_7,video_12,video_21,video_24,video_31,video_34,video_38,video_44,video_48,video_49 2 | TVSum/split1/video_2,video_3,video_5,video_6,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_16,video_18,video_21,video_22,video_23,video_24,video_25,video_26,video_27,video_28,video_29,video_30,video_31,video_32,video_33,video_34,video_35,video_37,video_38,video_39,video_40,video_42,video_44,video_45,video_47,video_48,video_49,video_50/video_1,video_4,video_15,video_17,video_19,video_20,video_36,video_41,video_43,video_46 3 | TVSum/split2/video_1,video_3,video_4,video_5,video_6,video_7,video_10,video_11,video_12,video_13,video_14,video_15,video_16,video_17,video_19,video_20,video_21,video_22,video_23,video_24,video_25,video_26,video_28,video_30,video_31,video_32,video_33,video_34,video_36,video_37,video_38,video_41,video_42,video_43,video_44,video_45,video_46,video_48,video_49,video_50/video_2,video_8,video_9,video_18,video_27,video_29,video_35,video_39,video_40,video_47 4 | TVSum/split3/video_1,video_2,video_3,video_4,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_17,video_18,video_19,video_20,video_21,video_24,video_26,video_27,video_29,video_30,video_31,video_32,video_34,video_35,video_36,video_37,video_38,video_39,video_40,video_41,video_42,video_43,video_44,video_46,video_47,video_48,video_49/video_5,video_6,video_16,video_22,video_23,video_25,video_28,video_33,video_45,video_50 5 | TVSum/split4/video_1,video_2,video_4,video_5,video_6,video_7,video_8,video_9,video_12,video_15,video_16,video_17,video_18,video_19,video_20,video_21,video_22,video_23,video_24,video_25,video_27,video_28,video_29,video_31,video_33,video_34,video_35,video_36,video_38,video_39,video_40,video_41,video_43,video_44,video_45,video_46,video_47,video_48,video_49,video_50/video_3,video_10,video_11,video_13,video_14,video_26,video_30,video_32,video_37,video_42 6 | -------------------------------------------------------------------------------- /kts/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | from kts.cpd_auto import cpd_auto 5 | from kts.cpd_nonlin import cpd_nonlin 6 | 7 | 8 | def gen_data(n, m, d=1): 9 | """Generates data with change points 10 | 11 | .. warning:: 12 | sigma is proportional to m 13 | 14 | :param n: Number of samples 15 | :param m: Number of change-points 16 | :param d: 17 | :return: Tuple (X, cps) 18 | - X - data array (n X d) 19 | - cps - change-points array, including 0 and n 20 | """ 21 | np.random.seed(1) 22 | # Select changes at some distance from the boundaries 23 | cps = np.random.permutation(n * 3 // 4 - 1)[0:m] + 1 + n // 8 24 | cps = np.sort(cps) 25 | cps = [0] + list(cps) + [n] 26 | mus = np.random.rand(m + 1, d) * (m / 2) # make sigma = m/2 27 | X = np.zeros((n, d)) 28 | for k in range(m + 1): 29 | X[cps[k]:cps[k + 1], :] = mus[k, :][np.newaxis, :] + np.random.rand( 30 | cps[k + 1] - cps[k], d) 31 | return X, np.array(cps) 32 | 33 | 34 | if __name__ == '__main__': 35 | plt.ioff() 36 | 37 | print('Test 1: 1-dimensional signal') 38 | plt.figure('Test 1: 1-dimensional signal') 39 | n = 1000 40 | m = 10 41 | (X, cps_gt) = gen_data(n, m) 42 | print('Ground truth:', cps_gt) 43 | plt.plot(X) 44 | K = np.dot(X, X.T) 45 | cps, scores = cpd_nonlin(K, m, lmin=1, lmax=10000) 46 | print('Estimated:', cps) 47 | mi = np.min(X) 48 | ma = np.max(X) 49 | for cp in cps: 50 | plt.plot([cp, cp], [mi, ma], 'r') 51 | plt.show() 52 | print('=' * 79) 53 | 54 | print('Test 2: multidimensional signal') 55 | plt.figure('Test 2: multidimensional signal') 56 | n = 1000 57 | m = 20 58 | (X, cps_gt) = gen_data(n, m, d=50) 59 | print('Ground truth:', cps_gt) 60 | plt.plot(X) 61 | K = np.dot(X, X.T) 62 | cps, scores = cpd_nonlin(K, m, lmin=1, lmax=10000) 63 | print('Estimated:', cps) 64 | mi = np.min(X) 65 | ma = np.max(X) 66 | for cp in cps: 67 | plt.plot([cp, cp], [mi, ma], 'r') 68 | plt.show() 69 | print('=' * 79) 70 | 71 | print('Test 3: automatic selection of the number of change-points') 72 | plt.figure('Test 3: automatic selection of the number of change-points') 73 | (X, cps_gt) = gen_data(n, m) 74 | print('Ground truth: (m=%d)' % m, cps_gt) 75 | plt.plot(X) 76 | K = np.dot(X, X.T) 77 | cps, scores = cpd_auto(K, 2 * m, 1) 78 | print('Estimated: (m=%d)' % len(cps), cps) 79 | mi = np.min(X) 80 | ma = np.max(X) 81 | for cp in cps: 82 | plt.plot([cp, cp], [mi, ma], 'r') 83 | plt.show() 84 | print('=' * 79) 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | 4 | from torch.utils.data import Dataset,DataLoader 5 | 6 | # Load split 7 | def load_split(dataset): 8 | outputs = [] 9 | with open(f'./splits/{dataset}_splits.txt','r') as f: 10 | lines = f.readlines() 11 | for line in lines: 12 | _,_,train_videos,test_videos = line.split('/') 13 | train_videos = train_videos.split(',') 14 | test_videos = test_videos.split(',') 15 | test_videos[-1] = test_videos[-1].replace('\n','') 16 | outputs.append((train_videos,test_videos)) 17 | return outputs 18 | 19 | # Create input,ground truth pair 20 | def load_h5(videos,data_path,dataset_name): 21 | features = [] 22 | gtscores = [] 23 | dataset_names = [] 24 | 25 | with h5py.File(data_path,'r') as hdf: 26 | for video in videos: 27 | feature = hdf[video]['features'][()] 28 | gtscore = hdf[video]['gtscore'][()] 29 | 30 | features.append(feature) 31 | gtscores.append(gtscore) 32 | dataset_names.append(dataset_name) 33 | return features,gtscores,dataset_names 34 | 35 | # Create Dataset 36 | class VSdataset(Dataset): 37 | def __init__(self,data,video_nums,transform=None): 38 | features,gtscores,dataset_names = data 39 | self.features = features 40 | self.gtscores = gtscores 41 | self.dataset_names = dataset_names 42 | self.video_nums = video_nums 43 | self.transform = transform 44 | def __len__(self): 45 | return len(self.video_nums) 46 | def __getitem__(self,idx): 47 | output_feature = torch.from_numpy(self.features[idx]).float() 48 | output_feature = output_feature.unsqueeze(0).expand(3,-1,-1) 49 | if self.transform is not None: 50 | output_feature= self.transform(output_feature) 51 | return torch.unsqueeze(output_feature,0),torch.from_numpy(self.gtscores[idx]).float(),self.dataset_names[idx],self.video_nums[idx] 52 | 53 | def collate_fn(sample): 54 | return sample[0] 55 | 56 | # Create Dataloader 57 | def create_dataloader(dataset): 58 | loaders = [] 59 | 60 | splits = load_split(dataset=dataset) 61 | data_path = f'./data/eccv16_dataset_{dataset.lower()}_google_pool5.h5' 62 | 63 | for train_videos,test_videos in splits: 64 | train_data = load_h5(videos=train_videos,data_path=data_path,dataset_name=dataset) 65 | test_data = load_h5(videos=test_videos,data_path=data_path,dataset_name=dataset) 66 | 67 | train_dataset = VSdataset(data=train_data,video_nums=train_videos) 68 | test_dataset = VSdataset(data=test_data,video_nums=test_videos) 69 | train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True,collate_fn=collate_fn) 70 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,collate_fn=collate_fn) 71 | loaders.append((train_loader,test_loader)) 72 | return loaders -------------------------------------------------------------------------------- /kts/cpd_nonlin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tqdm import tqdm 4 | 5 | def calc_scatters(K): 6 | """Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with 7 | starting frame i and ending frame j} 8 | """ 9 | n = K.shape[0] 10 | K1 = np.cumsum([0] + list(np.diag(K))) 11 | K2 = np.zeros((n + 1, n + 1)) 12 | # TODO: use the fact that K - symmetric 13 | K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) 14 | 15 | diagK2 = np.diag(K2) 16 | 17 | i = np.arange(n).reshape((-1, 1)) 18 | j = np.arange(n).reshape((1, -1)) 19 | scatters = ( 20 | K1[1:].reshape((1, -1)) - K1[:-1].reshape((-1, 1)) - 21 | (diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape((-1, 1)) - 22 | K2[1:, :-1].T - K2[:-1, 1:]) / 23 | ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32)) 24 | ) 25 | scatters[j < i] = 0 26 | 27 | return scatters 28 | 29 | 30 | def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True, 31 | out_scatters=None): 32 | """Change point detection with dynamic programming 33 | 34 | :param K: Square kernel matrix 35 | :param ncp: Number of change points to detect (ncp >= 0) 36 | :param lmin: Minimal length of a segment 37 | :param lmax: Maximal length of a segment 38 | :param backtrack: If False - only evaluate objective scores (to save memory) 39 | :param verbose: If true, print verbose message 40 | :param out_scatters: Output scatters 41 | :return: Tuple (cps, obj_vals) 42 | - cps - detected array of change points: mean is thought to be constant 43 | on [ cps[i], cps[i+1] ) 44 | - obj_vals - values of the objective function for 0..m changepoints 45 | """ 46 | m = int(ncp) # prevent numpy.int64 47 | 48 | n, n1 = K.shape 49 | assert n == n1, 'Kernel matrix awaited.' 50 | assert (m + 1) * lmin <= n <= (m + 1) * lmax 51 | assert 1 <= lmin <= lmax 52 | 53 | if verbose: 54 | print('Precomputing scatters...') 55 | J = calc_scatters(K) 56 | 57 | if out_scatters is not None: 58 | out_scatters[0] = J 59 | 60 | if verbose: 61 | print('Inferring best change points...') 62 | # I[k, l] - value of the objective for k change-points and l first frames 63 | I = 1e101 * np.ones((m + 1, n + 1)) 64 | I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] 65 | 66 | if backtrack: 67 | # p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l 68 | p = np.zeros((m + 1, n + 1), dtype=int) 69 | else: 70 | p = np.zeros((1, 1), dtype=int) 71 | 72 | for k in tqdm(range(1, m + 1), ncols = 90, total = m, desc = 'KTS outer', leave = False): 73 | for l in tqdm(range((k + 1) * lmin, n + 1), ncols = 90, total = n + 1 - ((k + 1) * lmin), desc = 'KTS inner', leave = False): 74 | tmin = max(k * lmin, l - lmax) 75 | tmax = l - lmin + 1 76 | c = J[tmin:tmax, l - 1].reshape(-1) + \ 77 | I[k - 1, tmin:tmax].reshape(-1) 78 | I[k, l] = np.min(c) 79 | if backtrack: 80 | p[k, l] = np.argmin(c) + tmin 81 | 82 | # Collect change points 83 | cps = np.zeros(m, dtype=int) 84 | 85 | if backtrack: 86 | cur = n 87 | for k in tqdm(range(m, 0, -1), ncols = 90, total = m, desc = 'KTS final', leave = False): 88 | cps[k - 1] = p[k, cur] 89 | cur = cps[k - 1] 90 | 91 | scores = I[:, n].copy() 92 | scores[scores > 1e99] = np.inf 93 | return cps, scores 94 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | import torch 5 | 6 | # Process bool argument 7 | def str2bool(v): 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise 14 | 15 | # Process none argument 16 | def str2none(v): 17 | if v.lower()=='none': 18 | return None 19 | else: 20 | return v 21 | 22 | # Define configuration class 23 | class Config(object): 24 | def __init__(self, **kwargs): 25 | self.kwargs = kwargs 26 | for k, v in kwargs.items(): 27 | setattr(self, k, v) 28 | 29 | self.datasets = ['SumMe','TVSum'] 30 | self.SumMe_len = 25 31 | self.TVSum_len = 50 32 | 33 | # Set device 34 | if self.device!='cpu': 35 | torch.cuda.set_device(self.device) 36 | 37 | # Set seed 38 | self.set_seed() 39 | 40 | # Set the seed 41 | def set_seed(self): 42 | random.seed(self.seed) 43 | np.random.seed(self.seed) 44 | torch.manual_seed(self.seed) 45 | if self.device!='cpu': 46 | torch.cuda.manual_seed(self.seed) 47 | torch.cuda.manual_seed_all(self.seed) 48 | torch.backends.cudnn.benchmark = False 49 | torch.backends.cudnn.deterministic = True 50 | 51 | # Define all configurations 52 | def get_config(parse=True, **optional_kwargs): 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument('--seed', type=int, default=123456) 56 | parser.add_argument('--device', type=str, default='cuda:0') 57 | parser.add_argument('--epochs', type=int, default=100) 58 | parser.add_argument('--batch_size', default='1') 59 | parser.add_argument('--learning_rate', default='1e-3') 60 | parser.add_argument('--weight_decay', default='1e-7') 61 | 62 | parser.add_argument('--model_name', type=str, default='GoogleNet_Attention') 63 | parser.add_argument('--Scale', type=str2none, default=None) 64 | parser.add_argument('--Softmax_axis', type=str2none, default='TD') 65 | parser.add_argument('--Balance', type=str2none, default=None) 66 | 67 | parser.add_argument('--Positional_encoding', type=str2none, default='FPE') 68 | parser.add_argument('--Positional_encoding_shape', type=str2none, default='TD') 69 | parser.add_argument('--Positional_encoding_way', type=str2none, default='PGL_SUM') 70 | parser.add_argument('--Dropout_on', type=str2bool, default=True) 71 | parser.add_argument('--Dropout_ratio', default='0.6') 72 | 73 | parser.add_argument('--Classifier_on', type=str2bool, default=True) 74 | parser.add_argument('--CLS_on', type=str2bool, default=True) 75 | parser.add_argument('--CLS_mix', type=str2none, default='Final') 76 | 77 | parser.add_argument('--key_value_emb', type=str2none, default='kv') 78 | parser.add_argument('--Skip_connection', type=str2none, default='KC') 79 | parser.add_argument('--Layernorm', type=str2bool, default=True) 80 | 81 | # Generate summary videos 82 | parser.add_argument('--input_is_file', type=str2bool, default='true') 83 | parser.add_argument('--file_path', type=str, default='./SumMe/Jumps.mp4') 84 | parser.add_argument('--dir_path', type=str, default='./SumMe') 85 | parser.add_argument('--ext', type=str, default='mp4') 86 | parser.add_argument('--sample_rate', type=int, default=15) 87 | parser.add_argument('--save_path', type=str, default='./summary_videos') 88 | parser.add_argument('--weight_path', type=str, default='./weights/SumMe/split4.pt') 89 | 90 | kwargs = vars(parser.parse_args()) 91 | kwargs.update(optional_kwargs) 92 | 93 | return Config(**kwargs) 94 | -------------------------------------------------------------------------------- /video_helper.py: -------------------------------------------------------------------------------- 1 | # Reference code: https://github.com/li-plus/DSNet/blob/1804176e2e8b57846beb063667448982273fca89/src/helpers/video_helper.py 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from os import PathLike 8 | from pathlib import Path 9 | from PIL import Image 10 | from torchvision import transforms, models 11 | from torchvision.models import GoogLeNet_Weights 12 | from tqdm import tqdm 13 | 14 | from kts.cpd_auto import cpd_auto 15 | 16 | class FeatureExtractor(object): 17 | def __init__(self, device): 18 | self.device = device 19 | self.transforms = GoogLeNet_Weights.IMAGENET1K_V1.transforms() 20 | weights = GoogLeNet_Weights.IMAGENET1K_V1 21 | self.model = models.googlenet(weights=weights) 22 | self.model = nn.Sequential(*list(self.model.children())[:-2]) 23 | self.model.to(self.device) 24 | self.model.eval() 25 | 26 | def run(self, img: np.ndarray): 27 | img = Image.fromarray(img) 28 | img = self.transforms(img) 29 | batch = img.unsqueeze(0) 30 | with torch.no_grad(): 31 | batch = batch.to(self.device) 32 | feat = self.model(batch) 33 | feat = feat.squeeze() 34 | 35 | assert feat.shape == (1024,), f'Invalid feature shape {feat.shape}: expected 1024' 36 | # normalize frame features 37 | feat = feat / (torch.norm(feat) + 1e-10) 38 | return feat 39 | 40 | class VideoPreprocessor(object): 41 | def __init__(self, sample_rate: int, device: str): 42 | self.model = FeatureExtractor(device) 43 | self.sample_rate = sample_rate 44 | 45 | def get_features(self, video_path: PathLike): 46 | video_path = Path(video_path) 47 | cap = cv2.VideoCapture(str(video_path)) 48 | assert cap is not None, f'Cannot open video: {video_path}' 49 | 50 | self.fps = cap.get(cv2.CAP_PROP_FPS) 51 | self.frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 52 | self.frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 53 | 54 | features = [] 55 | n_frames = 0 56 | 57 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 58 | with tqdm(total = total_frames, ncols=90, desc = "getting features", unit='frame', leave=False) as pbar: 59 | while True: 60 | ret, frame = cap.read() 61 | 62 | if not ret: 63 | break 64 | 65 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 66 | feat = self.model.run(frame) 67 | features.append(feat) 68 | 69 | n_frames += 1 70 | pbar.update(1) 71 | 72 | cap.release() 73 | features = torch.stack(features) 74 | return n_frames, features 75 | 76 | def kts(self, n_frames, features): 77 | seq_len = len(features) 78 | picks = np.arange(0, seq_len) 79 | # compute change points using KTS 80 | kernel = np.matmul(features.clone().detach().cpu().numpy(), features.clone().detach().cpu().numpy().T) 81 | change_points, _ = cpd_auto(kernel, seq_len - 1, 1, verbose=False) 82 | change_points = np.hstack((0, change_points, n_frames)) 83 | begin_frames = change_points[:-1] 84 | end_frames = change_points[1:] 85 | change_points = np.vstack((begin_frames, end_frames - 1)).T 86 | return change_points, picks 87 | 88 | def run(self, video_path: PathLike): 89 | n_frames, features = self.get_features(video_path) 90 | cps, picks = self.kts(n_frames, features) 91 | return n_frames, features[::self.sample_rate,:], cps, picks[::self.sample_rate] 92 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import torch 4 | 5 | from config import get_config 6 | from dataset import create_dataloader 7 | from evaluation_metrics import get_corr_coeff 8 | from generate_summary import generate_summary 9 | from model import set_model 10 | from utils import report_params,get_gt 11 | 12 | # Load configurations 13 | config = get_config() 14 | 15 | # Print the number of parameters 16 | report_params( 17 | model_name=config.model_name, 18 | Scale=config.Scale, 19 | Softmax_axis=config.Softmax_axis, 20 | Balance=config.Balance, 21 | Positional_encoding=config.Positional_encoding, 22 | Positional_encoding_shape=config.Positional_encoding_shape, 23 | Positional_encoding_way=config.Positional_encoding_way, 24 | Dropout_on=config.Dropout_on, 25 | Dropout_ratio=config.Dropout_ratio, 26 | Classifier_on=config.Classifier_on, 27 | CLS_on=config.CLS_on, 28 | CLS_mix=config.CLS_mix, 29 | key_value_emb=config.key_value_emb, 30 | Skip_connection=config.Skip_connection, 31 | Layernorm=config.Layernorm 32 | ) 33 | 34 | # Start testing 35 | for dataset in config.datasets: 36 | user_scores = get_gt(dataset) 37 | split_kendalls = [] 38 | split_spears = [] 39 | 40 | for split_id,(train_loader,test_loader) in enumerate(create_dataloader(dataset)): 41 | model = set_model( 42 | model_name=config.model_name, 43 | Scale=config.Scale, 44 | Softmax_axis=config.Softmax_axis, 45 | Balance=config.Balance, 46 | Positional_encoding=config.Positional_encoding, 47 | Positional_encoding_shape=config.Positional_encoding_shape, 48 | Positional_encoding_way=config.Positional_encoding_way, 49 | Dropout_on=config.Dropout_on, 50 | Dropout_ratio=config.Dropout_ratio, 51 | Classifier_on=config.Classifier_on, 52 | CLS_on=config.CLS_on, 53 | CLS_mix=config.CLS_mix, 54 | key_value_emb=config.key_value_emb, 55 | Skip_connection=config.Skip_connection, 56 | Layernorm=config.Layernorm 57 | ) 58 | model.load_state_dict(torch.load(f'./weights/{dataset}/split{split_id+1}.pt', map_location='cpu')) 59 | model.to(config.device) 60 | model.eval() 61 | 62 | kendalls = [] 63 | spears = [] 64 | with torch.no_grad(): 65 | for feature,_,dataset_name,video_num in test_loader: 66 | feature = feature.to(config.device) 67 | output = model(feature) 68 | 69 | with h5py.File(f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5','r') as hdf: 70 | user_summary = np.array(hdf[video_num]['user_summary']) 71 | sb = np.array(hdf[f"{video_num}/change_points"]) 72 | n_frames = np.array(hdf[f"{video_num}/n_frames"]) 73 | positions = np.array(hdf[f"{video_num}/picks"]) 74 | scores = output.squeeze().clone().detach().cpu().numpy().tolist() 75 | summary = generate_summary([sb], [scores], [n_frames], [positions])[0] 76 | 77 | if dataset_name=='SumMe': 78 | spear,kendall = get_corr_coeff([summary],[video_num],dataset_name,user_summary) 79 | elif dataset_name=='TVSum': 80 | spear,kendall = get_corr_coeff([scores],[video_num],dataset_name,user_scores) 81 | 82 | spears.append(spear) 83 | kendalls.append(kendall) 84 | split_kendalls.append(np.mean(kendalls)) 85 | split_spears.append(np.mean(spears)) 86 | print("[Split{}]Kendall:{:.3f}, Spear:{:.3f}".format( 87 | split_id,split_kendalls[split_id],split_spears[split_id] 88 | )) 89 | print("[FINAL - {}]Kendall:{:.3f}, Spear:{:.3f}".format( 90 | dataset,np.mean(split_kendalls),np.mean(split_spears) 91 | )) 92 | print() 93 | -------------------------------------------------------------------------------- /generate_video.py: -------------------------------------------------------------------------------- 1 | # Reference code: https://github.com/li-plus/DSNet/blob/1804176e2e8b57846beb063667448982273fca89/src/make_dataset.py#L4 2 | # Reference code: https://github.com/e-apostolidis/PGL-SUM/blob/81d0d6d0ee0470775ad759087deebbce1ceffec3/model/configs.py#L10 3 | import cv2 4 | import torch 5 | 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | 9 | from config import get_config 10 | from generate_summary import generate_summary 11 | from model import set_model 12 | from video_helper import VideoPreprocessor 13 | 14 | def pick_frames(video_path, selections): 15 | cap = cv2.VideoCapture(str(video_path)) 16 | frames = [] 17 | n_frames = 0 18 | 19 | with tqdm(total = len(selections), ncols=90, desc = "selecting frames", unit='frame', leave = False) as pbar: 20 | while True: 21 | ret, frame = cap.read() 22 | 23 | if not ret: 24 | break 25 | 26 | if selections[n_frames]: 27 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 28 | frames.append(frame) 29 | n_frames += 1 30 | 31 | pbar.update(1) 32 | 33 | cap.release() 34 | 35 | return frames 36 | 37 | def produce_video(save_path, frames, fps, frame_size): 38 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 39 | out = cv2.VideoWriter(save_path, fourcc, fps, frame_size) 40 | for frame in tqdm(frames, total = len(frames), ncols=90, desc = "generating videos", leave = False): 41 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 42 | out.write(frame) 43 | out.release() 44 | 45 | def main(): 46 | # Load config 47 | config = get_config() 48 | 49 | # create output directory 50 | out_dir = Path(config.save_path) 51 | out_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | # feature extractor 54 | video_proc = VideoPreprocessor( 55 | sample_rate=config.sample_rate, 56 | device=config.device 57 | ) 58 | 59 | # search all videos with .mp4 suffix 60 | if config.input_is_file: 61 | video_paths = [Path(config.file_path)] 62 | else: 63 | video_paths = sorted(Path(config.dir_path).glob(f'*.{config.ext}')) 64 | 65 | # Load CSTA weights 66 | model = set_model( 67 | model_name=config.model_name, 68 | Scale=config.Scale, 69 | Softmax_axis=config.Softmax_axis, 70 | Balance=config.Balance, 71 | Positional_encoding=config.Positional_encoding, 72 | Positional_encoding_shape=config.Positional_encoding_shape, 73 | Positional_encoding_way=config.Positional_encoding_way, 74 | Dropout_on=config.Dropout_on, 75 | Dropout_ratio=config.Dropout_ratio, 76 | Classifier_on=config.Classifier_on, 77 | CLS_on=config.CLS_on, 78 | CLS_mix=config.CLS_mix, 79 | key_value_emb=config.key_value_emb, 80 | Skip_connection=config.Skip_connection, 81 | Layernorm=config.Layernorm 82 | ) 83 | model.load_state_dict(torch.load(config.weight_path, map_location='cpu')) 84 | model.to(config.device) 85 | model.eval() 86 | 87 | # Generate summarized videos 88 | with torch.no_grad(): 89 | for video_path in tqdm(video_paths,total=len(video_paths),ncols=80,leave=False,desc="Making videos..."): 90 | video_name = video_path.stem 91 | n_frames, features, cps, pick = video_proc.run(video_path) 92 | 93 | inputs = features.to(config.device) 94 | inputs = inputs.unsqueeze(0).expand(3,-1,-1).unsqueeze(0) 95 | outputs = model(inputs) 96 | predictions = outputs.squeeze().clone().detach().cpu().numpy().tolist() 97 | # print(cps.shape, len(predictions), n_frames, pick.shape) 98 | selections = generate_summary([cps], [predictions], [n_frames], [pick])[0] 99 | 100 | frames = pick_frames(video_path=video_path, selections=selections) 101 | produce_video( 102 | save_path=f'{config.save_path}/{video_name}.mp4', 103 | frames=frames, 104 | fps=video_proc.fps, 105 | frame_size=(video_proc.frame_width,video_proc.frame_height) 106 | ) 107 | 108 | if __name__=='__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import EfficientNet_B0_Weights,GoogLeNet_Weights,MobileNet_V2_Weights,ResNet18_Weights 2 | 3 | from models.EfficientNet import CSTA_EfficientNet 4 | from models.GoogleNet import CSTA_GoogleNet 5 | from models.MobileNet import CSTA_MobileNet 6 | from models.ResNet import CSTA_ResNet 7 | 8 | # Load models depending on CNN 9 | def set_model(model_name, 10 | Scale, 11 | Softmax_axis, 12 | Balance, 13 | Positional_encoding, 14 | Positional_encoding_shape, 15 | Positional_encoding_way, 16 | Dropout_on, 17 | Dropout_ratio, 18 | Classifier_on, 19 | CLS_on, 20 | CLS_mix, 21 | key_value_emb, 22 | Skip_connection, 23 | Layernorm): 24 | if model_name in ['EfficientNet','EfficientNet_Attention']: 25 | model = CSTA_EfficientNet( 26 | model_name=model_name, 27 | Scale=Scale, 28 | Softmax_axis=Softmax_axis, 29 | Balance=Balance, 30 | Positional_encoding=Positional_encoding, 31 | Positional_encoding_shape=Positional_encoding_shape, 32 | Positional_encoding_way=Positional_encoding_way, 33 | Dropout_on=Dropout_on, 34 | Dropout_ratio=Dropout_ratio, 35 | Classifier_on=Classifier_on, 36 | CLS_on=CLS_on, 37 | CLS_mix=CLS_mix, 38 | key_value_emb=key_value_emb, 39 | Skip_connection=Skip_connection, 40 | Layernorm=Layernorm 41 | ) 42 | state_dict = EfficientNet_B0_Weights.IMAGENET1K_V1.get_state_dict(progress=False) 43 | model.efficientnet.load_state_dict(state_dict) 44 | elif model_name in ['GoogleNet','GoogleNet_Attention']: 45 | model = CSTA_GoogleNet( 46 | model_name=model_name, 47 | Scale=Scale, 48 | Softmax_axis=Softmax_axis, 49 | Balance=Balance, 50 | Positional_encoding=Positional_encoding, 51 | Positional_encoding_shape=Positional_encoding_shape, 52 | Positional_encoding_way=Positional_encoding_way, 53 | Dropout_on=Dropout_on, 54 | Dropout_ratio=Dropout_ratio, 55 | Classifier_on=Classifier_on, 56 | CLS_on=CLS_on, 57 | CLS_mix=CLS_mix, 58 | key_value_emb=key_value_emb, 59 | Skip_connection=Skip_connection, 60 | Layernorm=Layernorm 61 | ) 62 | state_dict = GoogLeNet_Weights.IMAGENET1K_V1.get_state_dict(progress=False) 63 | state_dict = {k: v for k, v in state_dict.items() if not k.startswith('aux')} 64 | new_state_dict = model.googlenet.state_dict() 65 | for name,param in state_dict.items(): 66 | new_state_dict[name] = param 67 | model.googlenet.load_state_dict(new_state_dict) 68 | elif model_name in ['MobileNet','MobileNet_Attention']: 69 | model = CSTA_MobileNet( 70 | model_name=model_name, 71 | Scale=Scale, 72 | Softmax_axis=Softmax_axis, 73 | Balance=Balance, 74 | Positional_encoding=Positional_encoding, 75 | Positional_encoding_shape=Positional_encoding_shape, 76 | Positional_encoding_way=Positional_encoding_way, 77 | Dropout_on=Dropout_on, 78 | Dropout_ratio=Dropout_ratio, 79 | Classifier_on=Classifier_on, 80 | CLS_on=CLS_on, 81 | CLS_mix=CLS_mix, 82 | key_value_emb=key_value_emb, 83 | Skip_connection=Skip_connection, 84 | Layernorm=Layernorm 85 | ) 86 | state_dict = MobileNet_V2_Weights.IMAGENET1K_V1.get_state_dict(progress=False) 87 | model.mobilenet.load_state_dict(state_dict) 88 | elif model_name in ['ResNet','ResNet_Attention']: 89 | model = CSTA_ResNet( 90 | model_name=model_name, 91 | Scale=Scale, 92 | Softmax_axis=Softmax_axis, 93 | Balance=Balance, 94 | Positional_encoding=Positional_encoding, 95 | Positional_encoding_shape=Positional_encoding_shape, 96 | Positional_encoding_way=Positional_encoding_way, 97 | Dropout_on=Dropout_on, 98 | Dropout_ratio=Dropout_ratio, 99 | Classifier_on=Classifier_on, 100 | CLS_on=CLS_on, 101 | CLS_mix=CLS_mix, 102 | key_value_emb=key_value_emb, 103 | Skip_connection=Skip_connection, 104 | Layernorm=Layernorm 105 | ) 106 | state_dict = ResNet18_Weights.IMAGENET1K_V1.get_state_dict(progress=False) 107 | model.resnet.load_state_dict(state_dict) 108 | else: 109 | raise 110 | return model -------------------------------------------------------------------------------- /models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class FixedPositionalEncoding(nn.Module): 6 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000, freq=10000.0): 7 | super(FixedPositionalEncoding, self).__init__() 8 | if Positional_encoding_shape=='TD': 9 | position = torch.arange(max_len).unsqueeze(1) 10 | div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(freq) / dim)) 11 | 12 | pe = torch.zeros(max_len, dim) 13 | pe[:, 0::2] = torch.sin(position * div_term) 14 | pe[:, 1::2] = torch.cos(position * div_term) 15 | elif Positional_encoding_shape=='T': 16 | position = torch.arange(max_len).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, 1, 2) * (-math.log(freq) / 1)) 18 | 19 | pe = torch.zeros(max_len,1) 20 | pe[:, 0::2] = torch.sin(position * div_term) 21 | pe[:, 1::2] = torch.cos(position * div_term) 22 | pe = pe.repeat_interleave(dim,dim=1) 23 | elif Positional_encoding_shape is None: 24 | pass 25 | else: 26 | raise 27 | 28 | self.register_buffer('pe', pe) 29 | 30 | def forward(self, x): 31 | return x + self.pe[:x.shape[0]] 32 | 33 | class RelativePositionalEncoding(nn.Module): 34 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000, freq=10000.0): 35 | super(RelativePositionalEncoding, self).__init__() 36 | self.Positional_encoding_shape = Positional_encoding_shape 37 | self.dim = dim 38 | self.max_len = max_len 39 | self.freq = freq 40 | 41 | def forward(self, x): 42 | T = x.shape[0] 43 | min_rpos = -(T - 1) 44 | i = torch.tensor([k for k in range(T)]) 45 | i = i.reshape(i.shape[0], 1) 46 | if self.Positional_encoding_shape=='TD': 47 | d = T + self.dim 48 | j = torch.tensor([k for k in range(self.dim)]) 49 | 50 | i = i.repeat_interleave(j.shape[0], dim=1) 51 | j = j.repeat(i.shape[0], 1) 52 | 53 | r_pos = j - i - min_rpos 54 | 55 | pe = torch.zeros(T, self.dim) 56 | idx = torch.tensor([k for k in range(T//2)],dtype=torch.int64) 57 | 58 | pe[2*idx, :] = torch.sin(r_pos[2*idx, :] / self.freq ** ((i[2*idx, :] + j[2*idx, :]) / d)) 59 | pe[2*idx+1, :] = torch.cos(r_pos[2*idx+1, :] / self.freq ** ((i[2*idx+1, :] + j[2*idx+1, :]) / d)) 60 | elif self.Positional_encoding_shape=='T': 61 | d = T + 1 62 | j = torch.tensor([k for k in range(1)]) 63 | 64 | i = i.repeat_interleave(j.shape[0], dim=1) 65 | j = j.repeat(i.shape[0], 1) 66 | 67 | r_pos = j - i - min_rpos 68 | 69 | pe = torch.zeros(T, 1) 70 | idx = torch.tensor([k for k in range(T//2)],dtype=torch.int64) 71 | 72 | pe[2*idx, :] = torch.sin(r_pos[2*idx, :] / self.freq ** ((i[2*idx, :] + j[2*idx, :]) / d)) 73 | pe[2*idx+1, :] = torch.cos(r_pos[2*idx+1, :] / self.freq ** ((i[2*idx+1, :] + j[2*idx+1, :]) / d)) 74 | pe = pe.repeat_interleave(self.dim,dim=1) 75 | elif self.Positional_encoding_shape is None: 76 | pass 77 | else: 78 | raise 79 | return x + pe[:x.shape[0]].to(x.device) 80 | 81 | class LearnablePositionalEncoding(nn.Module): 82 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000): 83 | super(LearnablePositionalEncoding, self).__init__() 84 | if Positional_encoding_shape=='TD': 85 | self.pe = nn.Parameter(torch.randn((max_len,dim))) 86 | elif Positional_encoding_shape=='T': 87 | self.pe = nn.Parameter(torch.randn((max_len,1))) 88 | elif Positional_encoding_shape is None: 89 | pass 90 | else: 91 | raise 92 | 93 | def forward(self, x): 94 | return x + self.pe[:x.shape[0]] 95 | 96 | class ConditionalPositionalEncoding(nn.Module): 97 | def __init__(self, Positional_encoding_shape, Positional_encoding_way, dim=1024, kernel_size=3, stride=1, padding=1): 98 | super(ConditionalPositionalEncoding, self).__init__() 99 | self.Positional_encoding_way = Positional_encoding_way 100 | if Positional_encoding_shape=='TD': 101 | self.pe = nn.Conv1d(in_channels=dim,out_channels=dim,kernel_size=kernel_size,stride=stride,padding=padding) 102 | elif Positional_encoding_shape=='T': 103 | self.pe = nn.Conv1d(in_channels=dim,out_channels=dim,kernel_size=kernel_size,stride=stride,padding=padding,groups=dim) 104 | else: 105 | raise 106 | 107 | def forward(self, x): 108 | if self.Positional_encoding_way=='Transformer': 109 | return x + self.pe(x[0].permute(0,2,1)).permute(0,2,1).unsqueeze(0) 110 | elif self.Positional_encoding_way=='PGL_SUM': 111 | return x + self.pe(x.unsqueeze(0).permute(0,2,1)).permute(0,2,1).squeeze(0) 112 | elif self.Positional_encoding_way is None: 113 | pass 114 | else: 115 | raise 116 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | 5 | from collections import Counter 6 | 7 | from models.EfficientNet import CSTA_EfficientNet 8 | from models.GoogleNet import CSTA_GoogleNet 9 | from models.MobileNet import CSTA_MobileNet 10 | from models.ResNet import CSTA_ResNet 11 | 12 | # Count the number of parameters 13 | def count_parameters(model,model_name): 14 | if model_name in ['GoogleNet','GoogleNet_Attention','ResNet','ResNet_Attention']: 15 | x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'fc' not in name] 16 | elif model_name in ['EfficientNet','EfficientNet_Attention','MobileNet','MobileNet_Attention']: 17 | x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'classifier' not in name] 18 | return sum(x) / (1024 * 1024) 19 | 20 | # Funtion printing the number of parameters of models 21 | def report_params(model_name, 22 | Scale, 23 | Softmax_axis, 24 | Balance, 25 | Positional_encoding, 26 | Positional_encoding_shape, 27 | Positional_encoding_way, 28 | Dropout_on, 29 | Dropout_ratio, 30 | Classifier_on, 31 | CLS_on, 32 | CLS_mix, 33 | key_value_emb, 34 | Skip_connection, 35 | Layernorm): 36 | if model_name in ['EfficientNet','EfficientNet_Attention']: 37 | model = CSTA_EfficientNet( 38 | model_name=model_name, 39 | Scale=Scale, 40 | Softmax_axis=Softmax_axis, 41 | Balance=Balance, 42 | Positional_encoding=Positional_encoding, 43 | Positional_encoding_shape=Positional_encoding_shape, 44 | Positional_encoding_way=Positional_encoding_way, 45 | Dropout_on=Dropout_on, 46 | Dropout_ratio=Dropout_ratio, 47 | Classifier_on=Classifier_on, 48 | CLS_on=CLS_on, 49 | CLS_mix=CLS_mix, 50 | key_value_emb=key_value_emb, 51 | Skip_connection=Skip_connection, 52 | Layernorm=Layernorm 53 | ) 54 | elif model_name in ['GoogleNet','GoogleNet_Attention']: 55 | model = CSTA_GoogleNet( 56 | model_name=model_name, 57 | Scale=Scale, 58 | Softmax_axis=Softmax_axis, 59 | Balance=Balance, 60 | Positional_encoding=Positional_encoding, 61 | Positional_encoding_shape=Positional_encoding_shape, 62 | Positional_encoding_way=Positional_encoding_way, 63 | Dropout_on=Dropout_on, 64 | Dropout_ratio=Dropout_ratio, 65 | Classifier_on=Classifier_on, 66 | CLS_on=CLS_on, 67 | CLS_mix=CLS_mix, 68 | key_value_emb=key_value_emb, 69 | Skip_connection=Skip_connection, 70 | Layernorm=Layernorm 71 | ) 72 | elif model_name in ['MobileNet','MobileNet_Attention']: 73 | model = CSTA_MobileNet( 74 | model_name=model_name, 75 | Scale=Scale, 76 | Softmax_axis=Softmax_axis, 77 | Balance=Balance, 78 | Positional_encoding=Positional_encoding, 79 | Positional_encoding_shape=Positional_encoding_shape, 80 | Positional_encoding_way=Positional_encoding_way, 81 | Dropout_on=Dropout_on, 82 | Dropout_ratio=Dropout_ratio, 83 | Classifier_on=Classifier_on, 84 | CLS_on=CLS_on, 85 | CLS_mix=CLS_mix, 86 | key_value_emb=key_value_emb, 87 | Skip_connection=Skip_connection, 88 | Layernorm=Layernorm 89 | ) 90 | elif model_name in ['ResNet','ResNet_Attention']: 91 | model = CSTA_ResNet( 92 | model_name=model_name, 93 | Scale=Scale, 94 | Softmax_axis=Softmax_axis, 95 | Balance=Balance, 96 | Positional_encoding=Positional_encoding, 97 | Positional_encoding_shape=Positional_encoding_shape, 98 | Positional_encoding_way=Positional_encoding_way, 99 | Dropout_on=Dropout_on, 100 | Dropout_ratio=Dropout_ratio, 101 | Classifier_on=Classifier_on, 102 | CLS_on=CLS_on, 103 | CLS_mix=CLS_mix, 104 | key_value_emb=key_value_emb, 105 | Skip_connection=Skip_connection, 106 | Layernorm=Layernorm 107 | ) 108 | print(f"PARAMS: {count_parameters(model,model_name):.2f}M") 109 | 110 | # Print all arguments and GPU setting 111 | def print_args(args): 112 | print(args.kwargs) 113 | print(f"CUDA: {torch.version.cuda}") 114 | print(f"cuDNN: {torch.backends.cudnn.version()}") 115 | if 'cuda' in args.device: 116 | print(f"GPU: {torch.cuda.is_available()}") 117 | print(f"GPU count: {torch.cuda.device_count()}") 118 | print(f"GPU name: {torch.cuda.get_device_name(0)}") 119 | 120 | # Load ground truth for TVSum 121 | def get_gt(dataset): 122 | if dataset=='TVSum': 123 | annot_path = f"./data/ydata-anno.tsv" 124 | with open(annot_path) as annot_file: 125 | annot = list(csv.reader(annot_file, delimiter="\t")) 126 | annotation_length = list(Counter(np.array(annot)[:, 0]).values()) 127 | user_scores = [] 128 | for idx in range(1,51): 129 | init = (idx - 1) * annotation_length[idx-1] 130 | till = idx * annotation_length[idx-1] 131 | user_score = [] 132 | for row in annot[init:till]: 133 | curr_user_score = row[2].split(",") 134 | curr_user_score = np.array([float(num) for num in curr_user_score]) 135 | curr_user_score = curr_user_score / curr_user_score.max(initial=-1) 136 | curr_user_score = curr_user_score[::15] 137 | 138 | user_score.append(curr_user_score) 139 | user_scores.append(user_score) 140 | return user_scores 141 | elif dataset=='SumMe': 142 | return None 143 | else: 144 | raise 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import shutil 4 | import torch 5 | 6 | from tqdm import tqdm 7 | 8 | from config import get_config 9 | from dataset import create_dataloader 10 | from evaluation_metrics import get_corr_coeff 11 | from generate_summary import generate_summary 12 | from model import set_model 13 | from utils import report_params, print_args, get_gt 14 | 15 | # Load configurations 16 | config = get_config() 17 | 18 | # Print information of setting 19 | print_args(config) 20 | 21 | # Print the number of parameters 22 | report_params( 23 | model_name=config.model_name, 24 | Scale=config.Scale, 25 | Softmax_axis=config.Softmax_axis, 26 | Balance=config.Balance, 27 | Positional_encoding=config.Positional_encoding, 28 | Positional_encoding_shape=config.Positional_encoding_shape, 29 | Positional_encoding_way=config.Positional_encoding_way, 30 | Dropout_on=config.Dropout_on, 31 | Dropout_ratio=config.Dropout_ratio, 32 | Classifier_on=config.Classifier_on, 33 | CLS_on=config.CLS_on, 34 | CLS_mix=config.CLS_mix, 35 | key_value_emb=config.key_value_emb, 36 | Skip_connection=config.Skip_connection, 37 | Layernorm=config.Layernorm 38 | ) 39 | 40 | # Start training 41 | for dataset in tqdm(config.datasets,total=len(config.datasets),ncols=70,leave=True,position=0): 42 | user_scores = get_gt(dataset) 43 | 44 | if dataset=='SumMe': 45 | batch_size = 1 if config.batch_size=='1' else int(config.SumMe_len*0.8*float(config.batch_size)) 46 | elif dataset=='TVSum': 47 | batch_size = 1 if config.batch_size=='1' else int(config.TVSum_len*0.8*float(config.batch_size)) 48 | 49 | for split_id,(train_loader,test_loader) in tqdm(enumerate(create_dataloader(dataset)),total=5,ncols=70,leave=False,position=1,desc=dataset): 50 | model = set_model( 51 | model_name=config.model_name, 52 | Scale=config.Scale, 53 | Softmax_axis=config.Softmax_axis, 54 | Balance=config.Balance, 55 | Positional_encoding=config.Positional_encoding, 56 | Positional_encoding_shape=config.Positional_encoding_shape, 57 | Positional_encoding_way=config.Positional_encoding_way, 58 | Dropout_on=config.Dropout_on, 59 | Dropout_ratio=config.Dropout_ratio, 60 | Classifier_on=config.Classifier_on, 61 | CLS_on=config.CLS_on, 62 | CLS_mix=config.CLS_mix, 63 | key_value_emb=config.key_value_emb, 64 | Skip_connection=config.Skip_connection, 65 | Layernorm=config.Layernorm 66 | ) 67 | model.to(config.device) 68 | criterion = torch.nn.MSELoss() 69 | optimizer = torch.optim.Adam(model.parameters(),lr=float(config.learning_rate),weight_decay=float(config.weight_decay)) 70 | 71 | model_selection_kendall = -1 72 | model_selection_spear = -1 73 | 74 | for epoch in tqdm(range(config.epochs),total=config.epochs,ncols=70,leave=False,position=2,desc=f'Split{split_id+1}'): 75 | model.train() 76 | update_loss = 0.0 77 | batch = 0 78 | 79 | for feature,gtscore,dataset_name,video_num in tqdm(train_loader,ncols=70,leave=False,position=3,desc=f'Epoch{epoch+1}_TRAIN'): 80 | feature = feature.to(config.device) 81 | gtscore = gtscore.to(config.device) 82 | output = model(feature) 83 | 84 | loss = criterion(output,gtscore) 85 | loss.requires_grad_(True) 86 | 87 | update_loss += loss 88 | batch += 1 89 | 90 | if batch==batch_size: 91 | optimizer.zero_grad() 92 | update_loss = update_loss / batch 93 | update_loss.backward() 94 | optimizer.step() 95 | update_loss = 0.0 96 | batch = 0 97 | 98 | if batch>0: 99 | optimizer.zero_grad() 100 | update_loss = update_loss / batch 101 | update_loss.backward() 102 | optimizer.step() 103 | update_loss = 0.0 104 | batch = 0 105 | 106 | val_spears = [] 107 | val_kendalls = [] 108 | model.eval() 109 | with torch.no_grad(): 110 | for feature,gtscore,dataset_name,video_num in tqdm(test_loader,ncols=70,leave=False,position=3,desc=f'Epoch{epoch+1}_TEST'): 111 | feature = feature.to(config.device) 112 | gtscore = gtscore.to(config.device) 113 | output = model(feature) 114 | 115 | if dataset_name in ['SumMe','TVSum']: 116 | with h5py.File(f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5','r') as hdf: 117 | user_summary = np.array(hdf[video_num]['user_summary']) 118 | sb = np.array(hdf[f"{video_num}/change_points"]) 119 | n_frames = np.array(hdf[f"{video_num}/n_frames"]) 120 | positions = np.array(hdf[f"{video_num}/picks"]) 121 | scores = output.squeeze().clone().detach().cpu().numpy().tolist() 122 | summary = generate_summary([sb], [scores], [n_frames], [positions])[0] 123 | if dataset_name=='SumMe': 124 | spear,kendall = get_corr_coeff([summary],[video_num],dataset_name,user_summary) 125 | elif dataset_name=='TVSum': 126 | spear,kendall = get_corr_coeff([scores],[video_num],dataset_name,user_scores) 127 | 128 | val_spears.append(spear) 129 | val_kendalls.append(kendall) 130 | 131 | if np.mean(val_kendalls) > model_selection_kendall and np.mean(val_spears) > model_selection_spear: 132 | model_selection_kendall = np.mean(val_kendalls) 133 | model_selection_spear = np.mean(val_spears) 134 | torch.save(model.state_dict(), './tmp/weight.pt') 135 | shutil.move('./tmp/weight.pt', f'./weights/{dataset}/split{split_id+1}.pt') 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSTA: CNN-based Spatiotemporal Attention for Video Summarization (CVPR 2024 paper) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/csta-cnn-based-spatiotemporal-attention-for/supervised-video-summarization-on-summe)](https://paperswithcode.com/sota/supervised-video-summarization-on-summe?p=csta-cnn-based-spatiotemporal-attention-for)
3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/csta-cnn-based-spatiotemporal-attention-for/supervised-video-summarization-on-tvsum)](https://paperswithcode.com/sota/supervised-video-summarization-on-tvsum?p=csta-cnn-based-spatiotemporal-attention-for)
4 | 5 | The official code of "CSTA: CNN-based Spatiotemporal Attention for Video Summarization" [[paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Son_CSTA_CNN-based_Spatiotemporal_Attention_for_Video_Summarization_CVPR_2024_paper.pdf)] [[arXiv](https://arxiv.org/pdf/2405.11905)]
6 | ![image](https://github.com/thswodnjs3/CSTA/assets/93433004/aa0dff4d-9b29-49a2-989a-5b6a12dba5fe) 7 | 8 | * [Model overview](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#model-overview) 9 | * [Updates](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#updates) 10 | * [Requirements](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#requirements) 11 | * [Data](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#data) 12 | * [Pre-trained models](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#pre-trained-models) 13 | * [Training](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#training) 14 | * [Inference](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#inference) 15 | * [Generate summary videos](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#generate-summary-videos) 16 | * [Citation](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#citation) 17 | * [Acknowledgement](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#acknowledgement) 18 | 19 | # Model overview 20 | ![image](https://github.com/thswodnjs3/CSTA/assets/93433004/537b7375-10d7-4d7d-8de0-0b69631ac635)
21 |
22 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 23 | 24 | # Updates 25 | * [2024.03.24] Create a repository. 26 | * [2024.05.21] Update the code and pre-trained models. 27 | * [2024.07.18] Upload the code to generate summary videos, including custom videos. 28 | * [2024.07.21] Update the KTS code for full frames of videos. 29 | * [2024.07.23] Update the code to use only the CPU. 30 | * [2024.12.30] Add tqdm to see the progress generating summary videos 31 | * (Yet) [2025.01.??] Add detailed explanations and comments for the code. 32 | 33 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 34 | 35 | # Requirements 36 | |Ubuntu|GPU|CUDA|cuDNN|conda|python| 37 | |:---:|:---:|:---:|:---:|:---:|:---:| 38 | |20.04.6 LTS|NVIDIA GeForce RTX 4090|12.1|8902|4.9.2|3.8.5| 39 | 40 | |h5py|numpy|scipy|torch|torchvision|tqdm| 41 | |:---:|:---:|:---:|:---:|:---:|:---:| 42 | |3.1.0|1.19.5|1.5.2|2.2.1|0.17.1|4.61.0| 43 | 44 | ``` 45 | conda create -n CSTA python=3.8.5 46 | conda activate CSTA 47 | git clone https://github.com/thswodnjs3/CSTA.git 48 | cd CSTA 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 53 | 54 | # Data 55 | ~~Link: [Dataset](https://drive.google.com/drive/folders/1iGfKZxexQfOxyIaOWhfU0P687dJq_KWF?usp=drive_link)
~~ 56 | - I'm very sorry, but the dataset is no longer available. You can download it [here (PGL-SUM).](https://github.com/e-apostolidis/PGL-SUM)

57 | H5py format of two benchmark video summarization preprocessed datasets (SumMe, TVSum).
58 | You should download datasets and put them in ```data/``` directory.
59 | The structure of the directory must be like below.
60 | ``` 61 | ├── data 62 | └── eccv16_dataset_summe_google_pool5.h5 63 | └── eccv16_dataset_tvsum_google_pool5.h5 64 | ``` 65 | You can see the details of both datasets below.
66 | 67 | [SumMe](https://link.springer.com/chapter/10.1007/978-3-319-10584-0_33)
68 | [TVSum](https://openaccess.thecvf.com/content_cvpr_2015/papers/Song_TVSum_Summarizing_Web_2015_CVPR_paper.pdf)
69 |
70 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 71 | 72 | # Pre-trained models 73 | ~~Link: [Weights](https://drive.google.com/drive/folders/1Z0WV_IJAHXV16sAGW7TmC9J_iFZQ9NSs?usp=drive_link)~~
74 | - I'm very sorry, but the pre-trained weights are no longer available.
75 | I accidentally deleted it, but it can be implemented by using the same seed number below (123456).
76 | If you can't see the similar performance with the same seed, then please contact me.

77 | 78 | You can download our pre-trained weights of CSTA.
79 | There are 5 weights for the SumMe dataset and the other 5 for the TVSum dataset(1 weight for each split).
80 | As shown in the paper, we tested everything 10 times (without fixation of seed) but only uploaded a single model as a representative for your convenience.
81 | The uploaded weight is acquired when the seed is 123456, and the result is almost identical to our paper.
82 | You should put 5 weights of the SumMe in ```weights/SumMe``` and the other 5 weights of the TVSum in ```weights/TVSum```.
83 | The structure of the directory must be like below.
84 | ``` 85 | ├── weights 86 | └── SumMe 87 | ├── split1.pt 88 | ├── split2.pt 89 | ├── split3.pt 90 | ├── split4.pt 91 | ├── split5.pt 92 | └── TVSum 93 | ├── split1.pt 94 | ├── split2.pt 95 | ├── split3.pt 96 | ├── split4.pt 97 | ├── split5.pt 98 | ``` 99 | 100 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 101 | 102 | # Training 103 | You can train the final version of our models by command below.
104 | ``` 105 | python train.py 106 | ``` 107 | Detailed explanations for all configurations will be updated later.
108 | 109 | ## You can't reproduce our result perfectly. 110 | As shown in the paper, we tested every experiment 10 times without fixation of the seed, so we can't be sure which seeds export the same results.
111 | Even though you set the seed 123456, which is the same as our pre-trained models, it may result in different results due to the non-deterministic property of the [Adaptive Average Pooling layer](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms).
112 | Based on my knowledge, non-deterministic operations produce random results even with the same seed. [You can see details here.](https://pytorch.org/docs/stable/notes/randomness.html)
113 | However, you can get similar results with the pre-trained models when you set the seed as 123456, so I hope this will be helpful for you.
114 |
115 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 116 | 117 | # Inference 118 | You can see the final performance of the models by command below.
119 | ``` 120 | python inference.py 121 | ``` 122 | All weight files should be located in the position I said above.
123 |
124 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 125 | 126 | # Generate summary videos 127 | You can generate summary videos using our models.
128 | You can use either videos from public datasets or custom videos.
129 | With the code below, you can apply our pre-trained models to raw videos to produce summary videos.
130 | ``` 131 | python generate_video.py --input_is_file True or False 132 | --file_path 'path to input video' 133 | --dir_path 'directory of input videos' 134 | --ext 'video file extension' 135 | --save_path 'path to save summary video' 136 | --weight_path 'path to loaded weights' 137 | 138 | e.g. 139 | 1)Using a directory 140 | python generate_video.py --input_is_file False --dir_path './videos' --ext 'mp4' --save_path './summary_videos' --weight_path './weights/SumMe/split4.pt' 141 | 142 | 2)Using a single video file 143 | python generate_video.py --input_is_file True --file_path './videos/Jumps.mp4' --save_path './summary_videos' --weight_path './weights/SumMe/split4.pt' 144 | ``` 145 | The explanation of the arguments is as follows.
146 | If you change the 'ext' argument and input a directory of videos, you must modify the ['fourcc'](https://github.com/thswodnjs3/CSTA/blob/7227ee36a460b0bdc4aa83cb446223779365df45/generate_video.py#L34) variable in the 'produce_video' function within the 'generate_video.py' file.
147 | Additionally, you must update this when inputting a single video file with different extensions other than 'mp4'. 148 | ``` 149 | 1. input_is_file (bool): True or False 150 | Indicates whether the input is a file or a directory. 151 | If this is True, the 'file_path' argument is required. 152 | If this is False, the 'dir_path' and 'ext' arguments are required. 153 | 154 | 2. file_path (str) e.g. './SumMe/Jumps.mp4' 155 | The path of the video file. 156 | This is only used when 'input_is_file' is True. 157 | 158 | 3. dir_path (str) e.g. './SumMe' 159 | The path of the directory where video files are located. 160 | This is only used when 'input_is_file' is False. 161 | 162 | 4. ext (str) e.g. 'mp4' 163 | The file extension of the video files. 164 | This is only used when 'input_is_file' is False. 165 | 166 | 5. sample_rate (int) e.g. 15 167 | The interval between selected frames in a video. 168 | For example, if the video has 30 fps, it will become 2 fps with a sample_rate of 15. 169 | 170 | 6. save_path (str) e.g. './summary_videos' 171 | The path where the summary videos are saved. 172 | 173 | 7. weight_path (str) e.g. './weights/SumMe/split4.pt' 174 | The path where the model weights are loaded from. 175 | ``` 176 | We referenced the KTS code from [DSNet](https://github.com/li-plus/DSNet).
177 | However, they applied KTS to downsampled videos (2 fps), which can result in different shot change points and sometimes make it impossible to summarize videos.
178 | We revised it to calculate change points based on the entire frames.
179 |
180 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 181 | 182 | # Citation 183 | If you find our code or our paper useful, please click [★star] for this repo and [cite] the following paper: 184 | ``` 185 | @inproceedings{son2024csta, 186 | title={CSTA: CNN-based Spatiotemporal Attention for Video Summarization}, 187 | author={Son, Jaewon and Park, Jaehun and Kim, Kwangsu}, 188 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 189 | pages={18847--18856}, 190 | year={2024} 191 | } 192 | ``` 193 | 194 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 195 | 196 | # Acknowledgement 197 | We especially, sincerely appreciate the authors of PosENet, RR-STG who responded to our requests very kindly.
198 | Below are the papers we referenced for the code.
199 | 200 | A2Summ - [paper](https://arxiv.org/pdf/2303.07284), [code](https://github.com/boheumd/A2Summ)
201 | CA-SUM - [paper](https://www.iti.gr/~bmezaris/publications/icmr2022_preprint.pdf), [code](https://github.com/e-apostolidis/CA-SUM)
202 | DSNet - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9275314), [code](https://github.com/li-plus/DSNet)
203 | iPTNet - [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Jiang_Joint_Video_Summarization_and_Moment_Localization_by_Cross-Task_Sample_Transfer_CVPR_2022_paper.pdf)
204 | MSVA - [paper](https://arxiv.org/pdf/2104.11530), [code](https://github.com/TIBHannover/MSVA)
205 | PGL-SUM - [paper](https://www.iti.gr/~bmezaris/publications/ism2021a_preprint.pdf), [code](https://github.com/e-apostolidis/PGL-SUM)
206 | PosENet - [paper](https://arxiv.org/pdf/2001.08248), [code](https://github.com/islamamirul/position_information)
207 | RR-STG - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9750933&tag=1)
208 | SSPVS - [paper](https://arxiv.org/pdf/2201.02494), [code](https://github.com/HopLee6/SSPVS-PyTorch)
209 | STVT - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10124837), [code](https://github.com/nchucvml/STVT)
210 | VASNet - [paper](https://arxiv.org/pdf/1812.01969), [code](https://github.com/ok1zjf/VASNet)
211 | VJMHT - [paper](https://arxiv.org/pdf/2112.13478), [code](https://github.com/HopLee6/VJMHT-PyTorch)
212 | 213 | ``` 214 | @inproceedings{he2023a2summ, 215 | title = {Align and Attend: Multimodal Summarization with Dual Contrastive Losses}, 216 | author={He, Bo and Wang, Jun and Qiu, Jielin and Bui, Trung and Shrivastava, Abhinav and Wang, Zhaowen}, 217 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 218 | year = {2023} 219 | } 220 | ``` 221 | ``` 222 | @inproceedings{10.1145/3512527.3531404, 223 | author = {Apostolidis, Evlampios and Balaouras, Georgios and Mezaris, Vasileios and Patras, Ioannis}, 224 | title = {Summarizing Videos Using Concentrated Attention and Considering the Uniqueness and Diversity of the Video Frames}, 225 | year = {2022}, 226 | isbn = {9781450392389}, 227 | publisher = {Association for Computing Machinery}, 228 | address = {New York, NY, USA}, 229 | url = {https://doi.org/10.1145/3512527.3531404}, 230 | doi = {10.1145/3512527.3531404}, 231 | pages = {407-415}, 232 | numpages = {9}, 233 | keywords = {frame diversity, frame uniqueness, concentrated attention, unsupervised learning, video summarization}, 234 | location = {Newark, NJ, USA}, 235 | series = {ICMR '22} 236 | } 237 | ``` 238 | ``` 239 | @article{zhu2020dsnet, 240 | title={DSNet: A Flexible Detect-to-Summarize Network for Video Summarization}, 241 | author={Zhu, Wencheng and Lu, Jiwen and Li, Jiahao and Zhou, Jie}, 242 | journal={IEEE Transactions on Image Processing}, 243 | volume={30}, 244 | pages={948--962}, 245 | year={2020} 246 | } 247 | ``` 248 | ``` 249 | @inproceedings{jiang2022joint, 250 | title={Joint video summarization and moment localization by cross-task sample transfer}, 251 | author={Jiang, Hao and Mu, Yadong}, 252 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 253 | pages={16388--16398}, 254 | year={2022} 255 | } 256 | ``` 257 | ``` 258 | @article{ghauri2021MSVA, 259 | title={SUPERVISED VIDEO SUMMARIZATION VIA MULTIPLE FEATURE SETS WITH PARALLEL ATTENTION}, 260 | author={Ghauri, Junaid Ahmed and Hakimov, Sherzod and Ewerth, Ralph}, 261 | Conference={IEEE International Conference on Multimedia and Expo (ICME)}, 262 | year={2021} 263 | } 264 | ``` 265 | ``` 266 | @INPROCEEDINGS{9666088, 267 | author = {Apostolidis, Evlampios and Balaouras, Georgios and Mezaris, Vasileios and Patras, Ioannis}, 268 | title = {Combining Global and Local Attention with Positional Encoding for Video Summarization}, 269 | booktitle = {2021 IEEE International Symposium on Multimedia (ISM)}, 270 | month = {December}, 271 | year = {2021}, 272 | pages = {226-234} 273 | } 274 | ``` 275 | ``` 276 | @InProceedings{islam2020position, 277 | title={How much Position Information Do Convolutional Neural Networks Encode?}, 278 | author={Islam, Md Amirul and Jia, Sen and Bruce, Neil}, 279 | booktitle={International Conference on Learning Representations}, 280 | year={2020} 281 | } 282 | ``` 283 | ``` 284 | @article{zhu2022relational, 285 | title={Relational reasoning over spatial-temporal graphs for video summarization}, 286 | author={Zhu, Wencheng and Han, Yucheng and Lu, Jiwen and Zhou, Jie}, 287 | journal={IEEE Transactions on Image Processing}, 288 | volume={31}, 289 | pages={3017--3031}, 290 | year={2022}, 291 | publisher={IEEE} 292 | } 293 | ``` 294 | ``` 295 | @inproceedings{li2023progressive, 296 | title={Progressive Video Summarization via Multimodal Self-supervised Learning}, 297 | author={Li, Haopeng and Ke, Qiuhong and Gong, Mingming and Drummond, Tom}, 298 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 299 | pages={5584--5593}, 300 | year={2023} 301 | } 302 | ``` 303 | ``` 304 | @article{hsu2023video, 305 | title={Video summarization with spatiotemporal vision transformer}, 306 | author={Hsu, Tzu-Chun and Liao, Yi-Sheng and Huang, Chun-Rong}, 307 | journal={IEEE Transactions on Image Processing}, 308 | year={2023}, 309 | publisher={IEEE} 310 | } 311 | ``` 312 | ``` 313 | @misc{fajtl2018summarizing, 314 | title={Summarizing Videos with Attention}, 315 | author={Jiri Fajtl and Hajar Sadeghi Sokeh and Vasileios Argyriou and Dorothy Monekosso and Paolo Remagnino}, 316 | year={2018}, 317 | eprint={1812.01969}, 318 | archivePrefix={arXiv}, 319 | primaryClass={cs.CV} 320 | } 321 | ``` 322 | ``` 323 | @article{li2022video, 324 | title={Video Joint Modelling Based on Hierarchical Transformer for Co-summarization}, 325 | author={Li, Haopeng and Ke, Qiuhong and Gong, Mingming and Zhang, Rui}, 326 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 327 | year={2022}, 328 | publisher={IEEE} 329 | } 330 | ``` 331 | 332 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑ 333 | -------------------------------------------------------------------------------- /models/GoogleNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | from typing import Any, Callable, List, Optional, Tuple 8 | 9 | from models.positional_encoding import FixedPositionalEncoding,LearnablePositionalEncoding,RelativePositionalEncoding,ConditionalPositionalEncoding 10 | 11 | # Edit GoogleNet by replacing last parts with adaptive average pooling layers 12 | class GoogleNet_Att(nn.Module): 13 | __constants__ = ["aux_logits", "transform_input"] 14 | 15 | def __init__( 16 | self, 17 | num_classes: int = 1000, 18 | init_weights: Optional[bool] = None 19 | ) -> None: 20 | super().__init__() 21 | conv_block = BasicConv2d 22 | inception_block = Inception 23 | 24 | self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3) 25 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 26 | self.conv2 = conv_block(64, 64, kernel_size=1) 27 | self.conv3 = conv_block(64, 192, kernel_size=3, padding=1) 28 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 29 | 30 | self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32) 31 | self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64) 32 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 33 | 34 | self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64) 35 | self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64) 36 | self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64) 37 | self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64) 38 | self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128) 39 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 40 | 41 | self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128) 42 | self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128) 43 | 44 | self.fc = nn.Linear(1024, num_classes) 45 | 46 | if init_weights: 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 49 | torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | 54 | def _forward(self, x: Tensor, n_frame) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 55 | x = self.conv1(x) 56 | x = self.maxpool1(x) 57 | x = self.conv2(x) 58 | x = self.conv3(x) 59 | x = self.maxpool2(x) 60 | 61 | x = self.inception3a(x) 62 | x = self.inception3b(x) 63 | x = self.maxpool3(x) 64 | x = self.inception4a(x) 65 | 66 | x = self.inception4b(x) 67 | x = self.inception4c(x) 68 | x = self.inception4d(x) 69 | 70 | x = self.inception4e(x) 71 | x = self.maxpool4(x) 72 | x = self.inception5a(x) 73 | x = self.inception5b(x) 74 | 75 | ############################################################################## 76 | # The place I edit to resize feature maps, and to handle various lengths of input videos 77 | ############################################################################## 78 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1)) 79 | x = self.avgpool(x) 80 | x = torch.squeeze(x) 81 | x = x.permute(1,0) 82 | return x 83 | 84 | def forward(self, x: Tensor): 85 | ############################################################################## 86 | # Takes the number of frames to handle various lengths of input videos 87 | ############################################################################## 88 | n_frame = x.shape[2] 89 | x = self._forward(x,n_frame) 90 | return x 91 | 92 | class Inception(nn.Module): 93 | def __init__( 94 | self, 95 | in_channels: int, 96 | ch1x1: int, 97 | ch3x3red: int, 98 | ch3x3: int, 99 | ch5x5red: int, 100 | ch5x5: int, 101 | pool_proj: int, 102 | conv_block: Optional[Callable[..., nn.Module]] = None, 103 | ) -> None: 104 | super().__init__() 105 | if conv_block is None: 106 | conv_block = BasicConv2d 107 | self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) 108 | 109 | self.branch2 = nn.Sequential( 110 | conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1) 111 | ) 112 | 113 | self.branch3 = nn.Sequential( 114 | conv_block(in_channels, ch5x5red, kernel_size=1), 115 | conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1), 116 | ) 117 | 118 | self.branch4 = nn.Sequential( 119 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 120 | conv_block(in_channels, pool_proj, kernel_size=1), 121 | ) 122 | 123 | def _forward(self, x: Tensor) -> List[Tensor]: 124 | branch1 = self.branch1(x) 125 | branch2 = self.branch2(x) 126 | branch3 = self.branch3(x) 127 | branch4 = self.branch4(x) 128 | 129 | outputs = [branch1, branch2, branch3, branch4] 130 | return outputs 131 | 132 | def forward(self, x: Tensor) -> Tensor: 133 | outputs = self._forward(x) 134 | return torch.cat(outputs, 1) 135 | 136 | class BasicConv2d(nn.Module): 137 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 138 | super().__init__() 139 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 140 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 141 | 142 | def forward(self, x: Tensor) -> Tensor: 143 | x = self.conv(x) 144 | x = self.bn(x) 145 | return F.relu(x, inplace=True) 146 | 147 | ############################################################################## 148 | # Define our proposed model 149 | ############################################################################## 150 | class CSTA_GoogleNet(nn.Module): 151 | def __init__(self, 152 | model_name, 153 | Scale, 154 | Softmax_axis, 155 | Balance, 156 | Positional_encoding, 157 | Positional_encoding_shape, 158 | Positional_encoding_way, 159 | Dropout_on, 160 | Dropout_ratio, 161 | Classifier_on, 162 | CLS_on, 163 | CLS_mix, 164 | key_value_emb, 165 | Skip_connection, 166 | Layernorm, 167 | dim=1024): 168 | super().__init__() 169 | self.googlenet = GoogleNet_Att() 170 | 171 | self.model_name = model_name 172 | self.Scale = Scale 173 | self.Softmax_axis = Softmax_axis 174 | self.Balance = Balance 175 | 176 | self.Positional_encoding = Positional_encoding 177 | self.Positional_encoding_shape = Positional_encoding_shape 178 | self.Positional_encoding_way = Positional_encoding_way 179 | self.Dropout_on = Dropout_on 180 | self.Dropout_ratio = Dropout_ratio 181 | 182 | self.Classifier_on = Classifier_on 183 | self.CLS_on = CLS_on 184 | self.CLS_mix = CLS_mix 185 | 186 | self.key_value_emb = key_value_emb 187 | self.Skip_connection = Skip_connection 188 | self.Layernorm = Layernorm 189 | 190 | self.dim = dim 191 | 192 | if self.Positional_encoding is not None: 193 | if self.Positional_encoding=='FPE': 194 | self.Positional_encoding_op = FixedPositionalEncoding( 195 | Positional_encoding_shape=self.Positional_encoding_shape, 196 | dim=self.dim 197 | ) 198 | elif self.Positional_encoding=='RPE': 199 | self.Positional_encoding_op = RelativePositionalEncoding( 200 | Positional_encoding_shape=self.Positional_encoding_shape, 201 | dim=self.dim 202 | ) 203 | elif self.Positional_encoding=='LPE': 204 | self.Positional_encoding_op = LearnablePositionalEncoding( 205 | Positional_encoding_shape=self.Positional_encoding_shape, 206 | dim=self.dim 207 | ) 208 | elif self.Positional_encoding=='CPE': 209 | self.Positional_encoding_op = ConditionalPositionalEncoding( 210 | Positional_encoding_shape=self.Positional_encoding_shape, 211 | Positional_encoding_way=self.Positional_encoding_way, 212 | dim=self.dim 213 | ) 214 | elif self.Positional_encoding is None: 215 | pass 216 | else: 217 | raise 218 | 219 | if self.Positional_encoding_way=='Transformer': 220 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim) 221 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 222 | pass 223 | else: 224 | raise 225 | 226 | if self.Dropout_on: 227 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio)) 228 | 229 | if self.Classifier_on: 230 | self.linear1 = nn.Sequential( 231 | nn.Linear(in_features=self.dim, out_features=self.dim), 232 | nn.ReLU(), 233 | nn.Dropout(p=0.5), 234 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6) 235 | ) 236 | self.linear2 = nn.Sequential( 237 | nn.Linear(in_features=self.dim, out_features=1), 238 | nn.Sigmoid() 239 | ) 240 | 241 | for name,param in self.named_parameters(): 242 | if name in ['linear1.0.weight','linear2.0.weight']: 243 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0)) 244 | elif name in ['linear1.0.bias','linear2.0.bias']: 245 | nn.init.constant_(param, 0.1) 246 | else: 247 | self.gap = nn.AdaptiveAvgPool1d(1) 248 | 249 | if self.CLS_on: 250 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024)) 251 | 252 | if self.key_value_emb is not None: 253 | if self.key_value_emb.lower()=='k': 254 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 255 | elif self.key_value_emb.lower()=='v': 256 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim) 257 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv': 258 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 259 | if self.model_name=='GoogleNet_Attention': 260 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim) 261 | else: 262 | raise 263 | 264 | if self.Layernorm: 265 | if self.Skip_connection=='KC': 266 | self.layernorm1 = nn.BatchNorm2d(num_features=1) 267 | elif self.Skip_connection=='CF': 268 | self.layernorm2 = nn.BatchNorm2d(num_features=1) 269 | elif self.Skip_connection=='IF': 270 | self.layernorm3 = nn.BatchNorm2d(num_features=1) 271 | elif self.Skip_connection is None: 272 | pass 273 | else: 274 | raise 275 | 276 | def forward(self, x): 277 | # Take the number of frames 278 | n_frame = x.shape[2] 279 | 280 | # Linear projection if using CLS token as transformer ways 281 | if self.Positional_encoding_way=='Transformer': 282 | x = self.Positional_encoding_embedding(x) 283 | # Stack CLS token 284 | if self.CLS_on: 285 | x = torch.cat((self.CLS,x),dim=2) 286 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim)) 287 | 288 | # Positional encoding (Transformer ways) 289 | if self.Positional_encoding_way=='Transformer': 290 | if self.Positional_encoding is not None: 291 | x = self.Positional_encoding_op(x) 292 | # Dropout (Transformer ways) 293 | if self.Dropout_on: 294 | x = self.dropout(x) 295 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 296 | pass 297 | else: 298 | raise 299 | 300 | # Key Embedding 301 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']: 302 | key = self.key_embedding(x) 303 | elif self.key_value_emb is None: 304 | key = x 305 | else: 306 | raise 307 | 308 | # CNN as attention algorithm 309 | x_att = self.googlenet(key) 310 | 311 | # Skip connection (KC) 312 | if self.Skip_connection is not None: 313 | if self.Skip_connection=='KC': 314 | x_att = x_att + key.squeeze(0)[0] 315 | if self.Layernorm: 316 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 317 | elif self.Skip_connection in ['CF','IF']: 318 | pass 319 | else: 320 | raise 321 | elif self.Skip_connection is None: 322 | pass 323 | else: 324 | raise 325 | 326 | # Combine CLS token (CNN) 327 | if self.CLS_on: 328 | if self.CLS_mix=='CNN': 329 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0) 330 | x = CT_adjust(x.squeeze(0)).unsqueeze(0) 331 | elif self.CLS_mix in ['SM','Final']: 332 | pass 333 | else: 334 | raise 335 | else: 336 | pass 337 | 338 | # Scaling factor 339 | if self.Scale is not None: 340 | if self.Scale=='D': 341 | scaling_factor = x_att.shape[1] 342 | elif self.Scale=='T': 343 | scaling_factor = x_att.shape[0] 344 | elif self.Scale=='T_D': 345 | scaling_factor = x_att.shape[0] * x_att.shape[1] 346 | else: 347 | raise 348 | scaling_factor = scaling_factor ** 0.5 349 | x_att = x_att / scaling_factor 350 | elif self.Scale is None: 351 | pass 352 | 353 | # Positional encoding (PGL-SUM ways) 354 | if self.Positional_encoding_way=='PGL_SUM': 355 | if self.Positional_encoding is not None: 356 | x_att = self.Positional_encoding_op(x_att) 357 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None: 358 | pass 359 | else: 360 | raise 361 | 362 | # softmax_axis 363 | x = x.squeeze(0)[0] 364 | if self.Softmax_axis=='T': 365 | temporal_attention = F.softmax(x_att,dim=0) 366 | elif self.Softmax_axis=='D': 367 | spatial_attention = F.softmax(x_att,dim=1) 368 | elif self.Softmax_axis=='TD': 369 | temporal_attention = F.softmax(x_att,dim=0) 370 | spatial_attention = F.softmax(x_att,dim=1) 371 | elif self.Softmax_axis is None: 372 | pass 373 | else: 374 | raise 375 | 376 | # Combine CLS token for softmax outputs (SM) 377 | if self.CLS_on: 378 | if self.CLS_mix=='SM': 379 | if self.Softmax_axis=='T': 380 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 381 | elif self.Softmax_axis=='D': 382 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 383 | elif self.Softmax_axis=='TD': 384 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 385 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 386 | elif self.Softmax_axis is None: 387 | pass 388 | else: 389 | raise 390 | elif self.CLS_mix in ['CNN','Final']: 391 | pass 392 | else: 393 | raise 394 | else: 395 | pass 396 | 397 | # Dropout (PGL-SUM ways) 398 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM': 399 | if self.Softmax_axis=='T': 400 | temporal_attention = self.dropout(temporal_attention) 401 | elif self.Softmax_axis=='D': 402 | spatial_attention = self.dropout(spatial_attention) 403 | elif self.Softmax_axis=='TD': 404 | temporal_attention = self.dropout(temporal_attention) 405 | spatial_attention = self.dropout(spatial_attention) 406 | elif self.Softmax_axis is None: 407 | pass 408 | else: 409 | raise 410 | 411 | # Value Embedding 412 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']: 413 | if self.model_name=='GoogleNet_Attention': 414 | x_out = self.value_embedding(x) 415 | elif self.model_name=='GoogleNet': 416 | x_out = x_att 417 | else: 418 | raise 419 | elif self.key_value_emb is None: 420 | if self.model_name=='GoogleNet': 421 | x_out = x_att 422 | elif self.model_name=='GoogleNet_Attention': 423 | x_out = x 424 | else: 425 | raise 426 | else: 427 | raise 428 | 429 | # Combine CLS token for CNN outputs (SM) 430 | if self.CLS_on: 431 | if self.CLS_mix=='SM': 432 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0) 433 | 434 | # Apply Attention maps to input frame features 435 | if self.Softmax_axis=='T': 436 | x_out = x_out * temporal_attention 437 | elif self.Softmax_axis=='D': 438 | x_out = x_out * spatial_attention 439 | elif self.Softmax_axis=='TD': 440 | T,D = x_out.shape 441 | adjust_frame = T/D 442 | adjust_dimension = D/T 443 | if self.Balance=='T': 444 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 445 | elif self.Balance=='D': 446 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 447 | elif self.Balance=='BD': 448 | if T>D: 449 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 450 | elif TD: 456 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 457 | elif T None: 22 | super().__init__() 23 | self.stride = stride 24 | if stride not in [1, 2]: 25 | raise ValueError(f"stride should be 1 or 2 instead of {stride}") 26 | 27 | if norm_layer is None: 28 | norm_layer = nn.BatchNorm2d 29 | 30 | hidden_dim = int(round(inp * expand_ratio)) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | 33 | layers: List[nn.Module] = [] 34 | if expand_ratio != 1: 35 | # pw 36 | layers.append( 37 | Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) 38 | ) 39 | layers.extend( 40 | [ 41 | # dw 42 | Conv2dNormActivation( 43 | hidden_dim, 44 | hidden_dim, 45 | stride=stride, 46 | groups=hidden_dim, 47 | norm_layer=norm_layer, 48 | activation_layer=nn.ReLU6, 49 | ), 50 | # pw-linear 51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 52 | norm_layer(oup), 53 | ] 54 | ) 55 | self.conv = nn.Sequential(*layers) 56 | self.out_channels = oup 57 | self._is_cn = stride > 1 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | if self.use_res_connect: 61 | return x + self.conv(x) 62 | else: 63 | return self.conv(x) 64 | 65 | # MobileNet as attention 66 | class MobileNet_Att(nn.Module): 67 | def __init__( 68 | self, 69 | num_classes: int = 1000, 70 | width_mult: float = 1.0, 71 | inverted_residual_setting: Optional[List[List[int]]] = None, 72 | round_nearest: int = 8, 73 | block: Optional[Callable[..., nn.Module]] = None, 74 | norm_layer: Optional[Callable[..., nn.Module]] = None, 75 | dropout: float = 0.2, 76 | ) -> None: 77 | """ 78 | MobileNet V2 main class 79 | 80 | Args: 81 | num_classes (int): Number of classes 82 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 83 | inverted_residual_setting: Network structure 84 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 85 | Set to 1 to turn off rounding 86 | block: Module specifying inverted residual building block for mobilenet 87 | norm_layer: Module specifying the normalization layer to use 88 | dropout (float): The droupout probability 89 | 90 | """ 91 | super().__init__() 92 | _log_api_usage_once(self) 93 | 94 | if block is None: 95 | block = InvertedResidual 96 | 97 | if norm_layer is None: 98 | norm_layer = nn.BatchNorm2d 99 | 100 | input_channel = 32 101 | last_channel = 1280 102 | 103 | if inverted_residual_setting is None: 104 | inverted_residual_setting = [ 105 | # t, c, n, s 106 | [1, 16, 1, 1], 107 | [6, 24, 2, 2], 108 | [6, 32, 3, 2], 109 | [6, 64, 4, 2], 110 | [6, 96, 3, 1], 111 | [6, 160, 3, 2], 112 | [6, 320, 1, 1], 113 | ] 114 | 115 | # only check the first element, assuming user knows t,c,n,s are required 116 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 117 | raise ValueError( 118 | f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" 119 | ) 120 | 121 | # building first layer 122 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 123 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 124 | features: List[nn.Module] = [ 125 | Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) 126 | ] 127 | # building inverted residual blocks 128 | for t, c, n, s in inverted_residual_setting: 129 | output_channel = _make_divisible(c * width_mult, round_nearest) 130 | for i in range(n): 131 | stride = s if i == 0 else 1 132 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 133 | input_channel = output_channel 134 | # building last several layers 135 | features.append( 136 | Conv2dNormActivation( 137 | input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 138 | ) 139 | ) 140 | # make it nn.Sequential 141 | self.features = nn.Sequential(*features) 142 | 143 | # building classifier 144 | self.classifier = nn.Sequential( 145 | nn.Dropout(p=dropout), 146 | nn.Linear(self.last_channel, num_classes), 147 | ) 148 | 149 | # weight initialization 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 153 | if m.bias is not None: 154 | nn.init.zeros_(m.bias) 155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 156 | nn.init.ones_(m.weight) 157 | nn.init.zeros_(m.bias) 158 | elif isinstance(m, nn.Linear): 159 | nn.init.normal_(m.weight, 0, 0.01) 160 | nn.init.zeros_(m.bias) 161 | 162 | def _forward_impl(self, x: Tensor, n_frame) -> Tensor: 163 | x = self.features(x) 164 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1)) 165 | x = self.avgpool(x) 166 | x = torch.squeeze(x) 167 | x = x.permute(1,0) 168 | return x 169 | 170 | def forward(self, x: Tensor) -> Tensor: 171 | n_frame = x.shape[2] 172 | return self._forward_impl(x, n_frame) 173 | 174 | _COMMON_META = { 175 | "num_params": 3504872, 176 | "min_size": (1, 1), 177 | "categories": _IMAGENET_CATEGORIES, 178 | } 179 | 180 | class MobileNet_V2_Weights(WeightsEnum): 181 | IMAGENET1K_V1 = Weights( 182 | url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 183 | transforms=partial(ImageClassification, crop_size=224), 184 | meta={ 185 | **_COMMON_META, 186 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", 187 | "_metrics": { 188 | "ImageNet-1K": { 189 | "acc@1": 71.878, 190 | "acc@5": 90.286, 191 | } 192 | }, 193 | "_ops": 0.301, 194 | "_file_size": 13.555, 195 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 196 | }, 197 | ) 198 | IMAGENET1K_V2 = Weights( 199 | url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", 200 | transforms=partial(ImageClassification, crop_size=224, resize_size=232), 201 | meta={ 202 | **_COMMON_META, 203 | "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", 204 | "_metrics": { 205 | "ImageNet-1K": { 206 | "acc@1": 72.154, 207 | "acc@5": 90.822, 208 | } 209 | }, 210 | "_ops": 0.301, 211 | "_file_size": 13.598, 212 | "_docs": """ 213 | These weights improve upon the results of the original paper by using a modified version of TorchVision's 214 | `new training recipe 215 | `_. 216 | """, 217 | }, 218 | ) 219 | DEFAULT = IMAGENET1K_V2 220 | 221 | @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) 222 | def mobilenet_v2( 223 | *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any 224 | ) -> MobileNet_Att: 225 | """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear 226 | Bottlenecks `_ paper. 227 | 228 | Args: 229 | weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The 230 | pretrained weights to use. See 231 | :class:`~torchvision.models.MobileNet_V2_Weights` below for 232 | more details, and possible values. By default, no pre-trained 233 | weights are used. 234 | progress (bool, optional): If True, displays a progress bar of the 235 | download to stderr. Default is True. 236 | **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2`` 237 | base class. Please refer to the `source code 238 | `_ 239 | for more details about this class. 240 | 241 | .. autoclass:: torchvision.models.MobileNet_V2_Weights 242 | :members: 243 | """ 244 | weights = MobileNet_V2_Weights.verify(weights) 245 | 246 | if weights is not None: 247 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 248 | 249 | model = MobileNet_Att(**kwargs) 250 | 251 | if weights is not None: 252 | model.load_state_dict(weights.get_state_dict(progress=progress)) 253 | 254 | return model 255 | 256 | # MobileNet-based CSTA 257 | class CSTA_MobileNet(nn.Module): 258 | def __init__(self, 259 | model_name, 260 | Scale, 261 | Softmax_axis, 262 | Balance, 263 | Positional_encoding, 264 | Positional_encoding_shape, 265 | Positional_encoding_way, 266 | Dropout_on, 267 | Dropout_ratio, 268 | Classifier_on, 269 | CLS_on, 270 | CLS_mix, 271 | key_value_emb, 272 | Skip_connection, 273 | Layernorm, 274 | dim=1280): 275 | super().__init__() 276 | self.mobilenet = MobileNet_Att() 277 | 278 | self.model_name = model_name 279 | self.Scale = Scale 280 | self.Softmax_axis = Softmax_axis 281 | self.Balance = Balance 282 | 283 | self.Positional_encoding = Positional_encoding 284 | self.Positional_encoding_shape = Positional_encoding_shape 285 | self.Positional_encoding_way = Positional_encoding_way 286 | self.Dropout_on = Dropout_on 287 | self.Dropout_ratio = Dropout_ratio 288 | 289 | self.Classifier_on = Classifier_on 290 | self.CLS_on = CLS_on 291 | self.CLS_mix = CLS_mix 292 | 293 | self.key_value_emb = key_value_emb 294 | self.Skip_connection = Skip_connection 295 | self.Layernorm = Layernorm 296 | 297 | self.dim = dim 298 | 299 | if self.Positional_encoding is not None: 300 | if self.Positional_encoding=='FPE': 301 | self.Positional_encoding_op = FixedPositionalEncoding( 302 | Positional_encoding_shape=self.Positional_encoding_shape, 303 | dim=self.dim 304 | ) 305 | elif self.Positional_encoding=='RPE': 306 | self.Positional_encoding_op = RelativePositionalEncoding( 307 | Positional_encoding_shape=self.Positional_encoding_shape, 308 | dim=self.dim 309 | ) 310 | elif self.Positional_encoding=='LPE': 311 | self.Positional_encoding_op = LearnablePositionalEncoding( 312 | Positional_encoding_shape=self.Positional_encoding_shape, 313 | dim=self.dim 314 | ) 315 | elif self.Positional_encoding=='CPE': 316 | self.Positional_encoding_op = ConditionalPositionalEncoding( 317 | Positional_encoding_shape=self.Positional_encoding_shape, 318 | Positional_encoding_way=self.Positional_encoding_way, 319 | dim=self.dim 320 | ) 321 | elif self.Positional_encoding is None: 322 | pass 323 | else: 324 | raise 325 | 326 | if self.Positional_encoding_way=='Transformer': 327 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim) 328 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 329 | pass 330 | else: 331 | raise 332 | 333 | if self.Dropout_on: 334 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio)) 335 | 336 | if self.Classifier_on: 337 | self.linear1 = nn.Sequential( 338 | nn.Linear(in_features=self.dim, out_features=self.dim), 339 | nn.ReLU(), 340 | nn.Dropout(p=0.5), 341 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6) 342 | ) 343 | self.linear2 = nn.Sequential( 344 | nn.Linear(in_features=self.dim, out_features=1), 345 | nn.Sigmoid() 346 | ) 347 | 348 | for name,param in self.named_parameters(): 349 | if name in ['linear1.0.weight','linear2.0.weight']: 350 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0)) 351 | elif name in ['linear1.0.bias','linear2.0.bias']: 352 | nn.init.constant_(param, 0.1) 353 | else: 354 | self.gap = nn.AdaptiveAvgPool1d(1) 355 | 356 | if self.CLS_on: 357 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024)) 358 | 359 | if self.key_value_emb is not None: 360 | if self.key_value_emb.lower()=='k': 361 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 362 | elif self.key_value_emb.lower()=='v': 363 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim) 364 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv': 365 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 366 | if self.model_name=='MobileNet_Attention': 367 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim) 368 | else: 369 | raise 370 | 371 | if self.Layernorm: 372 | if self.Skip_connection=='KC': 373 | self.layernorm1 = nn.BatchNorm2d(num_features=1) 374 | elif self.Skip_connection=='CF': 375 | self.layernorm2 = nn.BatchNorm2d(num_features=1) 376 | elif self.Skip_connection=='IF': 377 | self.layernorm3 = nn.BatchNorm2d(num_features=1) 378 | elif self.Skip_connection is None: 379 | pass 380 | else: 381 | raise 382 | 383 | def forward(self, x): 384 | n_frame = x.shape[2] 385 | 386 | if self.Positional_encoding_way=='Transformer': 387 | x = self.Positional_encoding_embedding(x) 388 | if self.CLS_on: 389 | x = torch.cat((self.CLS,x),dim=2) 390 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim)) 391 | 392 | if self.Positional_encoding_way=='Transformer': 393 | if self.Positional_encoding is not None: 394 | x = self.Positional_encoding_op(x) 395 | if self.Dropout_on: 396 | x = self.dropout(x) 397 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 398 | pass 399 | else: 400 | raise 401 | 402 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']: 403 | key = self.key_embedding(x) 404 | elif self.key_value_emb is None: 405 | key = x 406 | else: 407 | raise 408 | 409 | x_att = self.mobilenet(key) 410 | 411 | if self.Skip_connection is not None: 412 | if self.Skip_connection=='KC': 413 | x_att = x_att + key.squeeze(0)[0] 414 | if self.Layernorm: 415 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 416 | elif self.Skip_connection in ['CF','IF']: 417 | pass 418 | else: 419 | raise 420 | elif self.Skip_connection is None: 421 | pass 422 | else: 423 | raise 424 | 425 | if self.CLS_on: 426 | if self.CLS_mix=='CNN': 427 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0) 428 | x = CT_adjust(x.squeeze(0)).unsqueeze(0) 429 | elif self.CLS_mix in ['SM','Final']: 430 | pass 431 | else: 432 | raise 433 | else: 434 | pass 435 | 436 | if self.Scale is not None: 437 | if self.Scale=='D': 438 | scaling_factor = x_att.shape[1] 439 | elif self.Scale=='T': 440 | scaling_factor = x_att.shape[0] 441 | elif self.Scale=='T_D': 442 | scaling_factor = x_att.shape[0] * x_att.shape[1] 443 | else: 444 | raise 445 | scaling_factor = scaling_factor ** 0.5 446 | x_att = x_att / scaling_factor 447 | elif self.Scale is None: 448 | pass 449 | 450 | if self.Positional_encoding_way=='PGL_SUM': 451 | if self.Positional_encoding is not None: 452 | x_att = self.Positional_encoding_op(x_att) 453 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None: 454 | pass 455 | else: 456 | raise 457 | 458 | x = x.squeeze(0)[0] 459 | if self.Softmax_axis=='T': 460 | temporal_attention = F.softmax(x_att,dim=0) 461 | elif self.Softmax_axis=='D': 462 | spatial_attention = F.softmax(x_att,dim=1) 463 | elif self.Softmax_axis=='TD': 464 | temporal_attention = F.softmax(x_att,dim=0) 465 | spatial_attention = F.softmax(x_att,dim=1) 466 | elif self.Softmax_axis is None: 467 | pass 468 | else: 469 | raise 470 | 471 | if self.CLS_on: 472 | if self.CLS_mix=='SM': 473 | if self.Softmax_axis=='T': 474 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 475 | elif self.Softmax_axis=='D': 476 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 477 | elif self.Softmax_axis=='TD': 478 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 479 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 480 | elif self.Softmax_axis is None: 481 | pass 482 | else: 483 | raise 484 | elif self.CLS_mix in ['CNN','Final']: 485 | pass 486 | else: 487 | raise 488 | else: 489 | pass 490 | 491 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM': 492 | if self.Softmax_axis=='T': 493 | temporal_attention = self.dropout(temporal_attention) 494 | elif self.Softmax_axis=='D': 495 | spatial_attention = self.dropout(spatial_attention) 496 | elif self.Softmax_axis=='TD': 497 | temporal_attention = self.dropout(temporal_attention) 498 | spatial_attention = self.dropout(spatial_attention) 499 | elif self.Softmax_axis is None: 500 | pass 501 | else: 502 | raise 503 | 504 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']: 505 | if self.model_name=='MobileNet_Attention': 506 | x_out = self.value_embedding(x) 507 | elif self.model_name=='MobileNet': 508 | x_out = x_att 509 | else: 510 | raise 511 | elif self.key_value_emb is None: 512 | if self.model_name=='MobileNet': 513 | x_out = x_att 514 | elif self.model_name=='MobileNet_Attention': 515 | x_out = x 516 | else: 517 | raise 518 | else: 519 | raise 520 | 521 | if self.CLS_on: 522 | if self.CLS_mix=='SM': 523 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0) 524 | 525 | if self.Softmax_axis=='T': 526 | x_out = x_out * temporal_attention 527 | elif self.Softmax_axis=='D': 528 | x_out = x_out * spatial_attention 529 | elif self.Softmax_axis=='TD': 530 | T,D = x_out.shape 531 | adjust_frame = T/D 532 | adjust_dimension = D/T 533 | if self.Balance=='T': 534 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 535 | elif self.Balance=='D': 536 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 537 | elif self.Balance=='BD': 538 | if T>D: 539 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 540 | elif TD: 546 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 547 | elif T Callable[[Callable[..., M]], Callable[..., M]]: 23 | def wrapper(fn: Callable[..., M]) -> Callable[..., M]: 24 | key = name if name is not None else fn.__name__ 25 | if key in BUILTIN_MODELS: 26 | raise ValueError(f"An entry is already registered under the name '{key}'.") 27 | BUILTIN_MODELS[key] = fn 28 | return fn 29 | 30 | return wrapper 31 | 32 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d( 35 | in_planes, 36 | out_planes, 37 | kernel_size=3, 38 | stride=stride, 39 | padding=dilation, 40 | groups=groups, 41 | bias=False, 42 | dilation=dilation, 43 | ) 44 | 45 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 46 | """1x1 convolution""" 47 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 48 | 49 | class BasicBlock(nn.Module): 50 | expansion: int = 1 51 | 52 | def __init__( 53 | self, 54 | inplanes: int, 55 | planes: int, 56 | stride: int = 1, 57 | downsample: Optional[nn.Module] = None, 58 | groups: int = 1, 59 | base_width: int = 64, 60 | dilation: int = 1, 61 | norm_layer: Optional[Callable[..., nn.Module]] = None, 62 | ) -> None: 63 | super().__init__() 64 | if norm_layer is None: 65 | norm_layer = nn.BatchNorm2d 66 | if groups != 1 or base_width != 64: 67 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 68 | if dilation > 1: 69 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 70 | self.conv1 = conv3x3(inplanes, planes, stride) 71 | self.bn1 = norm_layer(planes) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = conv3x3(planes, planes) 74 | self.bn2 = norm_layer(planes) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x: Tensor) -> Tensor: 79 | identity = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | class Bottleneck(nn.Module): 97 | expansion: int = 4 98 | 99 | def __init__( 100 | self, 101 | inplanes: int, 102 | planes: int, 103 | stride: int = 1, 104 | downsample: Optional[nn.Module] = None, 105 | groups: int = 1, 106 | base_width: int = 64, 107 | dilation: int = 1, 108 | norm_layer: Optional[Callable[..., nn.Module]] = None, 109 | ) -> None: 110 | super().__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | width = int(planes * (base_width / 64.0)) * groups 114 | self.conv1 = conv1x1(inplanes, width) 115 | self.bn1 = norm_layer(width) 116 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 117 | self.bn2 = norm_layer(width) 118 | self.conv3 = conv1x1(width, planes * self.expansion) 119 | self.bn3 = norm_layer(planes * self.expansion) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.downsample = downsample 122 | self.stride = stride 123 | 124 | def forward(self, x: Tensor) -> Tensor: 125 | identity = x 126 | 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.relu(out) 134 | 135 | out = self.conv3(out) 136 | out = self.bn3(out) 137 | 138 | if self.downsample is not None: 139 | identity = self.downsample(x) 140 | 141 | out += identity 142 | out = self.relu(out) 143 | 144 | return out 145 | 146 | # ResNet as attention 147 | class ResNet_Att(nn.Module): 148 | def __init__( 149 | self, 150 | block: Type[Union[BasicBlock, Bottleneck]] = BasicBlock, 151 | layers: List[int] = [2, 2, 2, 2], 152 | num_classes: int = 1000, 153 | zero_init_residual: bool = False, 154 | groups: int = 1, 155 | width_per_group: int = 64, 156 | replace_stride_with_dilation: Optional[List[bool]] = None, 157 | norm_layer: Optional[Callable[..., nn.Module]] = None, 158 | ) -> None: 159 | super().__init__() 160 | _log_api_usage_once(self) 161 | if norm_layer is None: 162 | norm_layer = nn.BatchNorm2d 163 | self._norm_layer = norm_layer 164 | 165 | self.inplanes = 64 166 | self.dilation = 1 167 | if replace_stride_with_dilation is None: 168 | replace_stride_with_dilation = [False, False, False] 169 | if len(replace_stride_with_dilation) != 3: 170 | raise ValueError( 171 | "replace_stride_with_dilation should be None " 172 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 173 | ) 174 | self.groups = groups 175 | self.base_width = width_per_group 176 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 177 | self.bn1 = norm_layer(self.inplanes) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 180 | self.layer1 = self._make_layer(block, 64, layers[0]) 181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 184 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 185 | self.fc = nn.Linear(512 * block.expansion, num_classes) 186 | 187 | for m in self.modules(): 188 | if isinstance(m, nn.Conv2d): 189 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 190 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 191 | nn.init.constant_(m.weight, 1) 192 | nn.init.constant_(m.bias, 0) 193 | 194 | if zero_init_residual: 195 | for m in self.modules(): 196 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 197 | nn.init.constant_(m.bn3.weight, 0) 198 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 199 | nn.init.constant_(m.bn2.weight, 0) 200 | 201 | def _make_layer( 202 | self, 203 | block: Type[Union[BasicBlock, Bottleneck]], 204 | planes: int, 205 | blocks: int, 206 | stride: int = 1, 207 | dilate: bool = False, 208 | ) -> nn.Sequential: 209 | norm_layer = self._norm_layer 210 | downsample = None 211 | previous_dilation = self.dilation 212 | if dilate: 213 | self.dilation *= stride 214 | stride = 1 215 | if stride != 1 or self.inplanes != planes * block.expansion: 216 | downsample = nn.Sequential( 217 | conv1x1(self.inplanes, planes * block.expansion, stride), 218 | norm_layer(planes * block.expansion), 219 | ) 220 | 221 | layers = [] 222 | layers.append( 223 | block( 224 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 225 | ) 226 | ) 227 | self.inplanes = planes * block.expansion 228 | for _ in range(1, blocks): 229 | layers.append( 230 | block( 231 | self.inplanes, 232 | planes, 233 | groups=self.groups, 234 | base_width=self.base_width, 235 | dilation=self.dilation, 236 | norm_layer=norm_layer, 237 | ) 238 | ) 239 | 240 | return nn.Sequential(*layers) 241 | 242 | def _forward_impl(self, x: Tensor,n_frame) -> Tensor: 243 | x = self.conv1(x) 244 | x = self.bn1(x) 245 | x = self.relu(x) 246 | x = self.maxpool(x) 247 | 248 | x = self.layer1(x) 249 | x = self.layer2(x) 250 | x = self.layer3(x) 251 | x = self.layer4(x) 252 | 253 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1)) 254 | x = self.avgpool(x) 255 | x = torch.squeeze(x) 256 | x = x.permute(1,0) 257 | 258 | return x 259 | 260 | def forward(self, x: Tensor) -> Tensor: 261 | n_frame = x.shape[2] 262 | return self._forward_impl(x,n_frame) 263 | 264 | def _resnet( 265 | block: Type[Union[BasicBlock, Bottleneck]], 266 | layers: List[int], 267 | weights: Optional[WeightsEnum], 268 | progress: bool, 269 | **kwargs: Any, 270 | ) -> ResNet_Att: 271 | if weights is not None: 272 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 273 | 274 | model = ResNet_Att(block, layers, **kwargs) 275 | 276 | if weights is not None: 277 | model.load_state_dict(weights.get_state_dict(progress=progress)) 278 | 279 | return model 280 | 281 | _COMMON_META = { 282 | "min_size": (1, 1), 283 | "categories": _IMAGENET_CATEGORIES, 284 | } 285 | 286 | class ResNet18_Weights(WeightsEnum): 287 | IMAGENET1K_V1 = Weights( 288 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth", 289 | transforms=partial(ImageClassification, crop_size=224), 290 | meta={ 291 | **_COMMON_META, 292 | "num_params": 11689512, 293 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", 294 | "_metrics": { 295 | "ImageNet-1K": { 296 | "acc@1": 69.758, 297 | "acc@5": 89.078, 298 | } 299 | }, 300 | "_ops": 1.814, 301 | "_file_size": 44.661, 302 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 303 | }, 304 | ) 305 | DEFAULT = IMAGENET1K_V1 306 | 307 | @register_model() 308 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) 309 | def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet_Att: 310 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 311 | 312 | Args: 313 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The 314 | pretrained weights to use. See 315 | :class:`~torchvision.models.ResNet18_Weights` below for 316 | more details, and possible values. By default, no pre-trained 317 | weights are used. 318 | progress (bool, optional): If True, displays a progress bar of the 319 | download to stderr. Default is True. 320 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` 321 | base class. Please refer to the `source code 322 | `_ 323 | for more details about this class. 324 | 325 | .. autoclass:: torchvision.models.ResNet18_Weights 326 | :members: 327 | """ 328 | weights = ResNet18_Weights.verify(weights) 329 | 330 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 331 | 332 | # ResNet-based CSTA 333 | class CSTA_ResNet(nn.Module): 334 | def __init__(self, 335 | model_name, 336 | Scale, 337 | Softmax_axis, 338 | Balance, 339 | Positional_encoding, 340 | Positional_encoding_shape, 341 | Positional_encoding_way, 342 | Dropout_on, 343 | Dropout_ratio, 344 | Classifier_on, 345 | CLS_on, 346 | CLS_mix, 347 | key_value_emb, 348 | Skip_connection, 349 | Layernorm,dim=512): 350 | super().__init__() 351 | self.resnet = ResNet_Att() 352 | 353 | self.model_name = model_name 354 | self.Scale = Scale 355 | self.Softmax_axis = Softmax_axis 356 | self.Balance = Balance 357 | 358 | self.Positional_encoding = Positional_encoding 359 | self.Positional_encoding_shape = Positional_encoding_shape 360 | self.Positional_encoding_way = Positional_encoding_way 361 | self.Dropout_on = Dropout_on 362 | self.Dropout_ratio = Dropout_ratio 363 | 364 | self.Classifier_on = Classifier_on 365 | self.CLS_on = CLS_on 366 | self.CLS_mix = CLS_mix 367 | 368 | self.key_value_emb = key_value_emb 369 | self.Skip_connection = Skip_connection 370 | self.Layernorm = Layernorm 371 | 372 | self.dim = dim 373 | 374 | if self.Positional_encoding is not None: 375 | if self.Positional_encoding=='FPE': 376 | self.Positional_encoding_op = FixedPositionalEncoding( 377 | Positional_encoding_shape=self.Positional_encoding_shape, 378 | dim=self.dim 379 | ) 380 | elif self.Positional_encoding=='RPE': 381 | self.Positional_encoding_op = RelativePositionalEncoding( 382 | Positional_encoding_shape=self.Positional_encoding_shape, 383 | dim=self.dim 384 | ) 385 | elif self.Positional_encoding=='LPE': 386 | self.Positional_encoding_op = LearnablePositionalEncoding( 387 | Positional_encoding_shape=self.Positional_encoding_shape, 388 | dim=self.dim 389 | ) 390 | elif self.Positional_encoding=='CPE': 391 | self.Positional_encoding_op = ConditionalPositionalEncoding( 392 | Positional_encoding_shape=self.Positional_encoding_shape, 393 | Positional_encoding_way=self.Positional_encoding_way, 394 | dim=self.dim 395 | ) 396 | elif self.Positional_encoding is None: 397 | pass 398 | else: 399 | raise 400 | 401 | if self.Positional_encoding_way=='Transformer': 402 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim) 403 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 404 | pass 405 | else: 406 | raise 407 | 408 | if self.Dropout_on: 409 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio)) 410 | 411 | if self.Classifier_on: 412 | self.linear1 = nn.Sequential( 413 | nn.Linear(in_features=self.dim, out_features=self.dim), 414 | nn.ReLU(), 415 | nn.Dropout(p=0.5), 416 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6) 417 | ) 418 | self.linear2 = nn.Sequential( 419 | nn.Linear(in_features=self.dim, out_features=1), 420 | nn.Sigmoid() 421 | ) 422 | 423 | for name,param in self.named_parameters(): 424 | if name in ['linear1.0.weight','linear2.0.weight']: 425 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0)) 426 | elif name in ['linear1.0.bias','linear2.0.bias']: 427 | nn.init.constant_(param, 0.1) 428 | else: 429 | self.gap = nn.AdaptiveAvgPool1d(1) 430 | 431 | if self.CLS_on: 432 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024)) 433 | 434 | if self.key_value_emb is not None: 435 | if self.key_value_emb.lower()=='k': 436 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 437 | elif self.key_value_emb.lower()=='v': 438 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim) 439 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv': 440 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 441 | if self.model_name=='ResNet_Attention': 442 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim) 443 | else: 444 | raise 445 | 446 | if self.Layernorm: 447 | if self.Skip_connection=='KC': 448 | self.layernorm1 = nn.BatchNorm2d(num_features=1) 449 | elif self.Skip_connection=='CF': 450 | self.layernorm2 = nn.BatchNorm2d(num_features=1) 451 | elif self.Skip_connection=='IF': 452 | self.layernorm3 = nn.BatchNorm2d(num_features=1) 453 | elif self.Skip_connection is None: 454 | pass 455 | else: 456 | raise 457 | 458 | def forward(self, x): 459 | n_frame = x.shape[2] 460 | 461 | if self.Positional_encoding_way=='Transformer': 462 | x = self.Positional_encoding_embedding(x) 463 | if self.CLS_on: 464 | x = torch.cat((self.CLS,x),dim=2) 465 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim)) 466 | 467 | if self.Positional_encoding_way=='Transformer': 468 | if self.Positional_encoding is not None: 469 | x = self.Positional_encoding_op(x) 470 | if self.Dropout_on: 471 | x = self.dropout(x) 472 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 473 | pass 474 | else: 475 | raise 476 | 477 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']: 478 | key = self.key_embedding(x) 479 | elif self.key_value_emb is None: 480 | key = x 481 | else: 482 | raise 483 | 484 | x_att = self.resnet(key) 485 | 486 | if self.Skip_connection is not None: 487 | if self.Skip_connection=='KC': 488 | x_att = x_att + key.squeeze(0)[0] 489 | if self.Layernorm: 490 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 491 | elif self.Skip_connection in ['CF','IF']: 492 | pass 493 | else: 494 | raise 495 | elif self.Skip_connection is None: 496 | pass 497 | else: 498 | raise 499 | 500 | if self.CLS_on: 501 | if self.CLS_mix=='CNN': 502 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0) 503 | x = CT_adjust(x.squeeze(0)).unsqueeze(0) 504 | elif self.CLS_mix in ['SM','Final']: 505 | pass 506 | else: 507 | raise 508 | else: 509 | pass 510 | 511 | if self.Scale is not None: 512 | if self.Scale=='D': 513 | scaling_factor = x_att.shape[1] 514 | elif self.Scale=='T': 515 | scaling_factor = x_att.shape[0] 516 | elif self.Scale=='T_D': 517 | scaling_factor = x_att.shape[0] * x_att.shape[1] 518 | else: 519 | raise 520 | scaling_factor = scaling_factor ** 0.5 521 | x_att = x_att / scaling_factor 522 | elif self.Scale is None: 523 | pass 524 | 525 | if self.Positional_encoding_way=='PGL_SUM': 526 | if self.Positional_encoding is not None: 527 | x_att = self.Positional_encoding_op(x_att) 528 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None: 529 | pass 530 | else: 531 | raise 532 | 533 | x = x.squeeze(0)[0] 534 | if self.Softmax_axis=='T': 535 | temporal_attention = F.softmax(x_att,dim=0) 536 | elif self.Softmax_axis=='D': 537 | spatial_attention = F.softmax(x_att,dim=1) 538 | elif self.Softmax_axis=='TD': 539 | temporal_attention = F.softmax(x_att,dim=0) 540 | spatial_attention = F.softmax(x_att,dim=1) 541 | elif self.Softmax_axis is None: 542 | pass 543 | else: 544 | raise 545 | 546 | if self.CLS_on: 547 | if self.CLS_mix=='SM': 548 | if self.Softmax_axis=='T': 549 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 550 | elif self.Softmax_axis=='D': 551 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 552 | elif self.Softmax_axis=='TD': 553 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 554 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 555 | elif self.Softmax_axis is None: 556 | pass 557 | else: 558 | raise 559 | elif self.CLS_mix in ['CNN','Final']: 560 | pass 561 | else: 562 | raise 563 | else: 564 | pass 565 | 566 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM': 567 | if self.Softmax_axis=='T': 568 | temporal_attention = self.dropout(temporal_attention) 569 | elif self.Softmax_axis=='D': 570 | spatial_attention = self.dropout(spatial_attention) 571 | elif self.Softmax_axis=='TD': 572 | temporal_attention = self.dropout(temporal_attention) 573 | spatial_attention = self.dropout(spatial_attention) 574 | elif self.Softmax_axis is None: 575 | pass 576 | else: 577 | raise 578 | 579 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']: 580 | if self.model_name=='ResNet_Attention': 581 | x_out = self.value_embedding(x) 582 | elif self.model_name=='ResNet': 583 | x_out = x_att 584 | else: 585 | raise 586 | elif self.key_value_emb is None: 587 | if self.model_name=='ResNet': 588 | x_out = x_att 589 | elif self.model_name=='ResNet_Attention': 590 | x_out = x 591 | else: 592 | raise 593 | else: 594 | raise 595 | 596 | if self.CLS_on: 597 | if self.CLS_mix=='SM': 598 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0) 599 | 600 | if self.Softmax_axis=='T': 601 | x_out = x_out * temporal_attention 602 | elif self.Softmax_axis=='D': 603 | x_out = x_out * spatial_attention 604 | elif self.Softmax_axis=='TD': 605 | T,D = x_out.shape 606 | adjust_frame = T/D 607 | adjust_dimension = D/T 608 | if self.Balance=='T': 609 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 610 | elif self.Balance=='D': 611 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 612 | elif self.Balance=='BD': 613 | if T>D: 614 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 615 | elif TD: 621 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 622 | elif T int: 31 | return _make_divisible(channels * width_mult, 8, min_value) 32 | 33 | class MBConvConfig(_MBConvConfig): 34 | def __init__( 35 | self, 36 | expand_ratio: float, 37 | kernel: int, 38 | stride: int, 39 | input_channels: int, 40 | out_channels: int, 41 | num_layers: int, 42 | width_mult: float = 1.0, 43 | depth_mult: float = 1.0, 44 | block: Optional[Callable[..., nn.Module]] = None, 45 | ) -> None: 46 | input_channels = self.adjust_channels(input_channels, width_mult) 47 | out_channels = self.adjust_channels(out_channels, width_mult) 48 | num_layers = self.adjust_depth(num_layers, depth_mult) 49 | if block is None: 50 | block = MBConv 51 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) 52 | 53 | @staticmethod 54 | def adjust_depth(num_layers: int, depth_mult: float): 55 | return int(math.ceil(num_layers * depth_mult)) 56 | 57 | class MBConv(nn.Module): 58 | def __init__( 59 | self, 60 | cnf: MBConvConfig, 61 | stochastic_depth_prob: float, 62 | norm_layer: Callable[..., nn.Module], 63 | se_layer: Callable[..., nn.Module] = SqueezeExcitation, 64 | ) -> None: 65 | super().__init__() 66 | 67 | if not (1 <= cnf.stride <= 2): 68 | raise ValueError("illegal stride value") 69 | 70 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 71 | 72 | layers: List[nn.Module] = [] 73 | activation_layer = nn.SiLU 74 | 75 | expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) 76 | if expanded_channels != cnf.input_channels: 77 | layers.append( 78 | Conv2dNormActivation( 79 | cnf.input_channels, 80 | expanded_channels, 81 | kernel_size=1, 82 | norm_layer=norm_layer, 83 | activation_layer=activation_layer, 84 | ) 85 | ) 86 | 87 | layers.append( 88 | Conv2dNormActivation( 89 | expanded_channels, 90 | expanded_channels, 91 | kernel_size=cnf.kernel, 92 | stride=cnf.stride, 93 | groups=expanded_channels, 94 | norm_layer=norm_layer, 95 | activation_layer=activation_layer, 96 | ) 97 | ) 98 | 99 | squeeze_channels = max(1, cnf.input_channels // 4) 100 | layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) 101 | 102 | layers.append( 103 | Conv2dNormActivation( 104 | expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None 105 | ) 106 | ) 107 | 108 | self.block = nn.Sequential(*layers) 109 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") 110 | self.out_channels = cnf.out_channels 111 | 112 | def forward(self, input: Tensor) -> Tensor: 113 | result = self.block(input) 114 | if self.use_res_connect: 115 | result = self.stochastic_depth(result) 116 | result += input 117 | return result 118 | 119 | class MBConvConfig(_MBConvConfig): 120 | def __init__( 121 | self, 122 | expand_ratio: float, 123 | kernel: int, 124 | stride: int, 125 | input_channels: int, 126 | out_channels: int, 127 | num_layers: int, 128 | width_mult: float = 1.0, 129 | depth_mult: float = 1.0, 130 | block: Optional[Callable[..., nn.Module]] = None, 131 | ) -> None: 132 | input_channels = self.adjust_channels(input_channels, width_mult) 133 | out_channels = self.adjust_channels(out_channels, width_mult) 134 | num_layers = self.adjust_depth(num_layers, depth_mult) 135 | if block is None: 136 | block = MBConv 137 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) 138 | 139 | @staticmethod 140 | def adjust_depth(num_layers: int, depth_mult: float): 141 | return int(math.ceil(num_layers * depth_mult)) 142 | 143 | class FusedMBConvConfig(_MBConvConfig): 144 | def __init__( 145 | self, 146 | expand_ratio: float, 147 | kernel: int, 148 | stride: int, 149 | input_channels: int, 150 | out_channels: int, 151 | num_layers: int, 152 | block: Optional[Callable[..., nn.Module]] = None, 153 | ) -> None: 154 | if block is None: 155 | block = FusedMBConv 156 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) 157 | 158 | class FusedMBConv(nn.Module): 159 | def __init__( 160 | self, 161 | cnf: FusedMBConvConfig, 162 | stochastic_depth_prob: float, 163 | norm_layer: Callable[..., nn.Module], 164 | ) -> None: 165 | super().__init__() 166 | 167 | if not (1 <= cnf.stride <= 2): 168 | raise ValueError("illegal stride value") 169 | 170 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 171 | 172 | layers: List[nn.Module] = [] 173 | activation_layer = nn.SiLU 174 | 175 | expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) 176 | if expanded_channels != cnf.input_channels: 177 | layers.append( 178 | Conv2dNormActivation( 179 | cnf.input_channels, 180 | expanded_channels, 181 | kernel_size=cnf.kernel, 182 | stride=cnf.stride, 183 | norm_layer=norm_layer, 184 | activation_layer=activation_layer, 185 | ) 186 | ) 187 | 188 | layers.append( 189 | Conv2dNormActivation( 190 | expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None 191 | ) 192 | ) 193 | else: 194 | layers.append( 195 | Conv2dNormActivation( 196 | cnf.input_channels, 197 | cnf.out_channels, 198 | kernel_size=cnf.kernel, 199 | stride=cnf.stride, 200 | norm_layer=norm_layer, 201 | activation_layer=activation_layer, 202 | ) 203 | ) 204 | 205 | self.block = nn.Sequential(*layers) 206 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") 207 | self.out_channels = cnf.out_channels 208 | 209 | def forward(self, input: Tensor) -> Tensor: 210 | result = self.block(input) 211 | if self.use_res_connect: 212 | result = self.stochastic_depth(result) 213 | result += input 214 | return result 215 | 216 | class FusedMBConvConfig(_MBConvConfig): 217 | def __init__( 218 | self, 219 | expand_ratio: float, 220 | kernel: int, 221 | stride: int, 222 | input_channels: int, 223 | out_channels: int, 224 | num_layers: int, 225 | block: Optional[Callable[..., nn.Module]] = None, 226 | ) -> None: 227 | if block is None: 228 | block = FusedMBConv 229 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) 230 | 231 | def _efficientnet_conf( 232 | arch: str, 233 | **kwargs: Any, 234 | ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: 235 | inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] 236 | if arch.startswith("efficientnet_b"): 237 | bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult")) 238 | inverted_residual_setting = [ 239 | bneck_conf(1, 3, 1, 32, 16, 1), 240 | bneck_conf(6, 3, 2, 16, 24, 2), 241 | bneck_conf(6, 5, 2, 24, 40, 2), 242 | bneck_conf(6, 3, 2, 40, 80, 3), 243 | bneck_conf(6, 5, 1, 80, 112, 3), 244 | bneck_conf(6, 5, 2, 112, 192, 4), 245 | bneck_conf(6, 3, 1, 192, 320, 1), 246 | ] 247 | last_channel = None 248 | elif arch.startswith("efficientnet_v2_s"): 249 | inverted_residual_setting = [ 250 | FusedMBConvConfig(1, 3, 1, 24, 24, 2), 251 | FusedMBConvConfig(4, 3, 2, 24, 48, 4), 252 | FusedMBConvConfig(4, 3, 2, 48, 64, 4), 253 | MBConvConfig(4, 3, 2, 64, 128, 6), 254 | MBConvConfig(6, 3, 1, 128, 160, 9), 255 | MBConvConfig(6, 3, 2, 160, 256, 15), 256 | ] 257 | last_channel = 1280 258 | elif arch.startswith("efficientnet_v2_m"): 259 | inverted_residual_setting = [ 260 | FusedMBConvConfig(1, 3, 1, 24, 24, 3), 261 | FusedMBConvConfig(4, 3, 2, 24, 48, 5), 262 | FusedMBConvConfig(4, 3, 2, 48, 80, 5), 263 | MBConvConfig(4, 3, 2, 80, 160, 7), 264 | MBConvConfig(6, 3, 1, 160, 176, 14), 265 | MBConvConfig(6, 3, 2, 176, 304, 18), 266 | MBConvConfig(6, 3, 1, 304, 512, 5), 267 | ] 268 | last_channel = 1280 269 | elif arch.startswith("efficientnet_v2_l"): 270 | inverted_residual_setting = [ 271 | FusedMBConvConfig(1, 3, 1, 32, 32, 4), 272 | FusedMBConvConfig(4, 3, 2, 32, 64, 7), 273 | FusedMBConvConfig(4, 3, 2, 64, 96, 7), 274 | MBConvConfig(4, 3, 2, 96, 192, 10), 275 | MBConvConfig(6, 3, 1, 192, 224, 19), 276 | MBConvConfig(6, 3, 2, 224, 384, 25), 277 | MBConvConfig(6, 3, 1, 384, 640, 7), 278 | ] 279 | last_channel = 1280 280 | else: 281 | raise ValueError(f"Unsupported model type {arch}") 282 | 283 | return inverted_residual_setting, last_channel 284 | 285 | # EfficientNet as attention 286 | class EfficientNet_Att(nn.Module): 287 | def __init__( 288 | self, 289 | inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)[0], 290 | dropout: float = 0.2, 291 | stochastic_depth_prob: float = 0.2, 292 | num_classes: int = 1000, 293 | norm_layer: Optional[Callable[..., nn.Module]] = None, 294 | last_channel: Optional[int] = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)[1], 295 | ) -> None: 296 | """ 297 | EfficientNet V1 and V2 main class 298 | 299 | Args: 300 | inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure 301 | dropout (float): The droupout probability 302 | stochastic_depth_prob (float): The stochastic depth probability 303 | num_classes (int): Number of classes 304 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use 305 | last_channel (int): The number of channels on the penultimate layer 306 | """ 307 | super().__init__() 308 | _log_api_usage_once(self) 309 | 310 | if not inverted_residual_setting: 311 | raise ValueError("The inverted_residual_setting should not be empty") 312 | elif not ( 313 | isinstance(inverted_residual_setting, Sequence) 314 | and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting]) 315 | ): 316 | raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") 317 | 318 | if norm_layer is None: 319 | norm_layer = nn.BatchNorm2d 320 | 321 | layers: List[nn.Module] = [] 322 | 323 | firstconv_output_channels = inverted_residual_setting[0].input_channels 324 | layers.append( 325 | Conv2dNormActivation( 326 | 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU 327 | ) 328 | ) 329 | 330 | total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting) 331 | stage_block_id = 0 332 | for cnf in inverted_residual_setting: 333 | stage: List[nn.Module] = [] 334 | for _ in range(cnf.num_layers): 335 | 336 | block_cnf = copy.copy(cnf) 337 | 338 | if stage: 339 | block_cnf.input_channels = block_cnf.out_channels 340 | block_cnf.stride = 1 341 | 342 | sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks 343 | 344 | stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer)) 345 | stage_block_id += 1 346 | 347 | layers.append(nn.Sequential(*stage)) 348 | 349 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 350 | lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels 351 | layers.append( 352 | Conv2dNormActivation( 353 | lastconv_input_channels, 354 | lastconv_output_channels, 355 | kernel_size=1, 356 | norm_layer=norm_layer, 357 | activation_layer=nn.SiLU, 358 | ) 359 | ) 360 | 361 | self.features = nn.Sequential(*layers) 362 | self.avgpool = nn.AdaptiveAvgPool2d(1) 363 | self.classifier = nn.Sequential( 364 | nn.Dropout(p=dropout, inplace=True), 365 | nn.Linear(lastconv_output_channels, num_classes), 366 | ) 367 | 368 | for m in self.modules(): 369 | if isinstance(m, nn.Conv2d): 370 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 371 | if m.bias is not None: 372 | nn.init.zeros_(m.bias) 373 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 374 | nn.init.ones_(m.weight) 375 | nn.init.zeros_(m.bias) 376 | elif isinstance(m, nn.Linear): 377 | init_range = 1.0 / math.sqrt(m.out_features) 378 | nn.init.uniform_(m.weight, -init_range, init_range) 379 | nn.init.zeros_(m.bias) 380 | 381 | def _forward_impl(self, x: Tensor,n_frame) -> Tensor: 382 | x = self.features(x) 383 | 384 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1)) 385 | x = self.avgpool(x) 386 | x = torch.squeeze(x) 387 | x = x.permute(1,0) 388 | 389 | return x 390 | 391 | def forward(self, x: Tensor) -> Tensor: 392 | n_frame = x.shape[2] 393 | return self._forward_impl(x,n_frame) 394 | 395 | # EfficientNet-based CSTA 396 | class CSTA_EfficientNet(nn.Module): 397 | def __init__(self, 398 | model_name, 399 | Scale, 400 | Softmax_axis, 401 | Balance, 402 | Positional_encoding, 403 | Positional_encoding_shape, 404 | Positional_encoding_way, 405 | Dropout_on, 406 | Dropout_ratio, 407 | Classifier_on, 408 | CLS_on, 409 | CLS_mix, 410 | key_value_emb, 411 | Skip_connection, 412 | Layernorm, 413 | dim=1280): 414 | super().__init__() 415 | self.efficientnet = EfficientNet_Att() 416 | 417 | self.model_name = model_name 418 | self.Scale = Scale 419 | self.Softmax_axis = Softmax_axis 420 | self.Balance = Balance 421 | 422 | self.Positional_encoding = Positional_encoding 423 | self.Positional_encoding_shape = Positional_encoding_shape 424 | self.Positional_encoding_way = Positional_encoding_way 425 | self.Dropout_on = Dropout_on 426 | self.Dropout_ratio = Dropout_ratio 427 | 428 | self.Classifier_on = Classifier_on 429 | self.CLS_on = CLS_on 430 | self.CLS_mix = CLS_mix 431 | 432 | self.key_value_emb = key_value_emb 433 | self.Skip_connection = Skip_connection 434 | self.Layernorm = Layernorm 435 | 436 | self.dim = dim 437 | 438 | if self.Positional_encoding is not None: 439 | if self.Positional_encoding=='FPE': 440 | self.Positional_encoding_op = FixedPositionalEncoding( 441 | Positional_encoding_shape=self.Positional_encoding_shape, 442 | dim=self.dim 443 | ) 444 | elif self.Positional_encoding=='RPE': 445 | self.Positional_encoding_op = RelativePositionalEncoding( 446 | Positional_encoding_shape=self.Positional_encoding_shape, 447 | dim=self.dim 448 | ) 449 | elif self.Positional_encoding=='LPE': 450 | self.Positional_encoding_op = LearnablePositionalEncoding( 451 | Positional_encoding_shape=self.Positional_encoding_shape, 452 | dim=self.dim 453 | ) 454 | elif self.Positional_encoding=='CPE': 455 | self.Positional_encoding_op = ConditionalPositionalEncoding( 456 | Positional_encoding_shape=self.Positional_encoding_shape, 457 | Positional_encoding_way=self.Positional_encoding_way, 458 | dim=self.dim 459 | ) 460 | elif self.Positional_encoding is None: 461 | pass 462 | else: 463 | raise 464 | 465 | if self.Positional_encoding_way=='Transformer': 466 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim) 467 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 468 | pass 469 | else: 470 | raise 471 | 472 | if self.Dropout_on: 473 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio)) 474 | 475 | if self.Classifier_on: 476 | self.linear1 = nn.Sequential( 477 | nn.Linear(in_features=self.dim, out_features=self.dim), 478 | nn.ReLU(), 479 | nn.Dropout(p=0.5), 480 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6) 481 | ) 482 | self.linear2 = nn.Sequential( 483 | nn.Linear(in_features=self.dim, out_features=1), 484 | nn.Sigmoid() 485 | ) 486 | 487 | for name,param in self.named_parameters(): 488 | if name in ['linear1.0.weight','linear2.0.weight']: 489 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0)) 490 | elif name in ['linear1.0.bias','linear2.0.bias']: 491 | nn.init.constant_(param, 0.1) 492 | else: 493 | self.gap = nn.AdaptiveAvgPool1d(1) 494 | 495 | if self.CLS_on: 496 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024)) 497 | 498 | if self.key_value_emb is not None: 499 | if self.key_value_emb.lower()=='k': 500 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 501 | elif self.key_value_emb.lower()=='v': 502 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim) 503 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv': 504 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim) 505 | if self.model_name=='EfficientNet_Attention': 506 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim) 507 | else: 508 | raise 509 | 510 | if self.Layernorm: 511 | if self.Skip_connection=='KC': 512 | self.layernorm1 = nn.BatchNorm2d(num_features=1) 513 | elif self.Skip_connection=='CF': 514 | self.layernorm2 = nn.BatchNorm2d(num_features=1) 515 | elif self.Skip_connection=='IF': 516 | self.layernorm3 = nn.BatchNorm2d(num_features=1) 517 | elif self.Skip_connection is None: 518 | pass 519 | else: 520 | raise 521 | 522 | def forward(self, x): 523 | n_frame = x.shape[2] 524 | 525 | if self.Positional_encoding_way=='Transformer': 526 | x = self.Positional_encoding_embedding(x) 527 | if self.CLS_on: 528 | x = torch.cat((self.CLS,x),dim=2) 529 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim)) 530 | 531 | if self.Positional_encoding_way=='Transformer': 532 | if self.Positional_encoding is not None: 533 | x = self.Positional_encoding_op(x) 534 | if self.Dropout_on: 535 | x = self.dropout(x) 536 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None: 537 | pass 538 | else: 539 | raise 540 | 541 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']: 542 | key = self.key_embedding(x) 543 | elif self.key_value_emb is None: 544 | key = x 545 | else: 546 | raise 547 | 548 | x_att = self.efficientnet(key) 549 | 550 | if self.Skip_connection is not None: 551 | if self.Skip_connection=='KC': 552 | x_att = x_att + key.squeeze(0)[0] 553 | if self.Layernorm: 554 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 555 | elif self.Skip_connection in ['CF','IF']: 556 | pass 557 | else: 558 | raise 559 | elif self.Skip_connection is None: 560 | pass 561 | else: 562 | raise 563 | 564 | if self.CLS_on: 565 | if self.CLS_mix=='CNN': 566 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0) 567 | x = CT_adjust(x.squeeze(0)).unsqueeze(0) 568 | elif self.CLS_mix in ['SM','Final']: 569 | pass 570 | else: 571 | raise 572 | else: 573 | pass 574 | 575 | if self.Scale is not None: 576 | if self.Scale=='D': 577 | scaling_factor = x_att.shape[1] 578 | elif self.Scale=='T': 579 | scaling_factor = x_att.shape[0] 580 | elif self.Scale=='T_D': 581 | scaling_factor = x_att.shape[0] * x_att.shape[1] 582 | else: 583 | raise 584 | scaling_factor = scaling_factor ** 0.5 585 | x_att = x_att / scaling_factor 586 | elif self.Scale is None: 587 | pass 588 | 589 | if self.Positional_encoding_way=='PGL_SUM': 590 | if self.Positional_encoding is not None: 591 | x_att = self.Positional_encoding_op(x_att) 592 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None: 593 | pass 594 | else: 595 | raise 596 | 597 | x = x.squeeze(0)[0] 598 | if self.Softmax_axis=='T': 599 | temporal_attention = F.softmax(x_att,dim=0) 600 | elif self.Softmax_axis=='D': 601 | spatial_attention = F.softmax(x_att,dim=1) 602 | elif self.Softmax_axis=='TD': 603 | temporal_attention = F.softmax(x_att,dim=0) 604 | spatial_attention = F.softmax(x_att,dim=1) 605 | elif self.Softmax_axis is None: 606 | pass 607 | else: 608 | raise 609 | 610 | if self.CLS_on: 611 | if self.CLS_mix=='SM': 612 | if self.Softmax_axis=='T': 613 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 614 | elif self.Softmax_axis=='D': 615 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 616 | elif self.Softmax_axis=='TD': 617 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0) 618 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0) 619 | elif self.Softmax_axis is None: 620 | pass 621 | else: 622 | raise 623 | elif self.CLS_mix in ['CNN','Final']: 624 | pass 625 | else: 626 | raise 627 | else: 628 | pass 629 | 630 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM': 631 | if self.Softmax_axis=='T': 632 | temporal_attention = self.dropout(temporal_attention) 633 | elif self.Softmax_axis=='D': 634 | spatial_attention = self.dropout(spatial_attention) 635 | elif self.Softmax_axis=='TD': 636 | temporal_attention = self.dropout(temporal_attention) 637 | spatial_attention = self.dropout(spatial_attention) 638 | elif self.Softmax_axis is None: 639 | pass 640 | else: 641 | raise 642 | 643 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']: 644 | if self.model_name=='EfficientNet_Attention': 645 | x_out = self.value_embedding(x) 646 | elif self.model_name=='EfficientNet': 647 | x_out = x_att 648 | else: 649 | raise 650 | elif self.key_value_emb is None: 651 | if self.model_name=='EfficientNet': 652 | x_out = x_att 653 | elif self.model_name=='EfficientNet_Attention': 654 | x_out = x 655 | else: 656 | raise 657 | else: 658 | raise 659 | 660 | if self.CLS_on: 661 | if self.CLS_mix=='SM': 662 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0) 663 | 664 | if self.Softmax_axis=='T': 665 | x_out = x_out * temporal_attention 666 | elif self.Softmax_axis=='D': 667 | x_out = x_out * spatial_attention 668 | elif self.Softmax_axis=='TD': 669 | T,D = x_out.shape 670 | adjust_frame = T/D 671 | adjust_dimension = D/T 672 | if self.Balance=='T': 673 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 674 | elif self.Balance=='D': 675 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 676 | elif self.Balance=='BD': 677 | if T>D: 678 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension 679 | elif TD: 685 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention 686 | elif T