├── tmp
└── txt.txt
├── summary_videos
└── txt.txt
├── weights
├── SumMe
│ └── txt.txt
└── TVSum
│ └── txt.txt
├── requirements.txt
├── kts
├── LICENSE
├── README.md
├── cpd_auto.py
├── demo.py
└── cpd_nonlin.py
├── knapsack_implementation.py
├── LICENSE
├── splits
├── SumMe_splits.txt
└── TVSum_splits.txt
├── evaluation_metrics.py
├── generate_summary.py
├── dataset.py
├── config.py
├── video_helper.py
├── inference.py
├── generate_video.py
├── model.py
├── models
├── positional_encoding.py
├── GoogleNet.py
├── MobileNet.py
├── ResNet.py
└── EfficientNet.py
├── utils.py
├── train.py
└── README.md
/tmp/txt.txt:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/summary_videos/txt.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/weights/SumMe/txt.txt:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/weights/TVSum/txt.txt:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py==3.1.0
2 | numpy==1.19.5
3 | scipy==1.5.2
4 | torch==2.2.1
5 | torchvision==0.17.1
6 | tqdm==4.61.0
--------------------------------------------------------------------------------
/kts/LICENSE:
--------------------------------------------------------------------------------
1 | #### DISCLAIMER
2 | This kts library is downloaded from:
3 | - http://lear.inrialpes.fr/software
4 | - http://pascal.inrialpes.fr/data2/potapov/med_summaries/kts_ver1.1.tar.gz
5 |
6 | I just modified the original code to remove weave dependecy. Please follow the
7 | original LICENSE from LEAR if you are using kts.
8 |
--------------------------------------------------------------------------------
/knapsack_implementation.py:
--------------------------------------------------------------------------------
1 | # Knapsack algorithm to find shots having high score
2 | def knapSack(W, wt, val, n):
3 | K = [[0 for _ in range(W + 1)] for _ in range(n + 1)]
4 |
5 | for i in range(n + 1):
6 | for w in range(W + 1):
7 | if i == 0 or w == 0:
8 | K[i][w] = 0
9 | elif wt[i - 1] <= w:
10 | K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], K[i - 1][w])
11 | else:
12 | K[i][w] = K[i - 1][w]
13 |
14 | selected = []
15 | w = W
16 | for i in range(n, 0, -1):
17 | if K[i][w] != K[i - 1][w]:
18 | selected.insert(0, i - 1)
19 | w -= wt[i - 1]
20 |
21 | return selected
--------------------------------------------------------------------------------
/kts/README.md:
--------------------------------------------------------------------------------
1 | This code is from [DS-Net](https://github.com/li-plus/DSNet.git).
2 |
3 | Kernel temporal segmentation
4 | ============================
5 |
6 |
7 | #### DISCLAIMER
8 | This kts library is downloaded from:
9 | - http://lear.inrialpes.fr/software
10 | - http://pascal.inrialpes.fr/data2/potapov/med_summaries/kts_ver1.1.tar.gz
11 |
12 | I just modified the original code to remove weave dependecy. Please follow the
13 | original LICENSE from LEAR if you are using kts.
14 |
15 |
16 | #### Original documentation
17 | This archive contains the following files:
18 | * cpd_nonlin.py - kernel temporal segmentation with fixed number of segments
19 | * cpd_auto.py - kernel temporal segmentation with autocalibration
20 | * demo.py - demo on synthetic examples
21 |
22 | #### Dependencies:
23 | * python + libraries: numpy, scipy, matplotlib (for demo)
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 thswodnjs3
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/splits/SumMe_splits.txt:
--------------------------------------------------------------------------------
1 | SumMe/split0/video_1,video_3,video_4,video_5,video_6,video_7,video_8,video_9,video_10,video_12,video_15,video_16,video_17,video_18,video_19,video_20,video_21,video_22,video_23,video_25/video_2,video_11,video_13,video_14,video_24
2 | SumMe/split1/video_1,video_2,video_3,video_4,video_5,video_6,video_7,video_8,video_11,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_20,video_22,video_24,video_25/video_9,video_10,video_12,video_21,video_23
3 | SumMe/split2/video_1,video_2,video_3,video_4,video_6,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_21,video_22,video_23,video_24/video_5,video_7,video_8,video_20,video_25
4 | SumMe/split3/video_2,video_3,video_5,video_6,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_16,video_19,video_20,video_21,video_22,video_23,video_24,video_25/video_1,video_4,video_15,video_17,video_18
5 | SumMe/split4/video_1,video_2,video_4,video_5,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_17,video_18,video_20,video_21,video_23,video_24,video_25/video_3,video_6,video_16,video_19,video_22
6 |
--------------------------------------------------------------------------------
/kts/cpd_auto.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from kts.cpd_nonlin import cpd_nonlin
4 |
5 |
6 | def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs):
7 | """Detect change points automatically selecting their number
8 |
9 | :param K: Kernel between each pair of frames in video
10 | :param ncp: Maximum number of change points
11 | :param vmax: Special parameter
12 | :param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x
13 | :param kwargs: Extra parameters for ``cpd_nonlin``
14 | :return: Tuple (cps, costs)
15 | - cps - best selected change-points
16 | - costs - costs for 0,1,2,...,m change-points
17 | """
18 | m = ncp
19 | _, scores = cpd_nonlin(K, m, backtrack=False, **kwargs)
20 |
21 | N = K.shape[0]
22 | N2 = N * desc_rate # length of the video before down-sampling
23 |
24 | penalties = np.zeros(m + 1)
25 | # Prevent division by zero (in case of 0 changes)
26 | ncp = np.arange(1, m + 1)
27 | penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1)
28 |
29 | costs = scores / float(N) + penalties
30 | m_best = np.argmin(costs)
31 | cps, scores2 = cpd_nonlin(K, m_best, **kwargs)
32 |
33 | return cps, scores2
34 |
--------------------------------------------------------------------------------
/evaluation_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.stats import spearmanr, kendalltau, rankdata
3 |
4 | # Calculate Kendall's and Spearman's coefficients
5 | def get_corr_coeff(pred_imp_scores, videos, dataset, user_scores=None):
6 | rho_coeff, tau_coeff = [], []
7 | if dataset=='SumMe':
8 | for pred_imp_score,video in zip(pred_imp_scores,videos):
9 | true = np.mean(user_scores,axis=0)
10 | rho_coeff.append(spearmanr(pred_imp_score,true)[0])
11 | tau_coeff.append(kendalltau(rankdata(pred_imp_score),rankdata(true))[0])
12 | elif dataset=='TVSum':
13 | for pred_imp_score,video in zip(pred_imp_scores,videos):
14 | pred_imp_score = np.squeeze(pred_imp_score).tolist()
15 | user = int(video.split("_")[-1])
16 |
17 | curr_user_score = user_scores[user-1]
18 |
19 | tmp_rho_coeff, tmp_tau_coeff = [], []
20 | for annotation in range(len(curr_user_score)):
21 | true_user_score = curr_user_score[annotation]
22 | curr_rho_coeff, _ = spearmanr(pred_imp_score, true_user_score)
23 | curr_tau_coeff, _ = kendalltau(rankdata(pred_imp_score), rankdata(true_user_score))
24 | tmp_rho_coeff.append(curr_rho_coeff)
25 | tmp_tau_coeff.append(curr_tau_coeff)
26 | rho_coeff.append(np.mean(tmp_rho_coeff))
27 | tau_coeff.append(np.mean(tmp_tau_coeff))
28 | rho_coeff = np.array(rho_coeff).mean()
29 | tau_coeff = np.array(tau_coeff).mean()
30 | return rho_coeff, tau_coeff
--------------------------------------------------------------------------------
/generate_summary.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from knapsack_implementation import knapSack
3 |
4 | # Generate summary videos
5 | def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions):
6 | all_summaries = []
7 | for video_index in range(len(all_scores)):
8 | shot_bound = all_shot_bound[video_index]
9 | frame_init_scores = all_scores[video_index]
10 | n_frames = all_nframes[video_index]
11 | positions = all_positions[video_index]
12 |
13 | frame_scores = np.zeros(n_frames, dtype=np.float32)
14 | if positions.dtype != int:
15 | positions = positions.astype(np.int32)
16 | if positions[-1] != n_frames:
17 | positions = np.concatenate([positions, [n_frames]])
18 | for i in range(len(positions) - 1):
19 | pos_left, pos_right = positions[i], positions[i + 1]
20 | if i == len(frame_init_scores):
21 | frame_scores[pos_left:pos_right] = 0
22 | else:
23 | frame_scores[pos_left:pos_right] = frame_init_scores[i]
24 |
25 | shot_imp_scores = []
26 | shot_lengths = []
27 | for shot in shot_bound:
28 | shot_lengths.append(shot[1] - shot[0] + 1)
29 | shot_imp_scores.append((frame_scores[shot[0]:shot[1] + 1].mean()).item())
30 |
31 | final_shot = shot_bound[-1]
32 | final_max_length = int((final_shot[1] + 1) * 0.15)
33 |
34 | selected = knapSack(final_max_length, shot_lengths, shot_imp_scores, len(shot_lengths))
35 |
36 | summary = np.zeros(final_shot[1] + 1, dtype=np.int8)
37 | for shot in selected:
38 | summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1
39 |
40 | all_summaries.append(summary)
41 |
42 | return all_summaries
43 |
--------------------------------------------------------------------------------
/splits/TVSum_splits.txt:
--------------------------------------------------------------------------------
1 | TVSum/split0/video_1,video_2,video_3,video_4,video_5,video_6,video_8,video_9,video_10,video_11,video_13,video_14,video_15,video_16,video_17,video_18,video_19,video_20,video_22,video_23,video_25,video_26,video_27,video_28,video_29,video_30,video_32,video_33,video_35,video_36,video_37,video_39,video_40,video_41,video_42,video_43,video_45,video_46,video_47,video_50/video_7,video_12,video_21,video_24,video_31,video_34,video_38,video_44,video_48,video_49
2 | TVSum/split1/video_2,video_3,video_5,video_6,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_16,video_18,video_21,video_22,video_23,video_24,video_25,video_26,video_27,video_28,video_29,video_30,video_31,video_32,video_33,video_34,video_35,video_37,video_38,video_39,video_40,video_42,video_44,video_45,video_47,video_48,video_49,video_50/video_1,video_4,video_15,video_17,video_19,video_20,video_36,video_41,video_43,video_46
3 | TVSum/split2/video_1,video_3,video_4,video_5,video_6,video_7,video_10,video_11,video_12,video_13,video_14,video_15,video_16,video_17,video_19,video_20,video_21,video_22,video_23,video_24,video_25,video_26,video_28,video_30,video_31,video_32,video_33,video_34,video_36,video_37,video_38,video_41,video_42,video_43,video_44,video_45,video_46,video_48,video_49,video_50/video_2,video_8,video_9,video_18,video_27,video_29,video_35,video_39,video_40,video_47
4 | TVSum/split3/video_1,video_2,video_3,video_4,video_7,video_8,video_9,video_10,video_11,video_12,video_13,video_14,video_15,video_17,video_18,video_19,video_20,video_21,video_24,video_26,video_27,video_29,video_30,video_31,video_32,video_34,video_35,video_36,video_37,video_38,video_39,video_40,video_41,video_42,video_43,video_44,video_46,video_47,video_48,video_49/video_5,video_6,video_16,video_22,video_23,video_25,video_28,video_33,video_45,video_50
5 | TVSum/split4/video_1,video_2,video_4,video_5,video_6,video_7,video_8,video_9,video_12,video_15,video_16,video_17,video_18,video_19,video_20,video_21,video_22,video_23,video_24,video_25,video_27,video_28,video_29,video_31,video_33,video_34,video_35,video_36,video_38,video_39,video_40,video_41,video_43,video_44,video_45,video_46,video_47,video_48,video_49,video_50/video_3,video_10,video_11,video_13,video_14,video_26,video_30,video_32,video_37,video_42
6 |
--------------------------------------------------------------------------------
/kts/demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 | from kts.cpd_auto import cpd_auto
5 | from kts.cpd_nonlin import cpd_nonlin
6 |
7 |
8 | def gen_data(n, m, d=1):
9 | """Generates data with change points
10 |
11 | .. warning::
12 | sigma is proportional to m
13 |
14 | :param n: Number of samples
15 | :param m: Number of change-points
16 | :param d:
17 | :return: Tuple (X, cps)
18 | - X - data array (n X d)
19 | - cps - change-points array, including 0 and n
20 | """
21 | np.random.seed(1)
22 | # Select changes at some distance from the boundaries
23 | cps = np.random.permutation(n * 3 // 4 - 1)[0:m] + 1 + n // 8
24 | cps = np.sort(cps)
25 | cps = [0] + list(cps) + [n]
26 | mus = np.random.rand(m + 1, d) * (m / 2) # make sigma = m/2
27 | X = np.zeros((n, d))
28 | for k in range(m + 1):
29 | X[cps[k]:cps[k + 1], :] = mus[k, :][np.newaxis, :] + np.random.rand(
30 | cps[k + 1] - cps[k], d)
31 | return X, np.array(cps)
32 |
33 |
34 | if __name__ == '__main__':
35 | plt.ioff()
36 |
37 | print('Test 1: 1-dimensional signal')
38 | plt.figure('Test 1: 1-dimensional signal')
39 | n = 1000
40 | m = 10
41 | (X, cps_gt) = gen_data(n, m)
42 | print('Ground truth:', cps_gt)
43 | plt.plot(X)
44 | K = np.dot(X, X.T)
45 | cps, scores = cpd_nonlin(K, m, lmin=1, lmax=10000)
46 | print('Estimated:', cps)
47 | mi = np.min(X)
48 | ma = np.max(X)
49 | for cp in cps:
50 | plt.plot([cp, cp], [mi, ma], 'r')
51 | plt.show()
52 | print('=' * 79)
53 |
54 | print('Test 2: multidimensional signal')
55 | plt.figure('Test 2: multidimensional signal')
56 | n = 1000
57 | m = 20
58 | (X, cps_gt) = gen_data(n, m, d=50)
59 | print('Ground truth:', cps_gt)
60 | plt.plot(X)
61 | K = np.dot(X, X.T)
62 | cps, scores = cpd_nonlin(K, m, lmin=1, lmax=10000)
63 | print('Estimated:', cps)
64 | mi = np.min(X)
65 | ma = np.max(X)
66 | for cp in cps:
67 | plt.plot([cp, cp], [mi, ma], 'r')
68 | plt.show()
69 | print('=' * 79)
70 |
71 | print('Test 3: automatic selection of the number of change-points')
72 | plt.figure('Test 3: automatic selection of the number of change-points')
73 | (X, cps_gt) = gen_data(n, m)
74 | print('Ground truth: (m=%d)' % m, cps_gt)
75 | plt.plot(X)
76 | K = np.dot(X, X.T)
77 | cps, scores = cpd_auto(K, 2 * m, 1)
78 | print('Estimated: (m=%d)' % len(cps), cps)
79 | mi = np.min(X)
80 | ma = np.max(X)
81 | for cp in cps:
82 | plt.plot([cp, cp], [mi, ma], 'r')
83 | plt.show()
84 | print('=' * 79)
85 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 |
4 | from torch.utils.data import Dataset,DataLoader
5 |
6 | # Load split
7 | def load_split(dataset):
8 | outputs = []
9 | with open(f'./splits/{dataset}_splits.txt','r') as f:
10 | lines = f.readlines()
11 | for line in lines:
12 | _,_,train_videos,test_videos = line.split('/')
13 | train_videos = train_videos.split(',')
14 | test_videos = test_videos.split(',')
15 | test_videos[-1] = test_videos[-1].replace('\n','')
16 | outputs.append((train_videos,test_videos))
17 | return outputs
18 |
19 | # Create input,ground truth pair
20 | def load_h5(videos,data_path,dataset_name):
21 | features = []
22 | gtscores = []
23 | dataset_names = []
24 |
25 | with h5py.File(data_path,'r') as hdf:
26 | for video in videos:
27 | feature = hdf[video]['features'][()]
28 | gtscore = hdf[video]['gtscore'][()]
29 |
30 | features.append(feature)
31 | gtscores.append(gtscore)
32 | dataset_names.append(dataset_name)
33 | return features,gtscores,dataset_names
34 |
35 | # Create Dataset
36 | class VSdataset(Dataset):
37 | def __init__(self,data,video_nums,transform=None):
38 | features,gtscores,dataset_names = data
39 | self.features = features
40 | self.gtscores = gtscores
41 | self.dataset_names = dataset_names
42 | self.video_nums = video_nums
43 | self.transform = transform
44 | def __len__(self):
45 | return len(self.video_nums)
46 | def __getitem__(self,idx):
47 | output_feature = torch.from_numpy(self.features[idx]).float()
48 | output_feature = output_feature.unsqueeze(0).expand(3,-1,-1)
49 | if self.transform is not None:
50 | output_feature= self.transform(output_feature)
51 | return torch.unsqueeze(output_feature,0),torch.from_numpy(self.gtscores[idx]).float(),self.dataset_names[idx],self.video_nums[idx]
52 |
53 | def collate_fn(sample):
54 | return sample[0]
55 |
56 | # Create Dataloader
57 | def create_dataloader(dataset):
58 | loaders = []
59 |
60 | splits = load_split(dataset=dataset)
61 | data_path = f'./data/eccv16_dataset_{dataset.lower()}_google_pool5.h5'
62 |
63 | for train_videos,test_videos in splits:
64 | train_data = load_h5(videos=train_videos,data_path=data_path,dataset_name=dataset)
65 | test_data = load_h5(videos=test_videos,data_path=data_path,dataset_name=dataset)
66 |
67 | train_dataset = VSdataset(data=train_data,video_nums=train_videos)
68 | test_dataset = VSdataset(data=test_data,video_nums=test_videos)
69 | train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True,collate_fn=collate_fn)
70 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,collate_fn=collate_fn)
71 | loaders.append((train_loader,test_loader))
72 | return loaders
--------------------------------------------------------------------------------
/kts/cpd_nonlin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tqdm import tqdm
4 |
5 | def calc_scatters(K):
6 | """Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with
7 | starting frame i and ending frame j}
8 | """
9 | n = K.shape[0]
10 | K1 = np.cumsum([0] + list(np.diag(K)))
11 | K2 = np.zeros((n + 1, n + 1))
12 | # TODO: use the fact that K - symmetric
13 | K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1)
14 |
15 | diagK2 = np.diag(K2)
16 |
17 | i = np.arange(n).reshape((-1, 1))
18 | j = np.arange(n).reshape((1, -1))
19 | scatters = (
20 | K1[1:].reshape((1, -1)) - K1[:-1].reshape((-1, 1)) -
21 | (diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape((-1, 1)) -
22 | K2[1:, :-1].T - K2[:-1, 1:]) /
23 | ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32))
24 | )
25 | scatters[j < i] = 0
26 |
27 | return scatters
28 |
29 |
30 | def cpd_nonlin(K, ncp, lmin=1, lmax=100000, backtrack=True, verbose=True,
31 | out_scatters=None):
32 | """Change point detection with dynamic programming
33 |
34 | :param K: Square kernel matrix
35 | :param ncp: Number of change points to detect (ncp >= 0)
36 | :param lmin: Minimal length of a segment
37 | :param lmax: Maximal length of a segment
38 | :param backtrack: If False - only evaluate objective scores (to save memory)
39 | :param verbose: If true, print verbose message
40 | :param out_scatters: Output scatters
41 | :return: Tuple (cps, obj_vals)
42 | - cps - detected array of change points: mean is thought to be constant
43 | on [ cps[i], cps[i+1] )
44 | - obj_vals - values of the objective function for 0..m changepoints
45 | """
46 | m = int(ncp) # prevent numpy.int64
47 |
48 | n, n1 = K.shape
49 | assert n == n1, 'Kernel matrix awaited.'
50 | assert (m + 1) * lmin <= n <= (m + 1) * lmax
51 | assert 1 <= lmin <= lmax
52 |
53 | if verbose:
54 | print('Precomputing scatters...')
55 | J = calc_scatters(K)
56 |
57 | if out_scatters is not None:
58 | out_scatters[0] = J
59 |
60 | if verbose:
61 | print('Inferring best change points...')
62 | # I[k, l] - value of the objective for k change-points and l first frames
63 | I = 1e101 * np.ones((m + 1, n + 1))
64 | I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1]
65 |
66 | if backtrack:
67 | # p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l
68 | p = np.zeros((m + 1, n + 1), dtype=int)
69 | else:
70 | p = np.zeros((1, 1), dtype=int)
71 |
72 | for k in tqdm(range(1, m + 1), ncols = 90, total = m, desc = 'KTS outer', leave = False):
73 | for l in tqdm(range((k + 1) * lmin, n + 1), ncols = 90, total = n + 1 - ((k + 1) * lmin), desc = 'KTS inner', leave = False):
74 | tmin = max(k * lmin, l - lmax)
75 | tmax = l - lmin + 1
76 | c = J[tmin:tmax, l - 1].reshape(-1) + \
77 | I[k - 1, tmin:tmax].reshape(-1)
78 | I[k, l] = np.min(c)
79 | if backtrack:
80 | p[k, l] = np.argmin(c) + tmin
81 |
82 | # Collect change points
83 | cps = np.zeros(m, dtype=int)
84 |
85 | if backtrack:
86 | cur = n
87 | for k in tqdm(range(m, 0, -1), ncols = 90, total = m, desc = 'KTS final', leave = False):
88 | cps[k - 1] = p[k, cur]
89 | cur = cps[k - 1]
90 |
91 | scores = I[:, n].copy()
92 | scores[scores > 1e99] = np.inf
93 | return cps, scores
94 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import random
4 | import torch
5 |
6 | # Process bool argument
7 | def str2bool(v):
8 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
9 | return True
10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
11 | return False
12 | else:
13 | raise
14 |
15 | # Process none argument
16 | def str2none(v):
17 | if v.lower()=='none':
18 | return None
19 | else:
20 | return v
21 |
22 | # Define configuration class
23 | class Config(object):
24 | def __init__(self, **kwargs):
25 | self.kwargs = kwargs
26 | for k, v in kwargs.items():
27 | setattr(self, k, v)
28 |
29 | self.datasets = ['SumMe','TVSum']
30 | self.SumMe_len = 25
31 | self.TVSum_len = 50
32 |
33 | # Set device
34 | if self.device!='cpu':
35 | torch.cuda.set_device(self.device)
36 |
37 | # Set seed
38 | self.set_seed()
39 |
40 | # Set the seed
41 | def set_seed(self):
42 | random.seed(self.seed)
43 | np.random.seed(self.seed)
44 | torch.manual_seed(self.seed)
45 | if self.device!='cpu':
46 | torch.cuda.manual_seed(self.seed)
47 | torch.cuda.manual_seed_all(self.seed)
48 | torch.backends.cudnn.benchmark = False
49 | torch.backends.cudnn.deterministic = True
50 |
51 | # Define all configurations
52 | def get_config(parse=True, **optional_kwargs):
53 | parser = argparse.ArgumentParser()
54 |
55 | parser.add_argument('--seed', type=int, default=123456)
56 | parser.add_argument('--device', type=str, default='cuda:0')
57 | parser.add_argument('--epochs', type=int, default=100)
58 | parser.add_argument('--batch_size', default='1')
59 | parser.add_argument('--learning_rate', default='1e-3')
60 | parser.add_argument('--weight_decay', default='1e-7')
61 |
62 | parser.add_argument('--model_name', type=str, default='GoogleNet_Attention')
63 | parser.add_argument('--Scale', type=str2none, default=None)
64 | parser.add_argument('--Softmax_axis', type=str2none, default='TD')
65 | parser.add_argument('--Balance', type=str2none, default=None)
66 |
67 | parser.add_argument('--Positional_encoding', type=str2none, default='FPE')
68 | parser.add_argument('--Positional_encoding_shape', type=str2none, default='TD')
69 | parser.add_argument('--Positional_encoding_way', type=str2none, default='PGL_SUM')
70 | parser.add_argument('--Dropout_on', type=str2bool, default=True)
71 | parser.add_argument('--Dropout_ratio', default='0.6')
72 |
73 | parser.add_argument('--Classifier_on', type=str2bool, default=True)
74 | parser.add_argument('--CLS_on', type=str2bool, default=True)
75 | parser.add_argument('--CLS_mix', type=str2none, default='Final')
76 |
77 | parser.add_argument('--key_value_emb', type=str2none, default='kv')
78 | parser.add_argument('--Skip_connection', type=str2none, default='KC')
79 | parser.add_argument('--Layernorm', type=str2bool, default=True)
80 |
81 | # Generate summary videos
82 | parser.add_argument('--input_is_file', type=str2bool, default='true')
83 | parser.add_argument('--file_path', type=str, default='./SumMe/Jumps.mp4')
84 | parser.add_argument('--dir_path', type=str, default='./SumMe')
85 | parser.add_argument('--ext', type=str, default='mp4')
86 | parser.add_argument('--sample_rate', type=int, default=15)
87 | parser.add_argument('--save_path', type=str, default='./summary_videos')
88 | parser.add_argument('--weight_path', type=str, default='./weights/SumMe/split4.pt')
89 |
90 | kwargs = vars(parser.parse_args())
91 | kwargs.update(optional_kwargs)
92 |
93 | return Config(**kwargs)
94 |
--------------------------------------------------------------------------------
/video_helper.py:
--------------------------------------------------------------------------------
1 | # Reference code: https://github.com/li-plus/DSNet/blob/1804176e2e8b57846beb063667448982273fca89/src/helpers/video_helper.py
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from os import PathLike
8 | from pathlib import Path
9 | from PIL import Image
10 | from torchvision import transforms, models
11 | from torchvision.models import GoogLeNet_Weights
12 | from tqdm import tqdm
13 |
14 | from kts.cpd_auto import cpd_auto
15 |
16 | class FeatureExtractor(object):
17 | def __init__(self, device):
18 | self.device = device
19 | self.transforms = GoogLeNet_Weights.IMAGENET1K_V1.transforms()
20 | weights = GoogLeNet_Weights.IMAGENET1K_V1
21 | self.model = models.googlenet(weights=weights)
22 | self.model = nn.Sequential(*list(self.model.children())[:-2])
23 | self.model.to(self.device)
24 | self.model.eval()
25 |
26 | def run(self, img: np.ndarray):
27 | img = Image.fromarray(img)
28 | img = self.transforms(img)
29 | batch = img.unsqueeze(0)
30 | with torch.no_grad():
31 | batch = batch.to(self.device)
32 | feat = self.model(batch)
33 | feat = feat.squeeze()
34 |
35 | assert feat.shape == (1024,), f'Invalid feature shape {feat.shape}: expected 1024'
36 | # normalize frame features
37 | feat = feat / (torch.norm(feat) + 1e-10)
38 | return feat
39 |
40 | class VideoPreprocessor(object):
41 | def __init__(self, sample_rate: int, device: str):
42 | self.model = FeatureExtractor(device)
43 | self.sample_rate = sample_rate
44 |
45 | def get_features(self, video_path: PathLike):
46 | video_path = Path(video_path)
47 | cap = cv2.VideoCapture(str(video_path))
48 | assert cap is not None, f'Cannot open video: {video_path}'
49 |
50 | self.fps = cap.get(cv2.CAP_PROP_FPS)
51 | self.frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
52 | self.frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
53 |
54 | features = []
55 | n_frames = 0
56 |
57 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58 | with tqdm(total = total_frames, ncols=90, desc = "getting features", unit='frame', leave=False) as pbar:
59 | while True:
60 | ret, frame = cap.read()
61 |
62 | if not ret:
63 | break
64 |
65 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66 | feat = self.model.run(frame)
67 | features.append(feat)
68 |
69 | n_frames += 1
70 | pbar.update(1)
71 |
72 | cap.release()
73 | features = torch.stack(features)
74 | return n_frames, features
75 |
76 | def kts(self, n_frames, features):
77 | seq_len = len(features)
78 | picks = np.arange(0, seq_len)
79 | # compute change points using KTS
80 | kernel = np.matmul(features.clone().detach().cpu().numpy(), features.clone().detach().cpu().numpy().T)
81 | change_points, _ = cpd_auto(kernel, seq_len - 1, 1, verbose=False)
82 | change_points = np.hstack((0, change_points, n_frames))
83 | begin_frames = change_points[:-1]
84 | end_frames = change_points[1:]
85 | change_points = np.vstack((begin_frames, end_frames - 1)).T
86 | return change_points, picks
87 |
88 | def run(self, video_path: PathLike):
89 | n_frames, features = self.get_features(video_path)
90 | cps, picks = self.kts(n_frames, features)
91 | return n_frames, features[::self.sample_rate,:], cps, picks[::self.sample_rate]
92 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import torch
4 |
5 | from config import get_config
6 | from dataset import create_dataloader
7 | from evaluation_metrics import get_corr_coeff
8 | from generate_summary import generate_summary
9 | from model import set_model
10 | from utils import report_params,get_gt
11 |
12 | # Load configurations
13 | config = get_config()
14 |
15 | # Print the number of parameters
16 | report_params(
17 | model_name=config.model_name,
18 | Scale=config.Scale,
19 | Softmax_axis=config.Softmax_axis,
20 | Balance=config.Balance,
21 | Positional_encoding=config.Positional_encoding,
22 | Positional_encoding_shape=config.Positional_encoding_shape,
23 | Positional_encoding_way=config.Positional_encoding_way,
24 | Dropout_on=config.Dropout_on,
25 | Dropout_ratio=config.Dropout_ratio,
26 | Classifier_on=config.Classifier_on,
27 | CLS_on=config.CLS_on,
28 | CLS_mix=config.CLS_mix,
29 | key_value_emb=config.key_value_emb,
30 | Skip_connection=config.Skip_connection,
31 | Layernorm=config.Layernorm
32 | )
33 |
34 | # Start testing
35 | for dataset in config.datasets:
36 | user_scores = get_gt(dataset)
37 | split_kendalls = []
38 | split_spears = []
39 |
40 | for split_id,(train_loader,test_loader) in enumerate(create_dataloader(dataset)):
41 | model = set_model(
42 | model_name=config.model_name,
43 | Scale=config.Scale,
44 | Softmax_axis=config.Softmax_axis,
45 | Balance=config.Balance,
46 | Positional_encoding=config.Positional_encoding,
47 | Positional_encoding_shape=config.Positional_encoding_shape,
48 | Positional_encoding_way=config.Positional_encoding_way,
49 | Dropout_on=config.Dropout_on,
50 | Dropout_ratio=config.Dropout_ratio,
51 | Classifier_on=config.Classifier_on,
52 | CLS_on=config.CLS_on,
53 | CLS_mix=config.CLS_mix,
54 | key_value_emb=config.key_value_emb,
55 | Skip_connection=config.Skip_connection,
56 | Layernorm=config.Layernorm
57 | )
58 | model.load_state_dict(torch.load(f'./weights/{dataset}/split{split_id+1}.pt', map_location='cpu'))
59 | model.to(config.device)
60 | model.eval()
61 |
62 | kendalls = []
63 | spears = []
64 | with torch.no_grad():
65 | for feature,_,dataset_name,video_num in test_loader:
66 | feature = feature.to(config.device)
67 | output = model(feature)
68 |
69 | with h5py.File(f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5','r') as hdf:
70 | user_summary = np.array(hdf[video_num]['user_summary'])
71 | sb = np.array(hdf[f"{video_num}/change_points"])
72 | n_frames = np.array(hdf[f"{video_num}/n_frames"])
73 | positions = np.array(hdf[f"{video_num}/picks"])
74 | scores = output.squeeze().clone().detach().cpu().numpy().tolist()
75 | summary = generate_summary([sb], [scores], [n_frames], [positions])[0]
76 |
77 | if dataset_name=='SumMe':
78 | spear,kendall = get_corr_coeff([summary],[video_num],dataset_name,user_summary)
79 | elif dataset_name=='TVSum':
80 | spear,kendall = get_corr_coeff([scores],[video_num],dataset_name,user_scores)
81 |
82 | spears.append(spear)
83 | kendalls.append(kendall)
84 | split_kendalls.append(np.mean(kendalls))
85 | split_spears.append(np.mean(spears))
86 | print("[Split{}]Kendall:{:.3f}, Spear:{:.3f}".format(
87 | split_id,split_kendalls[split_id],split_spears[split_id]
88 | ))
89 | print("[FINAL - {}]Kendall:{:.3f}, Spear:{:.3f}".format(
90 | dataset,np.mean(split_kendalls),np.mean(split_spears)
91 | ))
92 | print()
93 |
--------------------------------------------------------------------------------
/generate_video.py:
--------------------------------------------------------------------------------
1 | # Reference code: https://github.com/li-plus/DSNet/blob/1804176e2e8b57846beb063667448982273fca89/src/make_dataset.py#L4
2 | # Reference code: https://github.com/e-apostolidis/PGL-SUM/blob/81d0d6d0ee0470775ad759087deebbce1ceffec3/model/configs.py#L10
3 | import cv2
4 | import torch
5 |
6 | from pathlib import Path
7 | from tqdm import tqdm
8 |
9 | from config import get_config
10 | from generate_summary import generate_summary
11 | from model import set_model
12 | from video_helper import VideoPreprocessor
13 |
14 | def pick_frames(video_path, selections):
15 | cap = cv2.VideoCapture(str(video_path))
16 | frames = []
17 | n_frames = 0
18 |
19 | with tqdm(total = len(selections), ncols=90, desc = "selecting frames", unit='frame', leave = False) as pbar:
20 | while True:
21 | ret, frame = cap.read()
22 |
23 | if not ret:
24 | break
25 |
26 | if selections[n_frames]:
27 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
28 | frames.append(frame)
29 | n_frames += 1
30 |
31 | pbar.update(1)
32 |
33 | cap.release()
34 |
35 | return frames
36 |
37 | def produce_video(save_path, frames, fps, frame_size):
38 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
39 | out = cv2.VideoWriter(save_path, fourcc, fps, frame_size)
40 | for frame in tqdm(frames, total = len(frames), ncols=90, desc = "generating videos", leave = False):
41 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
42 | out.write(frame)
43 | out.release()
44 |
45 | def main():
46 | # Load config
47 | config = get_config()
48 |
49 | # create output directory
50 | out_dir = Path(config.save_path)
51 | out_dir.mkdir(parents=True, exist_ok=True)
52 |
53 | # feature extractor
54 | video_proc = VideoPreprocessor(
55 | sample_rate=config.sample_rate,
56 | device=config.device
57 | )
58 |
59 | # search all videos with .mp4 suffix
60 | if config.input_is_file:
61 | video_paths = [Path(config.file_path)]
62 | else:
63 | video_paths = sorted(Path(config.dir_path).glob(f'*.{config.ext}'))
64 |
65 | # Load CSTA weights
66 | model = set_model(
67 | model_name=config.model_name,
68 | Scale=config.Scale,
69 | Softmax_axis=config.Softmax_axis,
70 | Balance=config.Balance,
71 | Positional_encoding=config.Positional_encoding,
72 | Positional_encoding_shape=config.Positional_encoding_shape,
73 | Positional_encoding_way=config.Positional_encoding_way,
74 | Dropout_on=config.Dropout_on,
75 | Dropout_ratio=config.Dropout_ratio,
76 | Classifier_on=config.Classifier_on,
77 | CLS_on=config.CLS_on,
78 | CLS_mix=config.CLS_mix,
79 | key_value_emb=config.key_value_emb,
80 | Skip_connection=config.Skip_connection,
81 | Layernorm=config.Layernorm
82 | )
83 | model.load_state_dict(torch.load(config.weight_path, map_location='cpu'))
84 | model.to(config.device)
85 | model.eval()
86 |
87 | # Generate summarized videos
88 | with torch.no_grad():
89 | for video_path in tqdm(video_paths,total=len(video_paths),ncols=80,leave=False,desc="Making videos..."):
90 | video_name = video_path.stem
91 | n_frames, features, cps, pick = video_proc.run(video_path)
92 |
93 | inputs = features.to(config.device)
94 | inputs = inputs.unsqueeze(0).expand(3,-1,-1).unsqueeze(0)
95 | outputs = model(inputs)
96 | predictions = outputs.squeeze().clone().detach().cpu().numpy().tolist()
97 | # print(cps.shape, len(predictions), n_frames, pick.shape)
98 | selections = generate_summary([cps], [predictions], [n_frames], [pick])[0]
99 |
100 | frames = pick_frames(video_path=video_path, selections=selections)
101 | produce_video(
102 | save_path=f'{config.save_path}/{video_name}.mp4',
103 | frames=frames,
104 | fps=video_proc.fps,
105 | frame_size=(video_proc.frame_width,video_proc.frame_height)
106 | )
107 |
108 | if __name__=='__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import EfficientNet_B0_Weights,GoogLeNet_Weights,MobileNet_V2_Weights,ResNet18_Weights
2 |
3 | from models.EfficientNet import CSTA_EfficientNet
4 | from models.GoogleNet import CSTA_GoogleNet
5 | from models.MobileNet import CSTA_MobileNet
6 | from models.ResNet import CSTA_ResNet
7 |
8 | # Load models depending on CNN
9 | def set_model(model_name,
10 | Scale,
11 | Softmax_axis,
12 | Balance,
13 | Positional_encoding,
14 | Positional_encoding_shape,
15 | Positional_encoding_way,
16 | Dropout_on,
17 | Dropout_ratio,
18 | Classifier_on,
19 | CLS_on,
20 | CLS_mix,
21 | key_value_emb,
22 | Skip_connection,
23 | Layernorm):
24 | if model_name in ['EfficientNet','EfficientNet_Attention']:
25 | model = CSTA_EfficientNet(
26 | model_name=model_name,
27 | Scale=Scale,
28 | Softmax_axis=Softmax_axis,
29 | Balance=Balance,
30 | Positional_encoding=Positional_encoding,
31 | Positional_encoding_shape=Positional_encoding_shape,
32 | Positional_encoding_way=Positional_encoding_way,
33 | Dropout_on=Dropout_on,
34 | Dropout_ratio=Dropout_ratio,
35 | Classifier_on=Classifier_on,
36 | CLS_on=CLS_on,
37 | CLS_mix=CLS_mix,
38 | key_value_emb=key_value_emb,
39 | Skip_connection=Skip_connection,
40 | Layernorm=Layernorm
41 | )
42 | state_dict = EfficientNet_B0_Weights.IMAGENET1K_V1.get_state_dict(progress=False)
43 | model.efficientnet.load_state_dict(state_dict)
44 | elif model_name in ['GoogleNet','GoogleNet_Attention']:
45 | model = CSTA_GoogleNet(
46 | model_name=model_name,
47 | Scale=Scale,
48 | Softmax_axis=Softmax_axis,
49 | Balance=Balance,
50 | Positional_encoding=Positional_encoding,
51 | Positional_encoding_shape=Positional_encoding_shape,
52 | Positional_encoding_way=Positional_encoding_way,
53 | Dropout_on=Dropout_on,
54 | Dropout_ratio=Dropout_ratio,
55 | Classifier_on=Classifier_on,
56 | CLS_on=CLS_on,
57 | CLS_mix=CLS_mix,
58 | key_value_emb=key_value_emb,
59 | Skip_connection=Skip_connection,
60 | Layernorm=Layernorm
61 | )
62 | state_dict = GoogLeNet_Weights.IMAGENET1K_V1.get_state_dict(progress=False)
63 | state_dict = {k: v for k, v in state_dict.items() if not k.startswith('aux')}
64 | new_state_dict = model.googlenet.state_dict()
65 | for name,param in state_dict.items():
66 | new_state_dict[name] = param
67 | model.googlenet.load_state_dict(new_state_dict)
68 | elif model_name in ['MobileNet','MobileNet_Attention']:
69 | model = CSTA_MobileNet(
70 | model_name=model_name,
71 | Scale=Scale,
72 | Softmax_axis=Softmax_axis,
73 | Balance=Balance,
74 | Positional_encoding=Positional_encoding,
75 | Positional_encoding_shape=Positional_encoding_shape,
76 | Positional_encoding_way=Positional_encoding_way,
77 | Dropout_on=Dropout_on,
78 | Dropout_ratio=Dropout_ratio,
79 | Classifier_on=Classifier_on,
80 | CLS_on=CLS_on,
81 | CLS_mix=CLS_mix,
82 | key_value_emb=key_value_emb,
83 | Skip_connection=Skip_connection,
84 | Layernorm=Layernorm
85 | )
86 | state_dict = MobileNet_V2_Weights.IMAGENET1K_V1.get_state_dict(progress=False)
87 | model.mobilenet.load_state_dict(state_dict)
88 | elif model_name in ['ResNet','ResNet_Attention']:
89 | model = CSTA_ResNet(
90 | model_name=model_name,
91 | Scale=Scale,
92 | Softmax_axis=Softmax_axis,
93 | Balance=Balance,
94 | Positional_encoding=Positional_encoding,
95 | Positional_encoding_shape=Positional_encoding_shape,
96 | Positional_encoding_way=Positional_encoding_way,
97 | Dropout_on=Dropout_on,
98 | Dropout_ratio=Dropout_ratio,
99 | Classifier_on=Classifier_on,
100 | CLS_on=CLS_on,
101 | CLS_mix=CLS_mix,
102 | key_value_emb=key_value_emb,
103 | Skip_connection=Skip_connection,
104 | Layernorm=Layernorm
105 | )
106 | state_dict = ResNet18_Weights.IMAGENET1K_V1.get_state_dict(progress=False)
107 | model.resnet.load_state_dict(state_dict)
108 | else:
109 | raise
110 | return model
--------------------------------------------------------------------------------
/models/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | class FixedPositionalEncoding(nn.Module):
6 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000, freq=10000.0):
7 | super(FixedPositionalEncoding, self).__init__()
8 | if Positional_encoding_shape=='TD':
9 | position = torch.arange(max_len).unsqueeze(1)
10 | div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(freq) / dim))
11 |
12 | pe = torch.zeros(max_len, dim)
13 | pe[:, 0::2] = torch.sin(position * div_term)
14 | pe[:, 1::2] = torch.cos(position * div_term)
15 | elif Positional_encoding_shape=='T':
16 | position = torch.arange(max_len).unsqueeze(1)
17 | div_term = torch.exp(torch.arange(0, 1, 2) * (-math.log(freq) / 1))
18 |
19 | pe = torch.zeros(max_len,1)
20 | pe[:, 0::2] = torch.sin(position * div_term)
21 | pe[:, 1::2] = torch.cos(position * div_term)
22 | pe = pe.repeat_interleave(dim,dim=1)
23 | elif Positional_encoding_shape is None:
24 | pass
25 | else:
26 | raise
27 |
28 | self.register_buffer('pe', pe)
29 |
30 | def forward(self, x):
31 | return x + self.pe[:x.shape[0]]
32 |
33 | class RelativePositionalEncoding(nn.Module):
34 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000, freq=10000.0):
35 | super(RelativePositionalEncoding, self).__init__()
36 | self.Positional_encoding_shape = Positional_encoding_shape
37 | self.dim = dim
38 | self.max_len = max_len
39 | self.freq = freq
40 |
41 | def forward(self, x):
42 | T = x.shape[0]
43 | min_rpos = -(T - 1)
44 | i = torch.tensor([k for k in range(T)])
45 | i = i.reshape(i.shape[0], 1)
46 | if self.Positional_encoding_shape=='TD':
47 | d = T + self.dim
48 | j = torch.tensor([k for k in range(self.dim)])
49 |
50 | i = i.repeat_interleave(j.shape[0], dim=1)
51 | j = j.repeat(i.shape[0], 1)
52 |
53 | r_pos = j - i - min_rpos
54 |
55 | pe = torch.zeros(T, self.dim)
56 | idx = torch.tensor([k for k in range(T//2)],dtype=torch.int64)
57 |
58 | pe[2*idx, :] = torch.sin(r_pos[2*idx, :] / self.freq ** ((i[2*idx, :] + j[2*idx, :]) / d))
59 | pe[2*idx+1, :] = torch.cos(r_pos[2*idx+1, :] / self.freq ** ((i[2*idx+1, :] + j[2*idx+1, :]) / d))
60 | elif self.Positional_encoding_shape=='T':
61 | d = T + 1
62 | j = torch.tensor([k for k in range(1)])
63 |
64 | i = i.repeat_interleave(j.shape[0], dim=1)
65 | j = j.repeat(i.shape[0], 1)
66 |
67 | r_pos = j - i - min_rpos
68 |
69 | pe = torch.zeros(T, 1)
70 | idx = torch.tensor([k for k in range(T//2)],dtype=torch.int64)
71 |
72 | pe[2*idx, :] = torch.sin(r_pos[2*idx, :] / self.freq ** ((i[2*idx, :] + j[2*idx, :]) / d))
73 | pe[2*idx+1, :] = torch.cos(r_pos[2*idx+1, :] / self.freq ** ((i[2*idx+1, :] + j[2*idx+1, :]) / d))
74 | pe = pe.repeat_interleave(self.dim,dim=1)
75 | elif self.Positional_encoding_shape is None:
76 | pass
77 | else:
78 | raise
79 | return x + pe[:x.shape[0]].to(x.device)
80 |
81 | class LearnablePositionalEncoding(nn.Module):
82 | def __init__(self, Positional_encoding_shape, dim=1024, max_len=5000):
83 | super(LearnablePositionalEncoding, self).__init__()
84 | if Positional_encoding_shape=='TD':
85 | self.pe = nn.Parameter(torch.randn((max_len,dim)))
86 | elif Positional_encoding_shape=='T':
87 | self.pe = nn.Parameter(torch.randn((max_len,1)))
88 | elif Positional_encoding_shape is None:
89 | pass
90 | else:
91 | raise
92 |
93 | def forward(self, x):
94 | return x + self.pe[:x.shape[0]]
95 |
96 | class ConditionalPositionalEncoding(nn.Module):
97 | def __init__(self, Positional_encoding_shape, Positional_encoding_way, dim=1024, kernel_size=3, stride=1, padding=1):
98 | super(ConditionalPositionalEncoding, self).__init__()
99 | self.Positional_encoding_way = Positional_encoding_way
100 | if Positional_encoding_shape=='TD':
101 | self.pe = nn.Conv1d(in_channels=dim,out_channels=dim,kernel_size=kernel_size,stride=stride,padding=padding)
102 | elif Positional_encoding_shape=='T':
103 | self.pe = nn.Conv1d(in_channels=dim,out_channels=dim,kernel_size=kernel_size,stride=stride,padding=padding,groups=dim)
104 | else:
105 | raise
106 |
107 | def forward(self, x):
108 | if self.Positional_encoding_way=='Transformer':
109 | return x + self.pe(x[0].permute(0,2,1)).permute(0,2,1).unsqueeze(0)
110 | elif self.Positional_encoding_way=='PGL_SUM':
111 | return x + self.pe(x.unsqueeze(0).permute(0,2,1)).permute(0,2,1).squeeze(0)
112 | elif self.Positional_encoding_way is None:
113 | pass
114 | else:
115 | raise
116 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import numpy as np
3 | import torch
4 |
5 | from collections import Counter
6 |
7 | from models.EfficientNet import CSTA_EfficientNet
8 | from models.GoogleNet import CSTA_GoogleNet
9 | from models.MobileNet import CSTA_MobileNet
10 | from models.ResNet import CSTA_ResNet
11 |
12 | # Count the number of parameters
13 | def count_parameters(model,model_name):
14 | if model_name in ['GoogleNet','GoogleNet_Attention','ResNet','ResNet_Attention']:
15 | x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'fc' not in name]
16 | elif model_name in ['EfficientNet','EfficientNet_Attention','MobileNet','MobileNet_Attention']:
17 | x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'classifier' not in name]
18 | return sum(x) / (1024 * 1024)
19 |
20 | # Funtion printing the number of parameters of models
21 | def report_params(model_name,
22 | Scale,
23 | Softmax_axis,
24 | Balance,
25 | Positional_encoding,
26 | Positional_encoding_shape,
27 | Positional_encoding_way,
28 | Dropout_on,
29 | Dropout_ratio,
30 | Classifier_on,
31 | CLS_on,
32 | CLS_mix,
33 | key_value_emb,
34 | Skip_connection,
35 | Layernorm):
36 | if model_name in ['EfficientNet','EfficientNet_Attention']:
37 | model = CSTA_EfficientNet(
38 | model_name=model_name,
39 | Scale=Scale,
40 | Softmax_axis=Softmax_axis,
41 | Balance=Balance,
42 | Positional_encoding=Positional_encoding,
43 | Positional_encoding_shape=Positional_encoding_shape,
44 | Positional_encoding_way=Positional_encoding_way,
45 | Dropout_on=Dropout_on,
46 | Dropout_ratio=Dropout_ratio,
47 | Classifier_on=Classifier_on,
48 | CLS_on=CLS_on,
49 | CLS_mix=CLS_mix,
50 | key_value_emb=key_value_emb,
51 | Skip_connection=Skip_connection,
52 | Layernorm=Layernorm
53 | )
54 | elif model_name in ['GoogleNet','GoogleNet_Attention']:
55 | model = CSTA_GoogleNet(
56 | model_name=model_name,
57 | Scale=Scale,
58 | Softmax_axis=Softmax_axis,
59 | Balance=Balance,
60 | Positional_encoding=Positional_encoding,
61 | Positional_encoding_shape=Positional_encoding_shape,
62 | Positional_encoding_way=Positional_encoding_way,
63 | Dropout_on=Dropout_on,
64 | Dropout_ratio=Dropout_ratio,
65 | Classifier_on=Classifier_on,
66 | CLS_on=CLS_on,
67 | CLS_mix=CLS_mix,
68 | key_value_emb=key_value_emb,
69 | Skip_connection=Skip_connection,
70 | Layernorm=Layernorm
71 | )
72 | elif model_name in ['MobileNet','MobileNet_Attention']:
73 | model = CSTA_MobileNet(
74 | model_name=model_name,
75 | Scale=Scale,
76 | Softmax_axis=Softmax_axis,
77 | Balance=Balance,
78 | Positional_encoding=Positional_encoding,
79 | Positional_encoding_shape=Positional_encoding_shape,
80 | Positional_encoding_way=Positional_encoding_way,
81 | Dropout_on=Dropout_on,
82 | Dropout_ratio=Dropout_ratio,
83 | Classifier_on=Classifier_on,
84 | CLS_on=CLS_on,
85 | CLS_mix=CLS_mix,
86 | key_value_emb=key_value_emb,
87 | Skip_connection=Skip_connection,
88 | Layernorm=Layernorm
89 | )
90 | elif model_name in ['ResNet','ResNet_Attention']:
91 | model = CSTA_ResNet(
92 | model_name=model_name,
93 | Scale=Scale,
94 | Softmax_axis=Softmax_axis,
95 | Balance=Balance,
96 | Positional_encoding=Positional_encoding,
97 | Positional_encoding_shape=Positional_encoding_shape,
98 | Positional_encoding_way=Positional_encoding_way,
99 | Dropout_on=Dropout_on,
100 | Dropout_ratio=Dropout_ratio,
101 | Classifier_on=Classifier_on,
102 | CLS_on=CLS_on,
103 | CLS_mix=CLS_mix,
104 | key_value_emb=key_value_emb,
105 | Skip_connection=Skip_connection,
106 | Layernorm=Layernorm
107 | )
108 | print(f"PARAMS: {count_parameters(model,model_name):.2f}M")
109 |
110 | # Print all arguments and GPU setting
111 | def print_args(args):
112 | print(args.kwargs)
113 | print(f"CUDA: {torch.version.cuda}")
114 | print(f"cuDNN: {torch.backends.cudnn.version()}")
115 | if 'cuda' in args.device:
116 | print(f"GPU: {torch.cuda.is_available()}")
117 | print(f"GPU count: {torch.cuda.device_count()}")
118 | print(f"GPU name: {torch.cuda.get_device_name(0)}")
119 |
120 | # Load ground truth for TVSum
121 | def get_gt(dataset):
122 | if dataset=='TVSum':
123 | annot_path = f"./data/ydata-anno.tsv"
124 | with open(annot_path) as annot_file:
125 | annot = list(csv.reader(annot_file, delimiter="\t"))
126 | annotation_length = list(Counter(np.array(annot)[:, 0]).values())
127 | user_scores = []
128 | for idx in range(1,51):
129 | init = (idx - 1) * annotation_length[idx-1]
130 | till = idx * annotation_length[idx-1]
131 | user_score = []
132 | for row in annot[init:till]:
133 | curr_user_score = row[2].split(",")
134 | curr_user_score = np.array([float(num) for num in curr_user_score])
135 | curr_user_score = curr_user_score / curr_user_score.max(initial=-1)
136 | curr_user_score = curr_user_score[::15]
137 |
138 | user_score.append(curr_user_score)
139 | user_scores.append(user_score)
140 | return user_scores
141 | elif dataset=='SumMe':
142 | return None
143 | else:
144 | raise
145 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import shutil
4 | import torch
5 |
6 | from tqdm import tqdm
7 |
8 | from config import get_config
9 | from dataset import create_dataloader
10 | from evaluation_metrics import get_corr_coeff
11 | from generate_summary import generate_summary
12 | from model import set_model
13 | from utils import report_params, print_args, get_gt
14 |
15 | # Load configurations
16 | config = get_config()
17 |
18 | # Print information of setting
19 | print_args(config)
20 |
21 | # Print the number of parameters
22 | report_params(
23 | model_name=config.model_name,
24 | Scale=config.Scale,
25 | Softmax_axis=config.Softmax_axis,
26 | Balance=config.Balance,
27 | Positional_encoding=config.Positional_encoding,
28 | Positional_encoding_shape=config.Positional_encoding_shape,
29 | Positional_encoding_way=config.Positional_encoding_way,
30 | Dropout_on=config.Dropout_on,
31 | Dropout_ratio=config.Dropout_ratio,
32 | Classifier_on=config.Classifier_on,
33 | CLS_on=config.CLS_on,
34 | CLS_mix=config.CLS_mix,
35 | key_value_emb=config.key_value_emb,
36 | Skip_connection=config.Skip_connection,
37 | Layernorm=config.Layernorm
38 | )
39 |
40 | # Start training
41 | for dataset in tqdm(config.datasets,total=len(config.datasets),ncols=70,leave=True,position=0):
42 | user_scores = get_gt(dataset)
43 |
44 | if dataset=='SumMe':
45 | batch_size = 1 if config.batch_size=='1' else int(config.SumMe_len*0.8*float(config.batch_size))
46 | elif dataset=='TVSum':
47 | batch_size = 1 if config.batch_size=='1' else int(config.TVSum_len*0.8*float(config.batch_size))
48 |
49 | for split_id,(train_loader,test_loader) in tqdm(enumerate(create_dataloader(dataset)),total=5,ncols=70,leave=False,position=1,desc=dataset):
50 | model = set_model(
51 | model_name=config.model_name,
52 | Scale=config.Scale,
53 | Softmax_axis=config.Softmax_axis,
54 | Balance=config.Balance,
55 | Positional_encoding=config.Positional_encoding,
56 | Positional_encoding_shape=config.Positional_encoding_shape,
57 | Positional_encoding_way=config.Positional_encoding_way,
58 | Dropout_on=config.Dropout_on,
59 | Dropout_ratio=config.Dropout_ratio,
60 | Classifier_on=config.Classifier_on,
61 | CLS_on=config.CLS_on,
62 | CLS_mix=config.CLS_mix,
63 | key_value_emb=config.key_value_emb,
64 | Skip_connection=config.Skip_connection,
65 | Layernorm=config.Layernorm
66 | )
67 | model.to(config.device)
68 | criterion = torch.nn.MSELoss()
69 | optimizer = torch.optim.Adam(model.parameters(),lr=float(config.learning_rate),weight_decay=float(config.weight_decay))
70 |
71 | model_selection_kendall = -1
72 | model_selection_spear = -1
73 |
74 | for epoch in tqdm(range(config.epochs),total=config.epochs,ncols=70,leave=False,position=2,desc=f'Split{split_id+1}'):
75 | model.train()
76 | update_loss = 0.0
77 | batch = 0
78 |
79 | for feature,gtscore,dataset_name,video_num in tqdm(train_loader,ncols=70,leave=False,position=3,desc=f'Epoch{epoch+1}_TRAIN'):
80 | feature = feature.to(config.device)
81 | gtscore = gtscore.to(config.device)
82 | output = model(feature)
83 |
84 | loss = criterion(output,gtscore)
85 | loss.requires_grad_(True)
86 |
87 | update_loss += loss
88 | batch += 1
89 |
90 | if batch==batch_size:
91 | optimizer.zero_grad()
92 | update_loss = update_loss / batch
93 | update_loss.backward()
94 | optimizer.step()
95 | update_loss = 0.0
96 | batch = 0
97 |
98 | if batch>0:
99 | optimizer.zero_grad()
100 | update_loss = update_loss / batch
101 | update_loss.backward()
102 | optimizer.step()
103 | update_loss = 0.0
104 | batch = 0
105 |
106 | val_spears = []
107 | val_kendalls = []
108 | model.eval()
109 | with torch.no_grad():
110 | for feature,gtscore,dataset_name,video_num in tqdm(test_loader,ncols=70,leave=False,position=3,desc=f'Epoch{epoch+1}_TEST'):
111 | feature = feature.to(config.device)
112 | gtscore = gtscore.to(config.device)
113 | output = model(feature)
114 |
115 | if dataset_name in ['SumMe','TVSum']:
116 | with h5py.File(f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5','r') as hdf:
117 | user_summary = np.array(hdf[video_num]['user_summary'])
118 | sb = np.array(hdf[f"{video_num}/change_points"])
119 | n_frames = np.array(hdf[f"{video_num}/n_frames"])
120 | positions = np.array(hdf[f"{video_num}/picks"])
121 | scores = output.squeeze().clone().detach().cpu().numpy().tolist()
122 | summary = generate_summary([sb], [scores], [n_frames], [positions])[0]
123 | if dataset_name=='SumMe':
124 | spear,kendall = get_corr_coeff([summary],[video_num],dataset_name,user_summary)
125 | elif dataset_name=='TVSum':
126 | spear,kendall = get_corr_coeff([scores],[video_num],dataset_name,user_scores)
127 |
128 | val_spears.append(spear)
129 | val_kendalls.append(kendall)
130 |
131 | if np.mean(val_kendalls) > model_selection_kendall and np.mean(val_spears) > model_selection_spear:
132 | model_selection_kendall = np.mean(val_kendalls)
133 | model_selection_spear = np.mean(val_spears)
134 | torch.save(model.state_dict(), './tmp/weight.pt')
135 | shutil.move('./tmp/weight.pt', f'./weights/{dataset}/split{split_id+1}.pt')
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CSTA: CNN-based Spatiotemporal Attention for Video Summarization (CVPR 2024 paper)
2 | [](https://paperswithcode.com/sota/supervised-video-summarization-on-summe?p=csta-cnn-based-spatiotemporal-attention-for)
3 | [](https://paperswithcode.com/sota/supervised-video-summarization-on-tvsum?p=csta-cnn-based-spatiotemporal-attention-for)
4 |
5 | The official code of "CSTA: CNN-based Spatiotemporal Attention for Video Summarization" [[paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Son_CSTA_CNN-based_Spatiotemporal_Attention_for_Video_Summarization_CVPR_2024_paper.pdf)] [[arXiv](https://arxiv.org/pdf/2405.11905)]
6 | 
7 |
8 | * [Model overview](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#model-overview)
9 | * [Updates](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#updates)
10 | * [Requirements](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#requirements)
11 | * [Data](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#data)
12 | * [Pre-trained models](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#pre-trained-models)
13 | * [Training](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#training)
14 | * [Inference](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#inference)
15 | * [Generate summary videos](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#generate-summary-videos)
16 | * [Citation](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#citation)
17 | * [Acknowledgement](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#acknowledgement)
18 |
19 | # Model overview
20 | 
21 |
22 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
23 |
24 | # Updates
25 | * [2024.03.24] Create a repository.
26 | * [2024.05.21] Update the code and pre-trained models.
27 | * [2024.07.18] Upload the code to generate summary videos, including custom videos.
28 | * [2024.07.21] Update the KTS code for full frames of videos.
29 | * [2024.07.23] Update the code to use only the CPU.
30 | * [2024.12.30] Add tqdm to see the progress generating summary videos
31 | * (Yet) [2025.01.??] Add detailed explanations and comments for the code.
32 |
33 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
34 |
35 | # Requirements
36 | |Ubuntu|GPU|CUDA|cuDNN|conda|python|
37 | |:---:|:---:|:---:|:---:|:---:|:---:|
38 | |20.04.6 LTS|NVIDIA GeForce RTX 4090|12.1|8902|4.9.2|3.8.5|
39 |
40 | |h5py|numpy|scipy|torch|torchvision|tqdm|
41 | |:---:|:---:|:---:|:---:|:---:|:---:|
42 | |3.1.0|1.19.5|1.5.2|2.2.1|0.17.1|4.61.0|
43 |
44 | ```
45 | conda create -n CSTA python=3.8.5
46 | conda activate CSTA
47 | git clone https://github.com/thswodnjs3/CSTA.git
48 | cd CSTA
49 | pip install -r requirements.txt
50 | ```
51 |
52 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
53 |
54 | # Data
55 | ~~Link: [Dataset](https://drive.google.com/drive/folders/1iGfKZxexQfOxyIaOWhfU0P687dJq_KWF?usp=drive_link)
~~
56 | - I'm very sorry, but the dataset is no longer available. You can download it [here (PGL-SUM).](https://github.com/e-apostolidis/PGL-SUM)
57 | H5py format of two benchmark video summarization preprocessed datasets (SumMe, TVSum).
58 | You should download datasets and put them in ```data/``` directory.
59 | The structure of the directory must be like below.
60 | ```
61 | ├── data
62 | └── eccv16_dataset_summe_google_pool5.h5
63 | └── eccv16_dataset_tvsum_google_pool5.h5
64 | ```
65 | You can see the details of both datasets below.
66 |
67 | [SumMe](https://link.springer.com/chapter/10.1007/978-3-319-10584-0_33)
68 | [TVSum](https://openaccess.thecvf.com/content_cvpr_2015/papers/Song_TVSum_Summarizing_Web_2015_CVPR_paper.pdf)
69 |
70 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
71 |
72 | # Pre-trained models
73 | ~~Link: [Weights](https://drive.google.com/drive/folders/1Z0WV_IJAHXV16sAGW7TmC9J_iFZQ9NSs?usp=drive_link)~~
74 | - I'm very sorry, but the pre-trained weights are no longer available.
75 | I accidentally deleted it, but it can be implemented by using the same seed number below (123456).
76 | If you can't see the similar performance with the same seed, then please contact me.
77 |
78 | You can download our pre-trained weights of CSTA.
79 | There are 5 weights for the SumMe dataset and the other 5 for the TVSum dataset(1 weight for each split).
80 | As shown in the paper, we tested everything 10 times (without fixation of seed) but only uploaded a single model as a representative for your convenience.
81 | The uploaded weight is acquired when the seed is 123456, and the result is almost identical to our paper.
82 | You should put 5 weights of the SumMe in ```weights/SumMe``` and the other 5 weights of the TVSum in ```weights/TVSum```.
83 | The structure of the directory must be like below.
84 | ```
85 | ├── weights
86 | └── SumMe
87 | ├── split1.pt
88 | ├── split2.pt
89 | ├── split3.pt
90 | ├── split4.pt
91 | ├── split5.pt
92 | └── TVSum
93 | ├── split1.pt
94 | ├── split2.pt
95 | ├── split3.pt
96 | ├── split4.pt
97 | ├── split5.pt
98 | ```
99 |
100 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
101 |
102 | # Training
103 | You can train the final version of our models by command below.
104 | ```
105 | python train.py
106 | ```
107 | Detailed explanations for all configurations will be updated later.
108 |
109 | ## You can't reproduce our result perfectly.
110 | As shown in the paper, we tested every experiment 10 times without fixation of the seed, so we can't be sure which seeds export the same results.
111 | Even though you set the seed 123456, which is the same as our pre-trained models, it may result in different results due to the non-deterministic property of the [Adaptive Average Pooling layer](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms).
112 | Based on my knowledge, non-deterministic operations produce random results even with the same seed. [You can see details here.](https://pytorch.org/docs/stable/notes/randomness.html)
113 | However, you can get similar results with the pre-trained models when you set the seed as 123456, so I hope this will be helpful for you.
114 |
115 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
116 |
117 | # Inference
118 | You can see the final performance of the models by command below.
119 | ```
120 | python inference.py
121 | ```
122 | All weight files should be located in the position I said above.
123 |
124 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
125 |
126 | # Generate summary videos
127 | You can generate summary videos using our models.
128 | You can use either videos from public datasets or custom videos.
129 | With the code below, you can apply our pre-trained models to raw videos to produce summary videos.
130 | ```
131 | python generate_video.py --input_is_file True or False
132 | --file_path 'path to input video'
133 | --dir_path 'directory of input videos'
134 | --ext 'video file extension'
135 | --save_path 'path to save summary video'
136 | --weight_path 'path to loaded weights'
137 |
138 | e.g.
139 | 1)Using a directory
140 | python generate_video.py --input_is_file False --dir_path './videos' --ext 'mp4' --save_path './summary_videos' --weight_path './weights/SumMe/split4.pt'
141 |
142 | 2)Using a single video file
143 | python generate_video.py --input_is_file True --file_path './videos/Jumps.mp4' --save_path './summary_videos' --weight_path './weights/SumMe/split4.pt'
144 | ```
145 | The explanation of the arguments is as follows.
146 | If you change the 'ext' argument and input a directory of videos, you must modify the ['fourcc'](https://github.com/thswodnjs3/CSTA/blob/7227ee36a460b0bdc4aa83cb446223779365df45/generate_video.py#L34) variable in the 'produce_video' function within the 'generate_video.py' file.
147 | Additionally, you must update this when inputting a single video file with different extensions other than 'mp4'.
148 | ```
149 | 1. input_is_file (bool): True or False
150 | Indicates whether the input is a file or a directory.
151 | If this is True, the 'file_path' argument is required.
152 | If this is False, the 'dir_path' and 'ext' arguments are required.
153 |
154 | 2. file_path (str) e.g. './SumMe/Jumps.mp4'
155 | The path of the video file.
156 | This is only used when 'input_is_file' is True.
157 |
158 | 3. dir_path (str) e.g. './SumMe'
159 | The path of the directory where video files are located.
160 | This is only used when 'input_is_file' is False.
161 |
162 | 4. ext (str) e.g. 'mp4'
163 | The file extension of the video files.
164 | This is only used when 'input_is_file' is False.
165 |
166 | 5. sample_rate (int) e.g. 15
167 | The interval between selected frames in a video.
168 | For example, if the video has 30 fps, it will become 2 fps with a sample_rate of 15.
169 |
170 | 6. save_path (str) e.g. './summary_videos'
171 | The path where the summary videos are saved.
172 |
173 | 7. weight_path (str) e.g. './weights/SumMe/split4.pt'
174 | The path where the model weights are loaded from.
175 | ```
176 | We referenced the KTS code from [DSNet](https://github.com/li-plus/DSNet).
177 | However, they applied KTS to downsampled videos (2 fps), which can result in different shot change points and sometimes make it impossible to summarize videos.
178 | We revised it to calculate change points based on the entire frames.
179 |
180 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
181 |
182 | # Citation
183 | If you find our code or our paper useful, please click [★star] for this repo and [cite] the following paper:
184 | ```
185 | @inproceedings{son2024csta,
186 | title={CSTA: CNN-based Spatiotemporal Attention for Video Summarization},
187 | author={Son, Jaewon and Park, Jaehun and Kim, Kwangsu},
188 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
189 | pages={18847--18856},
190 | year={2024}
191 | }
192 | ```
193 |
194 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
195 |
196 | # Acknowledgement
197 | We especially, sincerely appreciate the authors of PosENet, RR-STG who responded to our requests very kindly.
198 | Below are the papers we referenced for the code.
199 |
200 | A2Summ - [paper](https://arxiv.org/pdf/2303.07284), [code](https://github.com/boheumd/A2Summ)
201 | CA-SUM - [paper](https://www.iti.gr/~bmezaris/publications/icmr2022_preprint.pdf), [code](https://github.com/e-apostolidis/CA-SUM)
202 | DSNet - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9275314), [code](https://github.com/li-plus/DSNet)
203 | iPTNet - [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Jiang_Joint_Video_Summarization_and_Moment_Localization_by_Cross-Task_Sample_Transfer_CVPR_2022_paper.pdf)
204 | MSVA - [paper](https://arxiv.org/pdf/2104.11530), [code](https://github.com/TIBHannover/MSVA)
205 | PGL-SUM - [paper](https://www.iti.gr/~bmezaris/publications/ism2021a_preprint.pdf), [code](https://github.com/e-apostolidis/PGL-SUM)
206 | PosENet - [paper](https://arxiv.org/pdf/2001.08248), [code](https://github.com/islamamirul/position_information)
207 | RR-STG - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9750933&tag=1)
208 | SSPVS - [paper](https://arxiv.org/pdf/2201.02494), [code](https://github.com/HopLee6/SSPVS-PyTorch)
209 | STVT - [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10124837), [code](https://github.com/nchucvml/STVT)
210 | VASNet - [paper](https://arxiv.org/pdf/1812.01969), [code](https://github.com/ok1zjf/VASNet)
211 | VJMHT - [paper](https://arxiv.org/pdf/2112.13478), [code](https://github.com/HopLee6/VJMHT-PyTorch)
212 |
213 | ```
214 | @inproceedings{he2023a2summ,
215 | title = {Align and Attend: Multimodal Summarization with Dual Contrastive Losses},
216 | author={He, Bo and Wang, Jun and Qiu, Jielin and Bui, Trung and Shrivastava, Abhinav and Wang, Zhaowen},
217 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
218 | year = {2023}
219 | }
220 | ```
221 | ```
222 | @inproceedings{10.1145/3512527.3531404,
223 | author = {Apostolidis, Evlampios and Balaouras, Georgios and Mezaris, Vasileios and Patras, Ioannis},
224 | title = {Summarizing Videos Using Concentrated Attention and Considering the Uniqueness and Diversity of the Video Frames},
225 | year = {2022},
226 | isbn = {9781450392389},
227 | publisher = {Association for Computing Machinery},
228 | address = {New York, NY, USA},
229 | url = {https://doi.org/10.1145/3512527.3531404},
230 | doi = {10.1145/3512527.3531404},
231 | pages = {407-415},
232 | numpages = {9},
233 | keywords = {frame diversity, frame uniqueness, concentrated attention, unsupervised learning, video summarization},
234 | location = {Newark, NJ, USA},
235 | series = {ICMR '22}
236 | }
237 | ```
238 | ```
239 | @article{zhu2020dsnet,
240 | title={DSNet: A Flexible Detect-to-Summarize Network for Video Summarization},
241 | author={Zhu, Wencheng and Lu, Jiwen and Li, Jiahao and Zhou, Jie},
242 | journal={IEEE Transactions on Image Processing},
243 | volume={30},
244 | pages={948--962},
245 | year={2020}
246 | }
247 | ```
248 | ```
249 | @inproceedings{jiang2022joint,
250 | title={Joint video summarization and moment localization by cross-task sample transfer},
251 | author={Jiang, Hao and Mu, Yadong},
252 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
253 | pages={16388--16398},
254 | year={2022}
255 | }
256 | ```
257 | ```
258 | @article{ghauri2021MSVA,
259 | title={SUPERVISED VIDEO SUMMARIZATION VIA MULTIPLE FEATURE SETS WITH PARALLEL ATTENTION},
260 | author={Ghauri, Junaid Ahmed and Hakimov, Sherzod and Ewerth, Ralph},
261 | Conference={IEEE International Conference on Multimedia and Expo (ICME)},
262 | year={2021}
263 | }
264 | ```
265 | ```
266 | @INPROCEEDINGS{9666088,
267 | author = {Apostolidis, Evlampios and Balaouras, Georgios and Mezaris, Vasileios and Patras, Ioannis},
268 | title = {Combining Global and Local Attention with Positional Encoding for Video Summarization},
269 | booktitle = {2021 IEEE International Symposium on Multimedia (ISM)},
270 | month = {December},
271 | year = {2021},
272 | pages = {226-234}
273 | }
274 | ```
275 | ```
276 | @InProceedings{islam2020position,
277 | title={How much Position Information Do Convolutional Neural Networks Encode?},
278 | author={Islam, Md Amirul and Jia, Sen and Bruce, Neil},
279 | booktitle={International Conference on Learning Representations},
280 | year={2020}
281 | }
282 | ```
283 | ```
284 | @article{zhu2022relational,
285 | title={Relational reasoning over spatial-temporal graphs for video summarization},
286 | author={Zhu, Wencheng and Han, Yucheng and Lu, Jiwen and Zhou, Jie},
287 | journal={IEEE Transactions on Image Processing},
288 | volume={31},
289 | pages={3017--3031},
290 | year={2022},
291 | publisher={IEEE}
292 | }
293 | ```
294 | ```
295 | @inproceedings{li2023progressive,
296 | title={Progressive Video Summarization via Multimodal Self-supervised Learning},
297 | author={Li, Haopeng and Ke, Qiuhong and Gong, Mingming and Drummond, Tom},
298 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
299 | pages={5584--5593},
300 | year={2023}
301 | }
302 | ```
303 | ```
304 | @article{hsu2023video,
305 | title={Video summarization with spatiotemporal vision transformer},
306 | author={Hsu, Tzu-Chun and Liao, Yi-Sheng and Huang, Chun-Rong},
307 | journal={IEEE Transactions on Image Processing},
308 | year={2023},
309 | publisher={IEEE}
310 | }
311 | ```
312 | ```
313 | @misc{fajtl2018summarizing,
314 | title={Summarizing Videos with Attention},
315 | author={Jiri Fajtl and Hajar Sadeghi Sokeh and Vasileios Argyriou and Dorothy Monekosso and Paolo Remagnino},
316 | year={2018},
317 | eprint={1812.01969},
318 | archivePrefix={arXiv},
319 | primaryClass={cs.CV}
320 | }
321 | ```
322 | ```
323 | @article{li2022video,
324 | title={Video Joint Modelling Based on Hierarchical Transformer for Co-summarization},
325 | author={Li, Haopeng and Ke, Qiuhong and Gong, Mingming and Zhang, Rui},
326 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
327 | year={2022},
328 | publisher={IEEE}
329 | }
330 | ```
331 |
332 | [Back to top](https://github.com/thswodnjs3/CSTA?tab=readme-ov-file#csta-cnn-based-spatiotemporal-attention-for-video-summarization-cvpr-2024-paper)↑
333 |
--------------------------------------------------------------------------------
/models/GoogleNet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch import Tensor
7 | from typing import Any, Callable, List, Optional, Tuple
8 |
9 | from models.positional_encoding import FixedPositionalEncoding,LearnablePositionalEncoding,RelativePositionalEncoding,ConditionalPositionalEncoding
10 |
11 | # Edit GoogleNet by replacing last parts with adaptive average pooling layers
12 | class GoogleNet_Att(nn.Module):
13 | __constants__ = ["aux_logits", "transform_input"]
14 |
15 | def __init__(
16 | self,
17 | num_classes: int = 1000,
18 | init_weights: Optional[bool] = None
19 | ) -> None:
20 | super().__init__()
21 | conv_block = BasicConv2d
22 | inception_block = Inception
23 |
24 | self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
25 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
26 | self.conv2 = conv_block(64, 64, kernel_size=1)
27 | self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
28 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
29 |
30 | self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
31 | self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
32 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
33 |
34 | self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
35 | self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
36 | self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
37 | self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
38 | self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
39 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
40 |
41 | self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
42 | self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
43 |
44 | self.fc = nn.Linear(1024, num_classes)
45 |
46 | if init_weights:
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
49 | torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
50 | elif isinstance(m, nn.BatchNorm2d):
51 | nn.init.constant_(m.weight, 1)
52 | nn.init.constant_(m.bias, 0)
53 |
54 | def _forward(self, x: Tensor, n_frame) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
55 | x = self.conv1(x)
56 | x = self.maxpool1(x)
57 | x = self.conv2(x)
58 | x = self.conv3(x)
59 | x = self.maxpool2(x)
60 |
61 | x = self.inception3a(x)
62 | x = self.inception3b(x)
63 | x = self.maxpool3(x)
64 | x = self.inception4a(x)
65 |
66 | x = self.inception4b(x)
67 | x = self.inception4c(x)
68 | x = self.inception4d(x)
69 |
70 | x = self.inception4e(x)
71 | x = self.maxpool4(x)
72 | x = self.inception5a(x)
73 | x = self.inception5b(x)
74 |
75 | ##############################################################################
76 | # The place I edit to resize feature maps, and to handle various lengths of input videos
77 | ##############################################################################
78 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1))
79 | x = self.avgpool(x)
80 | x = torch.squeeze(x)
81 | x = x.permute(1,0)
82 | return x
83 |
84 | def forward(self, x: Tensor):
85 | ##############################################################################
86 | # Takes the number of frames to handle various lengths of input videos
87 | ##############################################################################
88 | n_frame = x.shape[2]
89 | x = self._forward(x,n_frame)
90 | return x
91 |
92 | class Inception(nn.Module):
93 | def __init__(
94 | self,
95 | in_channels: int,
96 | ch1x1: int,
97 | ch3x3red: int,
98 | ch3x3: int,
99 | ch5x5red: int,
100 | ch5x5: int,
101 | pool_proj: int,
102 | conv_block: Optional[Callable[..., nn.Module]] = None,
103 | ) -> None:
104 | super().__init__()
105 | if conv_block is None:
106 | conv_block = BasicConv2d
107 | self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
108 |
109 | self.branch2 = nn.Sequential(
110 | conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
111 | )
112 |
113 | self.branch3 = nn.Sequential(
114 | conv_block(in_channels, ch5x5red, kernel_size=1),
115 | conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
116 | )
117 |
118 | self.branch4 = nn.Sequential(
119 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
120 | conv_block(in_channels, pool_proj, kernel_size=1),
121 | )
122 |
123 | def _forward(self, x: Tensor) -> List[Tensor]:
124 | branch1 = self.branch1(x)
125 | branch2 = self.branch2(x)
126 | branch3 = self.branch3(x)
127 | branch4 = self.branch4(x)
128 |
129 | outputs = [branch1, branch2, branch3, branch4]
130 | return outputs
131 |
132 | def forward(self, x: Tensor) -> Tensor:
133 | outputs = self._forward(x)
134 | return torch.cat(outputs, 1)
135 |
136 | class BasicConv2d(nn.Module):
137 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
138 | super().__init__()
139 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
140 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
141 |
142 | def forward(self, x: Tensor) -> Tensor:
143 | x = self.conv(x)
144 | x = self.bn(x)
145 | return F.relu(x, inplace=True)
146 |
147 | ##############################################################################
148 | # Define our proposed model
149 | ##############################################################################
150 | class CSTA_GoogleNet(nn.Module):
151 | def __init__(self,
152 | model_name,
153 | Scale,
154 | Softmax_axis,
155 | Balance,
156 | Positional_encoding,
157 | Positional_encoding_shape,
158 | Positional_encoding_way,
159 | Dropout_on,
160 | Dropout_ratio,
161 | Classifier_on,
162 | CLS_on,
163 | CLS_mix,
164 | key_value_emb,
165 | Skip_connection,
166 | Layernorm,
167 | dim=1024):
168 | super().__init__()
169 | self.googlenet = GoogleNet_Att()
170 |
171 | self.model_name = model_name
172 | self.Scale = Scale
173 | self.Softmax_axis = Softmax_axis
174 | self.Balance = Balance
175 |
176 | self.Positional_encoding = Positional_encoding
177 | self.Positional_encoding_shape = Positional_encoding_shape
178 | self.Positional_encoding_way = Positional_encoding_way
179 | self.Dropout_on = Dropout_on
180 | self.Dropout_ratio = Dropout_ratio
181 |
182 | self.Classifier_on = Classifier_on
183 | self.CLS_on = CLS_on
184 | self.CLS_mix = CLS_mix
185 |
186 | self.key_value_emb = key_value_emb
187 | self.Skip_connection = Skip_connection
188 | self.Layernorm = Layernorm
189 |
190 | self.dim = dim
191 |
192 | if self.Positional_encoding is not None:
193 | if self.Positional_encoding=='FPE':
194 | self.Positional_encoding_op = FixedPositionalEncoding(
195 | Positional_encoding_shape=self.Positional_encoding_shape,
196 | dim=self.dim
197 | )
198 | elif self.Positional_encoding=='RPE':
199 | self.Positional_encoding_op = RelativePositionalEncoding(
200 | Positional_encoding_shape=self.Positional_encoding_shape,
201 | dim=self.dim
202 | )
203 | elif self.Positional_encoding=='LPE':
204 | self.Positional_encoding_op = LearnablePositionalEncoding(
205 | Positional_encoding_shape=self.Positional_encoding_shape,
206 | dim=self.dim
207 | )
208 | elif self.Positional_encoding=='CPE':
209 | self.Positional_encoding_op = ConditionalPositionalEncoding(
210 | Positional_encoding_shape=self.Positional_encoding_shape,
211 | Positional_encoding_way=self.Positional_encoding_way,
212 | dim=self.dim
213 | )
214 | elif self.Positional_encoding is None:
215 | pass
216 | else:
217 | raise
218 |
219 | if self.Positional_encoding_way=='Transformer':
220 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim)
221 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
222 | pass
223 | else:
224 | raise
225 |
226 | if self.Dropout_on:
227 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio))
228 |
229 | if self.Classifier_on:
230 | self.linear1 = nn.Sequential(
231 | nn.Linear(in_features=self.dim, out_features=self.dim),
232 | nn.ReLU(),
233 | nn.Dropout(p=0.5),
234 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6)
235 | )
236 | self.linear2 = nn.Sequential(
237 | nn.Linear(in_features=self.dim, out_features=1),
238 | nn.Sigmoid()
239 | )
240 |
241 | for name,param in self.named_parameters():
242 | if name in ['linear1.0.weight','linear2.0.weight']:
243 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0))
244 | elif name in ['linear1.0.bias','linear2.0.bias']:
245 | nn.init.constant_(param, 0.1)
246 | else:
247 | self.gap = nn.AdaptiveAvgPool1d(1)
248 |
249 | if self.CLS_on:
250 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024))
251 |
252 | if self.key_value_emb is not None:
253 | if self.key_value_emb.lower()=='k':
254 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
255 | elif self.key_value_emb.lower()=='v':
256 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim)
257 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv':
258 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
259 | if self.model_name=='GoogleNet_Attention':
260 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim)
261 | else:
262 | raise
263 |
264 | if self.Layernorm:
265 | if self.Skip_connection=='KC':
266 | self.layernorm1 = nn.BatchNorm2d(num_features=1)
267 | elif self.Skip_connection=='CF':
268 | self.layernorm2 = nn.BatchNorm2d(num_features=1)
269 | elif self.Skip_connection=='IF':
270 | self.layernorm3 = nn.BatchNorm2d(num_features=1)
271 | elif self.Skip_connection is None:
272 | pass
273 | else:
274 | raise
275 |
276 | def forward(self, x):
277 | # Take the number of frames
278 | n_frame = x.shape[2]
279 |
280 | # Linear projection if using CLS token as transformer ways
281 | if self.Positional_encoding_way=='Transformer':
282 | x = self.Positional_encoding_embedding(x)
283 | # Stack CLS token
284 | if self.CLS_on:
285 | x = torch.cat((self.CLS,x),dim=2)
286 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim))
287 |
288 | # Positional encoding (Transformer ways)
289 | if self.Positional_encoding_way=='Transformer':
290 | if self.Positional_encoding is not None:
291 | x = self.Positional_encoding_op(x)
292 | # Dropout (Transformer ways)
293 | if self.Dropout_on:
294 | x = self.dropout(x)
295 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
296 | pass
297 | else:
298 | raise
299 |
300 | # Key Embedding
301 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']:
302 | key = self.key_embedding(x)
303 | elif self.key_value_emb is None:
304 | key = x
305 | else:
306 | raise
307 |
308 | # CNN as attention algorithm
309 | x_att = self.googlenet(key)
310 |
311 | # Skip connection (KC)
312 | if self.Skip_connection is not None:
313 | if self.Skip_connection=='KC':
314 | x_att = x_att + key.squeeze(0)[0]
315 | if self.Layernorm:
316 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
317 | elif self.Skip_connection in ['CF','IF']:
318 | pass
319 | else:
320 | raise
321 | elif self.Skip_connection is None:
322 | pass
323 | else:
324 | raise
325 |
326 | # Combine CLS token (CNN)
327 | if self.CLS_on:
328 | if self.CLS_mix=='CNN':
329 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0)
330 | x = CT_adjust(x.squeeze(0)).unsqueeze(0)
331 | elif self.CLS_mix in ['SM','Final']:
332 | pass
333 | else:
334 | raise
335 | else:
336 | pass
337 |
338 | # Scaling factor
339 | if self.Scale is not None:
340 | if self.Scale=='D':
341 | scaling_factor = x_att.shape[1]
342 | elif self.Scale=='T':
343 | scaling_factor = x_att.shape[0]
344 | elif self.Scale=='T_D':
345 | scaling_factor = x_att.shape[0] * x_att.shape[1]
346 | else:
347 | raise
348 | scaling_factor = scaling_factor ** 0.5
349 | x_att = x_att / scaling_factor
350 | elif self.Scale is None:
351 | pass
352 |
353 | # Positional encoding (PGL-SUM ways)
354 | if self.Positional_encoding_way=='PGL_SUM':
355 | if self.Positional_encoding is not None:
356 | x_att = self.Positional_encoding_op(x_att)
357 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None:
358 | pass
359 | else:
360 | raise
361 |
362 | # softmax_axis
363 | x = x.squeeze(0)[0]
364 | if self.Softmax_axis=='T':
365 | temporal_attention = F.softmax(x_att,dim=0)
366 | elif self.Softmax_axis=='D':
367 | spatial_attention = F.softmax(x_att,dim=1)
368 | elif self.Softmax_axis=='TD':
369 | temporal_attention = F.softmax(x_att,dim=0)
370 | spatial_attention = F.softmax(x_att,dim=1)
371 | elif self.Softmax_axis is None:
372 | pass
373 | else:
374 | raise
375 |
376 | # Combine CLS token for softmax outputs (SM)
377 | if self.CLS_on:
378 | if self.CLS_mix=='SM':
379 | if self.Softmax_axis=='T':
380 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
381 | elif self.Softmax_axis=='D':
382 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
383 | elif self.Softmax_axis=='TD':
384 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
385 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
386 | elif self.Softmax_axis is None:
387 | pass
388 | else:
389 | raise
390 | elif self.CLS_mix in ['CNN','Final']:
391 | pass
392 | else:
393 | raise
394 | else:
395 | pass
396 |
397 | # Dropout (PGL-SUM ways)
398 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM':
399 | if self.Softmax_axis=='T':
400 | temporal_attention = self.dropout(temporal_attention)
401 | elif self.Softmax_axis=='D':
402 | spatial_attention = self.dropout(spatial_attention)
403 | elif self.Softmax_axis=='TD':
404 | temporal_attention = self.dropout(temporal_attention)
405 | spatial_attention = self.dropout(spatial_attention)
406 | elif self.Softmax_axis is None:
407 | pass
408 | else:
409 | raise
410 |
411 | # Value Embedding
412 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']:
413 | if self.model_name=='GoogleNet_Attention':
414 | x_out = self.value_embedding(x)
415 | elif self.model_name=='GoogleNet':
416 | x_out = x_att
417 | else:
418 | raise
419 | elif self.key_value_emb is None:
420 | if self.model_name=='GoogleNet':
421 | x_out = x_att
422 | elif self.model_name=='GoogleNet_Attention':
423 | x_out = x
424 | else:
425 | raise
426 | else:
427 | raise
428 |
429 | # Combine CLS token for CNN outputs (SM)
430 | if self.CLS_on:
431 | if self.CLS_mix=='SM':
432 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0)
433 |
434 | # Apply Attention maps to input frame features
435 | if self.Softmax_axis=='T':
436 | x_out = x_out * temporal_attention
437 | elif self.Softmax_axis=='D':
438 | x_out = x_out * spatial_attention
439 | elif self.Softmax_axis=='TD':
440 | T,D = x_out.shape
441 | adjust_frame = T/D
442 | adjust_dimension = D/T
443 | if self.Balance=='T':
444 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
445 | elif self.Balance=='D':
446 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
447 | elif self.Balance=='BD':
448 | if T>D:
449 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
450 | elif TD:
456 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
457 | elif T None:
22 | super().__init__()
23 | self.stride = stride
24 | if stride not in [1, 2]:
25 | raise ValueError(f"stride should be 1 or 2 instead of {stride}")
26 |
27 | if norm_layer is None:
28 | norm_layer = nn.BatchNorm2d
29 |
30 | hidden_dim = int(round(inp * expand_ratio))
31 | self.use_res_connect = self.stride == 1 and inp == oup
32 |
33 | layers: List[nn.Module] = []
34 | if expand_ratio != 1:
35 | # pw
36 | layers.append(
37 | Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
38 | )
39 | layers.extend(
40 | [
41 | # dw
42 | Conv2dNormActivation(
43 | hidden_dim,
44 | hidden_dim,
45 | stride=stride,
46 | groups=hidden_dim,
47 | norm_layer=norm_layer,
48 | activation_layer=nn.ReLU6,
49 | ),
50 | # pw-linear
51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
52 | norm_layer(oup),
53 | ]
54 | )
55 | self.conv = nn.Sequential(*layers)
56 | self.out_channels = oup
57 | self._is_cn = stride > 1
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | if self.use_res_connect:
61 | return x + self.conv(x)
62 | else:
63 | return self.conv(x)
64 |
65 | # MobileNet as attention
66 | class MobileNet_Att(nn.Module):
67 | def __init__(
68 | self,
69 | num_classes: int = 1000,
70 | width_mult: float = 1.0,
71 | inverted_residual_setting: Optional[List[List[int]]] = None,
72 | round_nearest: int = 8,
73 | block: Optional[Callable[..., nn.Module]] = None,
74 | norm_layer: Optional[Callable[..., nn.Module]] = None,
75 | dropout: float = 0.2,
76 | ) -> None:
77 | """
78 | MobileNet V2 main class
79 |
80 | Args:
81 | num_classes (int): Number of classes
82 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
83 | inverted_residual_setting: Network structure
84 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
85 | Set to 1 to turn off rounding
86 | block: Module specifying inverted residual building block for mobilenet
87 | norm_layer: Module specifying the normalization layer to use
88 | dropout (float): The droupout probability
89 |
90 | """
91 | super().__init__()
92 | _log_api_usage_once(self)
93 |
94 | if block is None:
95 | block = InvertedResidual
96 |
97 | if norm_layer is None:
98 | norm_layer = nn.BatchNorm2d
99 |
100 | input_channel = 32
101 | last_channel = 1280
102 |
103 | if inverted_residual_setting is None:
104 | inverted_residual_setting = [
105 | # t, c, n, s
106 | [1, 16, 1, 1],
107 | [6, 24, 2, 2],
108 | [6, 32, 3, 2],
109 | [6, 64, 4, 2],
110 | [6, 96, 3, 1],
111 | [6, 160, 3, 2],
112 | [6, 320, 1, 1],
113 | ]
114 |
115 | # only check the first element, assuming user knows t,c,n,s are required
116 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
117 | raise ValueError(
118 | f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
119 | )
120 |
121 | # building first layer
122 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
123 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
124 | features: List[nn.Module] = [
125 | Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
126 | ]
127 | # building inverted residual blocks
128 | for t, c, n, s in inverted_residual_setting:
129 | output_channel = _make_divisible(c * width_mult, round_nearest)
130 | for i in range(n):
131 | stride = s if i == 0 else 1
132 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
133 | input_channel = output_channel
134 | # building last several layers
135 | features.append(
136 | Conv2dNormActivation(
137 | input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
138 | )
139 | )
140 | # make it nn.Sequential
141 | self.features = nn.Sequential(*features)
142 |
143 | # building classifier
144 | self.classifier = nn.Sequential(
145 | nn.Dropout(p=dropout),
146 | nn.Linear(self.last_channel, num_classes),
147 | )
148 |
149 | # weight initialization
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
153 | if m.bias is not None:
154 | nn.init.zeros_(m.bias)
155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
156 | nn.init.ones_(m.weight)
157 | nn.init.zeros_(m.bias)
158 | elif isinstance(m, nn.Linear):
159 | nn.init.normal_(m.weight, 0, 0.01)
160 | nn.init.zeros_(m.bias)
161 |
162 | def _forward_impl(self, x: Tensor, n_frame) -> Tensor:
163 | x = self.features(x)
164 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1))
165 | x = self.avgpool(x)
166 | x = torch.squeeze(x)
167 | x = x.permute(1,0)
168 | return x
169 |
170 | def forward(self, x: Tensor) -> Tensor:
171 | n_frame = x.shape[2]
172 | return self._forward_impl(x, n_frame)
173 |
174 | _COMMON_META = {
175 | "num_params": 3504872,
176 | "min_size": (1, 1),
177 | "categories": _IMAGENET_CATEGORIES,
178 | }
179 |
180 | class MobileNet_V2_Weights(WeightsEnum):
181 | IMAGENET1K_V1 = Weights(
182 | url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
183 | transforms=partial(ImageClassification, crop_size=224),
184 | meta={
185 | **_COMMON_META,
186 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
187 | "_metrics": {
188 | "ImageNet-1K": {
189 | "acc@1": 71.878,
190 | "acc@5": 90.286,
191 | }
192 | },
193 | "_ops": 0.301,
194 | "_file_size": 13.555,
195 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
196 | },
197 | )
198 | IMAGENET1K_V2 = Weights(
199 | url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
200 | transforms=partial(ImageClassification, crop_size=224, resize_size=232),
201 | meta={
202 | **_COMMON_META,
203 | "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
204 | "_metrics": {
205 | "ImageNet-1K": {
206 | "acc@1": 72.154,
207 | "acc@5": 90.822,
208 | }
209 | },
210 | "_ops": 0.301,
211 | "_file_size": 13.598,
212 | "_docs": """
213 | These weights improve upon the results of the original paper by using a modified version of TorchVision's
214 | `new training recipe
215 | `_.
216 | """,
217 | },
218 | )
219 | DEFAULT = IMAGENET1K_V2
220 |
221 | @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
222 | def mobilenet_v2(
223 | *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
224 | ) -> MobileNet_Att:
225 | """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
226 | Bottlenecks `_ paper.
227 |
228 | Args:
229 | weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
230 | pretrained weights to use. See
231 | :class:`~torchvision.models.MobileNet_V2_Weights` below for
232 | more details, and possible values. By default, no pre-trained
233 | weights are used.
234 | progress (bool, optional): If True, displays a progress bar of the
235 | download to stderr. Default is True.
236 | **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
237 | base class. Please refer to the `source code
238 | `_
239 | for more details about this class.
240 |
241 | .. autoclass:: torchvision.models.MobileNet_V2_Weights
242 | :members:
243 | """
244 | weights = MobileNet_V2_Weights.verify(weights)
245 |
246 | if weights is not None:
247 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
248 |
249 | model = MobileNet_Att(**kwargs)
250 |
251 | if weights is not None:
252 | model.load_state_dict(weights.get_state_dict(progress=progress))
253 |
254 | return model
255 |
256 | # MobileNet-based CSTA
257 | class CSTA_MobileNet(nn.Module):
258 | def __init__(self,
259 | model_name,
260 | Scale,
261 | Softmax_axis,
262 | Balance,
263 | Positional_encoding,
264 | Positional_encoding_shape,
265 | Positional_encoding_way,
266 | Dropout_on,
267 | Dropout_ratio,
268 | Classifier_on,
269 | CLS_on,
270 | CLS_mix,
271 | key_value_emb,
272 | Skip_connection,
273 | Layernorm,
274 | dim=1280):
275 | super().__init__()
276 | self.mobilenet = MobileNet_Att()
277 |
278 | self.model_name = model_name
279 | self.Scale = Scale
280 | self.Softmax_axis = Softmax_axis
281 | self.Balance = Balance
282 |
283 | self.Positional_encoding = Positional_encoding
284 | self.Positional_encoding_shape = Positional_encoding_shape
285 | self.Positional_encoding_way = Positional_encoding_way
286 | self.Dropout_on = Dropout_on
287 | self.Dropout_ratio = Dropout_ratio
288 |
289 | self.Classifier_on = Classifier_on
290 | self.CLS_on = CLS_on
291 | self.CLS_mix = CLS_mix
292 |
293 | self.key_value_emb = key_value_emb
294 | self.Skip_connection = Skip_connection
295 | self.Layernorm = Layernorm
296 |
297 | self.dim = dim
298 |
299 | if self.Positional_encoding is not None:
300 | if self.Positional_encoding=='FPE':
301 | self.Positional_encoding_op = FixedPositionalEncoding(
302 | Positional_encoding_shape=self.Positional_encoding_shape,
303 | dim=self.dim
304 | )
305 | elif self.Positional_encoding=='RPE':
306 | self.Positional_encoding_op = RelativePositionalEncoding(
307 | Positional_encoding_shape=self.Positional_encoding_shape,
308 | dim=self.dim
309 | )
310 | elif self.Positional_encoding=='LPE':
311 | self.Positional_encoding_op = LearnablePositionalEncoding(
312 | Positional_encoding_shape=self.Positional_encoding_shape,
313 | dim=self.dim
314 | )
315 | elif self.Positional_encoding=='CPE':
316 | self.Positional_encoding_op = ConditionalPositionalEncoding(
317 | Positional_encoding_shape=self.Positional_encoding_shape,
318 | Positional_encoding_way=self.Positional_encoding_way,
319 | dim=self.dim
320 | )
321 | elif self.Positional_encoding is None:
322 | pass
323 | else:
324 | raise
325 |
326 | if self.Positional_encoding_way=='Transformer':
327 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim)
328 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
329 | pass
330 | else:
331 | raise
332 |
333 | if self.Dropout_on:
334 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio))
335 |
336 | if self.Classifier_on:
337 | self.linear1 = nn.Sequential(
338 | nn.Linear(in_features=self.dim, out_features=self.dim),
339 | nn.ReLU(),
340 | nn.Dropout(p=0.5),
341 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6)
342 | )
343 | self.linear2 = nn.Sequential(
344 | nn.Linear(in_features=self.dim, out_features=1),
345 | nn.Sigmoid()
346 | )
347 |
348 | for name,param in self.named_parameters():
349 | if name in ['linear1.0.weight','linear2.0.weight']:
350 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0))
351 | elif name in ['linear1.0.bias','linear2.0.bias']:
352 | nn.init.constant_(param, 0.1)
353 | else:
354 | self.gap = nn.AdaptiveAvgPool1d(1)
355 |
356 | if self.CLS_on:
357 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024))
358 |
359 | if self.key_value_emb is not None:
360 | if self.key_value_emb.lower()=='k':
361 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
362 | elif self.key_value_emb.lower()=='v':
363 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim)
364 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv':
365 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
366 | if self.model_name=='MobileNet_Attention':
367 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim)
368 | else:
369 | raise
370 |
371 | if self.Layernorm:
372 | if self.Skip_connection=='KC':
373 | self.layernorm1 = nn.BatchNorm2d(num_features=1)
374 | elif self.Skip_connection=='CF':
375 | self.layernorm2 = nn.BatchNorm2d(num_features=1)
376 | elif self.Skip_connection=='IF':
377 | self.layernorm3 = nn.BatchNorm2d(num_features=1)
378 | elif self.Skip_connection is None:
379 | pass
380 | else:
381 | raise
382 |
383 | def forward(self, x):
384 | n_frame = x.shape[2]
385 |
386 | if self.Positional_encoding_way=='Transformer':
387 | x = self.Positional_encoding_embedding(x)
388 | if self.CLS_on:
389 | x = torch.cat((self.CLS,x),dim=2)
390 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim))
391 |
392 | if self.Positional_encoding_way=='Transformer':
393 | if self.Positional_encoding is not None:
394 | x = self.Positional_encoding_op(x)
395 | if self.Dropout_on:
396 | x = self.dropout(x)
397 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
398 | pass
399 | else:
400 | raise
401 |
402 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']:
403 | key = self.key_embedding(x)
404 | elif self.key_value_emb is None:
405 | key = x
406 | else:
407 | raise
408 |
409 | x_att = self.mobilenet(key)
410 |
411 | if self.Skip_connection is not None:
412 | if self.Skip_connection=='KC':
413 | x_att = x_att + key.squeeze(0)[0]
414 | if self.Layernorm:
415 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
416 | elif self.Skip_connection in ['CF','IF']:
417 | pass
418 | else:
419 | raise
420 | elif self.Skip_connection is None:
421 | pass
422 | else:
423 | raise
424 |
425 | if self.CLS_on:
426 | if self.CLS_mix=='CNN':
427 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0)
428 | x = CT_adjust(x.squeeze(0)).unsqueeze(0)
429 | elif self.CLS_mix in ['SM','Final']:
430 | pass
431 | else:
432 | raise
433 | else:
434 | pass
435 |
436 | if self.Scale is not None:
437 | if self.Scale=='D':
438 | scaling_factor = x_att.shape[1]
439 | elif self.Scale=='T':
440 | scaling_factor = x_att.shape[0]
441 | elif self.Scale=='T_D':
442 | scaling_factor = x_att.shape[0] * x_att.shape[1]
443 | else:
444 | raise
445 | scaling_factor = scaling_factor ** 0.5
446 | x_att = x_att / scaling_factor
447 | elif self.Scale is None:
448 | pass
449 |
450 | if self.Positional_encoding_way=='PGL_SUM':
451 | if self.Positional_encoding is not None:
452 | x_att = self.Positional_encoding_op(x_att)
453 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None:
454 | pass
455 | else:
456 | raise
457 |
458 | x = x.squeeze(0)[0]
459 | if self.Softmax_axis=='T':
460 | temporal_attention = F.softmax(x_att,dim=0)
461 | elif self.Softmax_axis=='D':
462 | spatial_attention = F.softmax(x_att,dim=1)
463 | elif self.Softmax_axis=='TD':
464 | temporal_attention = F.softmax(x_att,dim=0)
465 | spatial_attention = F.softmax(x_att,dim=1)
466 | elif self.Softmax_axis is None:
467 | pass
468 | else:
469 | raise
470 |
471 | if self.CLS_on:
472 | if self.CLS_mix=='SM':
473 | if self.Softmax_axis=='T':
474 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
475 | elif self.Softmax_axis=='D':
476 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
477 | elif self.Softmax_axis=='TD':
478 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
479 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
480 | elif self.Softmax_axis is None:
481 | pass
482 | else:
483 | raise
484 | elif self.CLS_mix in ['CNN','Final']:
485 | pass
486 | else:
487 | raise
488 | else:
489 | pass
490 |
491 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM':
492 | if self.Softmax_axis=='T':
493 | temporal_attention = self.dropout(temporal_attention)
494 | elif self.Softmax_axis=='D':
495 | spatial_attention = self.dropout(spatial_attention)
496 | elif self.Softmax_axis=='TD':
497 | temporal_attention = self.dropout(temporal_attention)
498 | spatial_attention = self.dropout(spatial_attention)
499 | elif self.Softmax_axis is None:
500 | pass
501 | else:
502 | raise
503 |
504 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']:
505 | if self.model_name=='MobileNet_Attention':
506 | x_out = self.value_embedding(x)
507 | elif self.model_name=='MobileNet':
508 | x_out = x_att
509 | else:
510 | raise
511 | elif self.key_value_emb is None:
512 | if self.model_name=='MobileNet':
513 | x_out = x_att
514 | elif self.model_name=='MobileNet_Attention':
515 | x_out = x
516 | else:
517 | raise
518 | else:
519 | raise
520 |
521 | if self.CLS_on:
522 | if self.CLS_mix=='SM':
523 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0)
524 |
525 | if self.Softmax_axis=='T':
526 | x_out = x_out * temporal_attention
527 | elif self.Softmax_axis=='D':
528 | x_out = x_out * spatial_attention
529 | elif self.Softmax_axis=='TD':
530 | T,D = x_out.shape
531 | adjust_frame = T/D
532 | adjust_dimension = D/T
533 | if self.Balance=='T':
534 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
535 | elif self.Balance=='D':
536 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
537 | elif self.Balance=='BD':
538 | if T>D:
539 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
540 | elif TD:
546 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
547 | elif T Callable[[Callable[..., M]], Callable[..., M]]:
23 | def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
24 | key = name if name is not None else fn.__name__
25 | if key in BUILTIN_MODELS:
26 | raise ValueError(f"An entry is already registered under the name '{key}'.")
27 | BUILTIN_MODELS[key] = fn
28 | return fn
29 |
30 | return wrapper
31 |
32 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
33 | """3x3 convolution with padding"""
34 | return nn.Conv2d(
35 | in_planes,
36 | out_planes,
37 | kernel_size=3,
38 | stride=stride,
39 | padding=dilation,
40 | groups=groups,
41 | bias=False,
42 | dilation=dilation,
43 | )
44 |
45 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
46 | """1x1 convolution"""
47 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
48 |
49 | class BasicBlock(nn.Module):
50 | expansion: int = 1
51 |
52 | def __init__(
53 | self,
54 | inplanes: int,
55 | planes: int,
56 | stride: int = 1,
57 | downsample: Optional[nn.Module] = None,
58 | groups: int = 1,
59 | base_width: int = 64,
60 | dilation: int = 1,
61 | norm_layer: Optional[Callable[..., nn.Module]] = None,
62 | ) -> None:
63 | super().__init__()
64 | if norm_layer is None:
65 | norm_layer = nn.BatchNorm2d
66 | if groups != 1 or base_width != 64:
67 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
68 | if dilation > 1:
69 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
70 | self.conv1 = conv3x3(inplanes, planes, stride)
71 | self.bn1 = norm_layer(planes)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.conv2 = conv3x3(planes, planes)
74 | self.bn2 = norm_layer(planes)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x: Tensor) -> Tensor:
79 | identity = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 |
88 | if self.downsample is not None:
89 | identity = self.downsample(x)
90 |
91 | out += identity
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 | class Bottleneck(nn.Module):
97 | expansion: int = 4
98 |
99 | def __init__(
100 | self,
101 | inplanes: int,
102 | planes: int,
103 | stride: int = 1,
104 | downsample: Optional[nn.Module] = None,
105 | groups: int = 1,
106 | base_width: int = 64,
107 | dilation: int = 1,
108 | norm_layer: Optional[Callable[..., nn.Module]] = None,
109 | ) -> None:
110 | super().__init__()
111 | if norm_layer is None:
112 | norm_layer = nn.BatchNorm2d
113 | width = int(planes * (base_width / 64.0)) * groups
114 | self.conv1 = conv1x1(inplanes, width)
115 | self.bn1 = norm_layer(width)
116 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
117 | self.bn2 = norm_layer(width)
118 | self.conv3 = conv1x1(width, planes * self.expansion)
119 | self.bn3 = norm_layer(planes * self.expansion)
120 | self.relu = nn.ReLU(inplace=True)
121 | self.downsample = downsample
122 | self.stride = stride
123 |
124 | def forward(self, x: Tensor) -> Tensor:
125 | identity = x
126 |
127 | out = self.conv1(x)
128 | out = self.bn1(out)
129 | out = self.relu(out)
130 |
131 | out = self.conv2(out)
132 | out = self.bn2(out)
133 | out = self.relu(out)
134 |
135 | out = self.conv3(out)
136 | out = self.bn3(out)
137 |
138 | if self.downsample is not None:
139 | identity = self.downsample(x)
140 |
141 | out += identity
142 | out = self.relu(out)
143 |
144 | return out
145 |
146 | # ResNet as attention
147 | class ResNet_Att(nn.Module):
148 | def __init__(
149 | self,
150 | block: Type[Union[BasicBlock, Bottleneck]] = BasicBlock,
151 | layers: List[int] = [2, 2, 2, 2],
152 | num_classes: int = 1000,
153 | zero_init_residual: bool = False,
154 | groups: int = 1,
155 | width_per_group: int = 64,
156 | replace_stride_with_dilation: Optional[List[bool]] = None,
157 | norm_layer: Optional[Callable[..., nn.Module]] = None,
158 | ) -> None:
159 | super().__init__()
160 | _log_api_usage_once(self)
161 | if norm_layer is None:
162 | norm_layer = nn.BatchNorm2d
163 | self._norm_layer = norm_layer
164 |
165 | self.inplanes = 64
166 | self.dilation = 1
167 | if replace_stride_with_dilation is None:
168 | replace_stride_with_dilation = [False, False, False]
169 | if len(replace_stride_with_dilation) != 3:
170 | raise ValueError(
171 | "replace_stride_with_dilation should be None "
172 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
173 | )
174 | self.groups = groups
175 | self.base_width = width_per_group
176 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
177 | self.bn1 = norm_layer(self.inplanes)
178 | self.relu = nn.ReLU(inplace=True)
179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
180 | self.layer1 = self._make_layer(block, 64, layers[0])
181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
184 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
185 | self.fc = nn.Linear(512 * block.expansion, num_classes)
186 |
187 | for m in self.modules():
188 | if isinstance(m, nn.Conv2d):
189 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
190 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
191 | nn.init.constant_(m.weight, 1)
192 | nn.init.constant_(m.bias, 0)
193 |
194 | if zero_init_residual:
195 | for m in self.modules():
196 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
197 | nn.init.constant_(m.bn3.weight, 0)
198 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
199 | nn.init.constant_(m.bn2.weight, 0)
200 |
201 | def _make_layer(
202 | self,
203 | block: Type[Union[BasicBlock, Bottleneck]],
204 | planes: int,
205 | blocks: int,
206 | stride: int = 1,
207 | dilate: bool = False,
208 | ) -> nn.Sequential:
209 | norm_layer = self._norm_layer
210 | downsample = None
211 | previous_dilation = self.dilation
212 | if dilate:
213 | self.dilation *= stride
214 | stride = 1
215 | if stride != 1 or self.inplanes != planes * block.expansion:
216 | downsample = nn.Sequential(
217 | conv1x1(self.inplanes, planes * block.expansion, stride),
218 | norm_layer(planes * block.expansion),
219 | )
220 |
221 | layers = []
222 | layers.append(
223 | block(
224 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
225 | )
226 | )
227 | self.inplanes = planes * block.expansion
228 | for _ in range(1, blocks):
229 | layers.append(
230 | block(
231 | self.inplanes,
232 | planes,
233 | groups=self.groups,
234 | base_width=self.base_width,
235 | dilation=self.dilation,
236 | norm_layer=norm_layer,
237 | )
238 | )
239 |
240 | return nn.Sequential(*layers)
241 |
242 | def _forward_impl(self, x: Tensor,n_frame) -> Tensor:
243 | x = self.conv1(x)
244 | x = self.bn1(x)
245 | x = self.relu(x)
246 | x = self.maxpool(x)
247 |
248 | x = self.layer1(x)
249 | x = self.layer2(x)
250 | x = self.layer3(x)
251 | x = self.layer4(x)
252 |
253 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1))
254 | x = self.avgpool(x)
255 | x = torch.squeeze(x)
256 | x = x.permute(1,0)
257 |
258 | return x
259 |
260 | def forward(self, x: Tensor) -> Tensor:
261 | n_frame = x.shape[2]
262 | return self._forward_impl(x,n_frame)
263 |
264 | def _resnet(
265 | block: Type[Union[BasicBlock, Bottleneck]],
266 | layers: List[int],
267 | weights: Optional[WeightsEnum],
268 | progress: bool,
269 | **kwargs: Any,
270 | ) -> ResNet_Att:
271 | if weights is not None:
272 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
273 |
274 | model = ResNet_Att(block, layers, **kwargs)
275 |
276 | if weights is not None:
277 | model.load_state_dict(weights.get_state_dict(progress=progress))
278 |
279 | return model
280 |
281 | _COMMON_META = {
282 | "min_size": (1, 1),
283 | "categories": _IMAGENET_CATEGORIES,
284 | }
285 |
286 | class ResNet18_Weights(WeightsEnum):
287 | IMAGENET1K_V1 = Weights(
288 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
289 | transforms=partial(ImageClassification, crop_size=224),
290 | meta={
291 | **_COMMON_META,
292 | "num_params": 11689512,
293 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
294 | "_metrics": {
295 | "ImageNet-1K": {
296 | "acc@1": 69.758,
297 | "acc@5": 89.078,
298 | }
299 | },
300 | "_ops": 1.814,
301 | "_file_size": 44.661,
302 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
303 | },
304 | )
305 | DEFAULT = IMAGENET1K_V1
306 |
307 | @register_model()
308 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
309 | def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet_Att:
310 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__.
311 |
312 | Args:
313 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
314 | pretrained weights to use. See
315 | :class:`~torchvision.models.ResNet18_Weights` below for
316 | more details, and possible values. By default, no pre-trained
317 | weights are used.
318 | progress (bool, optional): If True, displays a progress bar of the
319 | download to stderr. Default is True.
320 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
321 | base class. Please refer to the `source code
322 | `_
323 | for more details about this class.
324 |
325 | .. autoclass:: torchvision.models.ResNet18_Weights
326 | :members:
327 | """
328 | weights = ResNet18_Weights.verify(weights)
329 |
330 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
331 |
332 | # ResNet-based CSTA
333 | class CSTA_ResNet(nn.Module):
334 | def __init__(self,
335 | model_name,
336 | Scale,
337 | Softmax_axis,
338 | Balance,
339 | Positional_encoding,
340 | Positional_encoding_shape,
341 | Positional_encoding_way,
342 | Dropout_on,
343 | Dropout_ratio,
344 | Classifier_on,
345 | CLS_on,
346 | CLS_mix,
347 | key_value_emb,
348 | Skip_connection,
349 | Layernorm,dim=512):
350 | super().__init__()
351 | self.resnet = ResNet_Att()
352 |
353 | self.model_name = model_name
354 | self.Scale = Scale
355 | self.Softmax_axis = Softmax_axis
356 | self.Balance = Balance
357 |
358 | self.Positional_encoding = Positional_encoding
359 | self.Positional_encoding_shape = Positional_encoding_shape
360 | self.Positional_encoding_way = Positional_encoding_way
361 | self.Dropout_on = Dropout_on
362 | self.Dropout_ratio = Dropout_ratio
363 |
364 | self.Classifier_on = Classifier_on
365 | self.CLS_on = CLS_on
366 | self.CLS_mix = CLS_mix
367 |
368 | self.key_value_emb = key_value_emb
369 | self.Skip_connection = Skip_connection
370 | self.Layernorm = Layernorm
371 |
372 | self.dim = dim
373 |
374 | if self.Positional_encoding is not None:
375 | if self.Positional_encoding=='FPE':
376 | self.Positional_encoding_op = FixedPositionalEncoding(
377 | Positional_encoding_shape=self.Positional_encoding_shape,
378 | dim=self.dim
379 | )
380 | elif self.Positional_encoding=='RPE':
381 | self.Positional_encoding_op = RelativePositionalEncoding(
382 | Positional_encoding_shape=self.Positional_encoding_shape,
383 | dim=self.dim
384 | )
385 | elif self.Positional_encoding=='LPE':
386 | self.Positional_encoding_op = LearnablePositionalEncoding(
387 | Positional_encoding_shape=self.Positional_encoding_shape,
388 | dim=self.dim
389 | )
390 | elif self.Positional_encoding=='CPE':
391 | self.Positional_encoding_op = ConditionalPositionalEncoding(
392 | Positional_encoding_shape=self.Positional_encoding_shape,
393 | Positional_encoding_way=self.Positional_encoding_way,
394 | dim=self.dim
395 | )
396 | elif self.Positional_encoding is None:
397 | pass
398 | else:
399 | raise
400 |
401 | if self.Positional_encoding_way=='Transformer':
402 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim)
403 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
404 | pass
405 | else:
406 | raise
407 |
408 | if self.Dropout_on:
409 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio))
410 |
411 | if self.Classifier_on:
412 | self.linear1 = nn.Sequential(
413 | nn.Linear(in_features=self.dim, out_features=self.dim),
414 | nn.ReLU(),
415 | nn.Dropout(p=0.5),
416 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6)
417 | )
418 | self.linear2 = nn.Sequential(
419 | nn.Linear(in_features=self.dim, out_features=1),
420 | nn.Sigmoid()
421 | )
422 |
423 | for name,param in self.named_parameters():
424 | if name in ['linear1.0.weight','linear2.0.weight']:
425 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0))
426 | elif name in ['linear1.0.bias','linear2.0.bias']:
427 | nn.init.constant_(param, 0.1)
428 | else:
429 | self.gap = nn.AdaptiveAvgPool1d(1)
430 |
431 | if self.CLS_on:
432 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024))
433 |
434 | if self.key_value_emb is not None:
435 | if self.key_value_emb.lower()=='k':
436 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
437 | elif self.key_value_emb.lower()=='v':
438 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim)
439 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv':
440 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
441 | if self.model_name=='ResNet_Attention':
442 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim)
443 | else:
444 | raise
445 |
446 | if self.Layernorm:
447 | if self.Skip_connection=='KC':
448 | self.layernorm1 = nn.BatchNorm2d(num_features=1)
449 | elif self.Skip_connection=='CF':
450 | self.layernorm2 = nn.BatchNorm2d(num_features=1)
451 | elif self.Skip_connection=='IF':
452 | self.layernorm3 = nn.BatchNorm2d(num_features=1)
453 | elif self.Skip_connection is None:
454 | pass
455 | else:
456 | raise
457 |
458 | def forward(self, x):
459 | n_frame = x.shape[2]
460 |
461 | if self.Positional_encoding_way=='Transformer':
462 | x = self.Positional_encoding_embedding(x)
463 | if self.CLS_on:
464 | x = torch.cat((self.CLS,x),dim=2)
465 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim))
466 |
467 | if self.Positional_encoding_way=='Transformer':
468 | if self.Positional_encoding is not None:
469 | x = self.Positional_encoding_op(x)
470 | if self.Dropout_on:
471 | x = self.dropout(x)
472 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
473 | pass
474 | else:
475 | raise
476 |
477 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']:
478 | key = self.key_embedding(x)
479 | elif self.key_value_emb is None:
480 | key = x
481 | else:
482 | raise
483 |
484 | x_att = self.resnet(key)
485 |
486 | if self.Skip_connection is not None:
487 | if self.Skip_connection=='KC':
488 | x_att = x_att + key.squeeze(0)[0]
489 | if self.Layernorm:
490 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
491 | elif self.Skip_connection in ['CF','IF']:
492 | pass
493 | else:
494 | raise
495 | elif self.Skip_connection is None:
496 | pass
497 | else:
498 | raise
499 |
500 | if self.CLS_on:
501 | if self.CLS_mix=='CNN':
502 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0)
503 | x = CT_adjust(x.squeeze(0)).unsqueeze(0)
504 | elif self.CLS_mix in ['SM','Final']:
505 | pass
506 | else:
507 | raise
508 | else:
509 | pass
510 |
511 | if self.Scale is not None:
512 | if self.Scale=='D':
513 | scaling_factor = x_att.shape[1]
514 | elif self.Scale=='T':
515 | scaling_factor = x_att.shape[0]
516 | elif self.Scale=='T_D':
517 | scaling_factor = x_att.shape[0] * x_att.shape[1]
518 | else:
519 | raise
520 | scaling_factor = scaling_factor ** 0.5
521 | x_att = x_att / scaling_factor
522 | elif self.Scale is None:
523 | pass
524 |
525 | if self.Positional_encoding_way=='PGL_SUM':
526 | if self.Positional_encoding is not None:
527 | x_att = self.Positional_encoding_op(x_att)
528 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None:
529 | pass
530 | else:
531 | raise
532 |
533 | x = x.squeeze(0)[0]
534 | if self.Softmax_axis=='T':
535 | temporal_attention = F.softmax(x_att,dim=0)
536 | elif self.Softmax_axis=='D':
537 | spatial_attention = F.softmax(x_att,dim=1)
538 | elif self.Softmax_axis=='TD':
539 | temporal_attention = F.softmax(x_att,dim=0)
540 | spatial_attention = F.softmax(x_att,dim=1)
541 | elif self.Softmax_axis is None:
542 | pass
543 | else:
544 | raise
545 |
546 | if self.CLS_on:
547 | if self.CLS_mix=='SM':
548 | if self.Softmax_axis=='T':
549 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
550 | elif self.Softmax_axis=='D':
551 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
552 | elif self.Softmax_axis=='TD':
553 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
554 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
555 | elif self.Softmax_axis is None:
556 | pass
557 | else:
558 | raise
559 | elif self.CLS_mix in ['CNN','Final']:
560 | pass
561 | else:
562 | raise
563 | else:
564 | pass
565 |
566 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM':
567 | if self.Softmax_axis=='T':
568 | temporal_attention = self.dropout(temporal_attention)
569 | elif self.Softmax_axis=='D':
570 | spatial_attention = self.dropout(spatial_attention)
571 | elif self.Softmax_axis=='TD':
572 | temporal_attention = self.dropout(temporal_attention)
573 | spatial_attention = self.dropout(spatial_attention)
574 | elif self.Softmax_axis is None:
575 | pass
576 | else:
577 | raise
578 |
579 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']:
580 | if self.model_name=='ResNet_Attention':
581 | x_out = self.value_embedding(x)
582 | elif self.model_name=='ResNet':
583 | x_out = x_att
584 | else:
585 | raise
586 | elif self.key_value_emb is None:
587 | if self.model_name=='ResNet':
588 | x_out = x_att
589 | elif self.model_name=='ResNet_Attention':
590 | x_out = x
591 | else:
592 | raise
593 | else:
594 | raise
595 |
596 | if self.CLS_on:
597 | if self.CLS_mix=='SM':
598 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0)
599 |
600 | if self.Softmax_axis=='T':
601 | x_out = x_out * temporal_attention
602 | elif self.Softmax_axis=='D':
603 | x_out = x_out * spatial_attention
604 | elif self.Softmax_axis=='TD':
605 | T,D = x_out.shape
606 | adjust_frame = T/D
607 | adjust_dimension = D/T
608 | if self.Balance=='T':
609 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
610 | elif self.Balance=='D':
611 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
612 | elif self.Balance=='BD':
613 | if T>D:
614 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
615 | elif TD:
621 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
622 | elif T int:
31 | return _make_divisible(channels * width_mult, 8, min_value)
32 |
33 | class MBConvConfig(_MBConvConfig):
34 | def __init__(
35 | self,
36 | expand_ratio: float,
37 | kernel: int,
38 | stride: int,
39 | input_channels: int,
40 | out_channels: int,
41 | num_layers: int,
42 | width_mult: float = 1.0,
43 | depth_mult: float = 1.0,
44 | block: Optional[Callable[..., nn.Module]] = None,
45 | ) -> None:
46 | input_channels = self.adjust_channels(input_channels, width_mult)
47 | out_channels = self.adjust_channels(out_channels, width_mult)
48 | num_layers = self.adjust_depth(num_layers, depth_mult)
49 | if block is None:
50 | block = MBConv
51 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
52 |
53 | @staticmethod
54 | def adjust_depth(num_layers: int, depth_mult: float):
55 | return int(math.ceil(num_layers * depth_mult))
56 |
57 | class MBConv(nn.Module):
58 | def __init__(
59 | self,
60 | cnf: MBConvConfig,
61 | stochastic_depth_prob: float,
62 | norm_layer: Callable[..., nn.Module],
63 | se_layer: Callable[..., nn.Module] = SqueezeExcitation,
64 | ) -> None:
65 | super().__init__()
66 |
67 | if not (1 <= cnf.stride <= 2):
68 | raise ValueError("illegal stride value")
69 |
70 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
71 |
72 | layers: List[nn.Module] = []
73 | activation_layer = nn.SiLU
74 |
75 | expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
76 | if expanded_channels != cnf.input_channels:
77 | layers.append(
78 | Conv2dNormActivation(
79 | cnf.input_channels,
80 | expanded_channels,
81 | kernel_size=1,
82 | norm_layer=norm_layer,
83 | activation_layer=activation_layer,
84 | )
85 | )
86 |
87 | layers.append(
88 | Conv2dNormActivation(
89 | expanded_channels,
90 | expanded_channels,
91 | kernel_size=cnf.kernel,
92 | stride=cnf.stride,
93 | groups=expanded_channels,
94 | norm_layer=norm_layer,
95 | activation_layer=activation_layer,
96 | )
97 | )
98 |
99 | squeeze_channels = max(1, cnf.input_channels // 4)
100 | layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
101 |
102 | layers.append(
103 | Conv2dNormActivation(
104 | expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
105 | )
106 | )
107 |
108 | self.block = nn.Sequential(*layers)
109 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
110 | self.out_channels = cnf.out_channels
111 |
112 | def forward(self, input: Tensor) -> Tensor:
113 | result = self.block(input)
114 | if self.use_res_connect:
115 | result = self.stochastic_depth(result)
116 | result += input
117 | return result
118 |
119 | class MBConvConfig(_MBConvConfig):
120 | def __init__(
121 | self,
122 | expand_ratio: float,
123 | kernel: int,
124 | stride: int,
125 | input_channels: int,
126 | out_channels: int,
127 | num_layers: int,
128 | width_mult: float = 1.0,
129 | depth_mult: float = 1.0,
130 | block: Optional[Callable[..., nn.Module]] = None,
131 | ) -> None:
132 | input_channels = self.adjust_channels(input_channels, width_mult)
133 | out_channels = self.adjust_channels(out_channels, width_mult)
134 | num_layers = self.adjust_depth(num_layers, depth_mult)
135 | if block is None:
136 | block = MBConv
137 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
138 |
139 | @staticmethod
140 | def adjust_depth(num_layers: int, depth_mult: float):
141 | return int(math.ceil(num_layers * depth_mult))
142 |
143 | class FusedMBConvConfig(_MBConvConfig):
144 | def __init__(
145 | self,
146 | expand_ratio: float,
147 | kernel: int,
148 | stride: int,
149 | input_channels: int,
150 | out_channels: int,
151 | num_layers: int,
152 | block: Optional[Callable[..., nn.Module]] = None,
153 | ) -> None:
154 | if block is None:
155 | block = FusedMBConv
156 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
157 |
158 | class FusedMBConv(nn.Module):
159 | def __init__(
160 | self,
161 | cnf: FusedMBConvConfig,
162 | stochastic_depth_prob: float,
163 | norm_layer: Callable[..., nn.Module],
164 | ) -> None:
165 | super().__init__()
166 |
167 | if not (1 <= cnf.stride <= 2):
168 | raise ValueError("illegal stride value")
169 |
170 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
171 |
172 | layers: List[nn.Module] = []
173 | activation_layer = nn.SiLU
174 |
175 | expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
176 | if expanded_channels != cnf.input_channels:
177 | layers.append(
178 | Conv2dNormActivation(
179 | cnf.input_channels,
180 | expanded_channels,
181 | kernel_size=cnf.kernel,
182 | stride=cnf.stride,
183 | norm_layer=norm_layer,
184 | activation_layer=activation_layer,
185 | )
186 | )
187 |
188 | layers.append(
189 | Conv2dNormActivation(
190 | expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
191 | )
192 | )
193 | else:
194 | layers.append(
195 | Conv2dNormActivation(
196 | cnf.input_channels,
197 | cnf.out_channels,
198 | kernel_size=cnf.kernel,
199 | stride=cnf.stride,
200 | norm_layer=norm_layer,
201 | activation_layer=activation_layer,
202 | )
203 | )
204 |
205 | self.block = nn.Sequential(*layers)
206 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
207 | self.out_channels = cnf.out_channels
208 |
209 | def forward(self, input: Tensor) -> Tensor:
210 | result = self.block(input)
211 | if self.use_res_connect:
212 | result = self.stochastic_depth(result)
213 | result += input
214 | return result
215 |
216 | class FusedMBConvConfig(_MBConvConfig):
217 | def __init__(
218 | self,
219 | expand_ratio: float,
220 | kernel: int,
221 | stride: int,
222 | input_channels: int,
223 | out_channels: int,
224 | num_layers: int,
225 | block: Optional[Callable[..., nn.Module]] = None,
226 | ) -> None:
227 | if block is None:
228 | block = FusedMBConv
229 | super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
230 |
231 | def _efficientnet_conf(
232 | arch: str,
233 | **kwargs: Any,
234 | ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
235 | inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
236 | if arch.startswith("efficientnet_b"):
237 | bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
238 | inverted_residual_setting = [
239 | bneck_conf(1, 3, 1, 32, 16, 1),
240 | bneck_conf(6, 3, 2, 16, 24, 2),
241 | bneck_conf(6, 5, 2, 24, 40, 2),
242 | bneck_conf(6, 3, 2, 40, 80, 3),
243 | bneck_conf(6, 5, 1, 80, 112, 3),
244 | bneck_conf(6, 5, 2, 112, 192, 4),
245 | bneck_conf(6, 3, 1, 192, 320, 1),
246 | ]
247 | last_channel = None
248 | elif arch.startswith("efficientnet_v2_s"):
249 | inverted_residual_setting = [
250 | FusedMBConvConfig(1, 3, 1, 24, 24, 2),
251 | FusedMBConvConfig(4, 3, 2, 24, 48, 4),
252 | FusedMBConvConfig(4, 3, 2, 48, 64, 4),
253 | MBConvConfig(4, 3, 2, 64, 128, 6),
254 | MBConvConfig(6, 3, 1, 128, 160, 9),
255 | MBConvConfig(6, 3, 2, 160, 256, 15),
256 | ]
257 | last_channel = 1280
258 | elif arch.startswith("efficientnet_v2_m"):
259 | inverted_residual_setting = [
260 | FusedMBConvConfig(1, 3, 1, 24, 24, 3),
261 | FusedMBConvConfig(4, 3, 2, 24, 48, 5),
262 | FusedMBConvConfig(4, 3, 2, 48, 80, 5),
263 | MBConvConfig(4, 3, 2, 80, 160, 7),
264 | MBConvConfig(6, 3, 1, 160, 176, 14),
265 | MBConvConfig(6, 3, 2, 176, 304, 18),
266 | MBConvConfig(6, 3, 1, 304, 512, 5),
267 | ]
268 | last_channel = 1280
269 | elif arch.startswith("efficientnet_v2_l"):
270 | inverted_residual_setting = [
271 | FusedMBConvConfig(1, 3, 1, 32, 32, 4),
272 | FusedMBConvConfig(4, 3, 2, 32, 64, 7),
273 | FusedMBConvConfig(4, 3, 2, 64, 96, 7),
274 | MBConvConfig(4, 3, 2, 96, 192, 10),
275 | MBConvConfig(6, 3, 1, 192, 224, 19),
276 | MBConvConfig(6, 3, 2, 224, 384, 25),
277 | MBConvConfig(6, 3, 1, 384, 640, 7),
278 | ]
279 | last_channel = 1280
280 | else:
281 | raise ValueError(f"Unsupported model type {arch}")
282 |
283 | return inverted_residual_setting, last_channel
284 |
285 | # EfficientNet as attention
286 | class EfficientNet_Att(nn.Module):
287 | def __init__(
288 | self,
289 | inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)[0],
290 | dropout: float = 0.2,
291 | stochastic_depth_prob: float = 0.2,
292 | num_classes: int = 1000,
293 | norm_layer: Optional[Callable[..., nn.Module]] = None,
294 | last_channel: Optional[int] = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)[1],
295 | ) -> None:
296 | """
297 | EfficientNet V1 and V2 main class
298 |
299 | Args:
300 | inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
301 | dropout (float): The droupout probability
302 | stochastic_depth_prob (float): The stochastic depth probability
303 | num_classes (int): Number of classes
304 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
305 | last_channel (int): The number of channels on the penultimate layer
306 | """
307 | super().__init__()
308 | _log_api_usage_once(self)
309 |
310 | if not inverted_residual_setting:
311 | raise ValueError("The inverted_residual_setting should not be empty")
312 | elif not (
313 | isinstance(inverted_residual_setting, Sequence)
314 | and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
315 | ):
316 | raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
317 |
318 | if norm_layer is None:
319 | norm_layer = nn.BatchNorm2d
320 |
321 | layers: List[nn.Module] = []
322 |
323 | firstconv_output_channels = inverted_residual_setting[0].input_channels
324 | layers.append(
325 | Conv2dNormActivation(
326 | 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
327 | )
328 | )
329 |
330 | total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
331 | stage_block_id = 0
332 | for cnf in inverted_residual_setting:
333 | stage: List[nn.Module] = []
334 | for _ in range(cnf.num_layers):
335 |
336 | block_cnf = copy.copy(cnf)
337 |
338 | if stage:
339 | block_cnf.input_channels = block_cnf.out_channels
340 | block_cnf.stride = 1
341 |
342 | sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
343 |
344 | stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
345 | stage_block_id += 1
346 |
347 | layers.append(nn.Sequential(*stage))
348 |
349 | lastconv_input_channels = inverted_residual_setting[-1].out_channels
350 | lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
351 | layers.append(
352 | Conv2dNormActivation(
353 | lastconv_input_channels,
354 | lastconv_output_channels,
355 | kernel_size=1,
356 | norm_layer=norm_layer,
357 | activation_layer=nn.SiLU,
358 | )
359 | )
360 |
361 | self.features = nn.Sequential(*layers)
362 | self.avgpool = nn.AdaptiveAvgPool2d(1)
363 | self.classifier = nn.Sequential(
364 | nn.Dropout(p=dropout, inplace=True),
365 | nn.Linear(lastconv_output_channels, num_classes),
366 | )
367 |
368 | for m in self.modules():
369 | if isinstance(m, nn.Conv2d):
370 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
371 | if m.bias is not None:
372 | nn.init.zeros_(m.bias)
373 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
374 | nn.init.ones_(m.weight)
375 | nn.init.zeros_(m.bias)
376 | elif isinstance(m, nn.Linear):
377 | init_range = 1.0 / math.sqrt(m.out_features)
378 | nn.init.uniform_(m.weight, -init_range, init_range)
379 | nn.init.zeros_(m.bias)
380 |
381 | def _forward_impl(self, x: Tensor,n_frame) -> Tensor:
382 | x = self.features(x)
383 |
384 | self.avgpool = nn.AdaptiveAvgPool2d((n_frame,1))
385 | x = self.avgpool(x)
386 | x = torch.squeeze(x)
387 | x = x.permute(1,0)
388 |
389 | return x
390 |
391 | def forward(self, x: Tensor) -> Tensor:
392 | n_frame = x.shape[2]
393 | return self._forward_impl(x,n_frame)
394 |
395 | # EfficientNet-based CSTA
396 | class CSTA_EfficientNet(nn.Module):
397 | def __init__(self,
398 | model_name,
399 | Scale,
400 | Softmax_axis,
401 | Balance,
402 | Positional_encoding,
403 | Positional_encoding_shape,
404 | Positional_encoding_way,
405 | Dropout_on,
406 | Dropout_ratio,
407 | Classifier_on,
408 | CLS_on,
409 | CLS_mix,
410 | key_value_emb,
411 | Skip_connection,
412 | Layernorm,
413 | dim=1280):
414 | super().__init__()
415 | self.efficientnet = EfficientNet_Att()
416 |
417 | self.model_name = model_name
418 | self.Scale = Scale
419 | self.Softmax_axis = Softmax_axis
420 | self.Balance = Balance
421 |
422 | self.Positional_encoding = Positional_encoding
423 | self.Positional_encoding_shape = Positional_encoding_shape
424 | self.Positional_encoding_way = Positional_encoding_way
425 | self.Dropout_on = Dropout_on
426 | self.Dropout_ratio = Dropout_ratio
427 |
428 | self.Classifier_on = Classifier_on
429 | self.CLS_on = CLS_on
430 | self.CLS_mix = CLS_mix
431 |
432 | self.key_value_emb = key_value_emb
433 | self.Skip_connection = Skip_connection
434 | self.Layernorm = Layernorm
435 |
436 | self.dim = dim
437 |
438 | if self.Positional_encoding is not None:
439 | if self.Positional_encoding=='FPE':
440 | self.Positional_encoding_op = FixedPositionalEncoding(
441 | Positional_encoding_shape=self.Positional_encoding_shape,
442 | dim=self.dim
443 | )
444 | elif self.Positional_encoding=='RPE':
445 | self.Positional_encoding_op = RelativePositionalEncoding(
446 | Positional_encoding_shape=self.Positional_encoding_shape,
447 | dim=self.dim
448 | )
449 | elif self.Positional_encoding=='LPE':
450 | self.Positional_encoding_op = LearnablePositionalEncoding(
451 | Positional_encoding_shape=self.Positional_encoding_shape,
452 | dim=self.dim
453 | )
454 | elif self.Positional_encoding=='CPE':
455 | self.Positional_encoding_op = ConditionalPositionalEncoding(
456 | Positional_encoding_shape=self.Positional_encoding_shape,
457 | Positional_encoding_way=self.Positional_encoding_way,
458 | dim=self.dim
459 | )
460 | elif self.Positional_encoding is None:
461 | pass
462 | else:
463 | raise
464 |
465 | if self.Positional_encoding_way=='Transformer':
466 | self.Positional_encoding_embedding = nn.Linear(in_features=self.dim, out_features=self.dim)
467 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
468 | pass
469 | else:
470 | raise
471 |
472 | if self.Dropout_on:
473 | self.dropout = nn.Dropout(p=float(self.Dropout_ratio))
474 |
475 | if self.Classifier_on:
476 | self.linear1 = nn.Sequential(
477 | nn.Linear(in_features=self.dim, out_features=self.dim),
478 | nn.ReLU(),
479 | nn.Dropout(p=0.5),
480 | nn.LayerNorm(normalized_shape=self.dim, eps=1e-6)
481 | )
482 | self.linear2 = nn.Sequential(
483 | nn.Linear(in_features=self.dim, out_features=1),
484 | nn.Sigmoid()
485 | )
486 |
487 | for name,param in self.named_parameters():
488 | if name in ['linear1.0.weight','linear2.0.weight']:
489 | nn.init.xavier_uniform_(param, gain=np.sqrt(2.0))
490 | elif name in ['linear1.0.bias','linear2.0.bias']:
491 | nn.init.constant_(param, 0.1)
492 | else:
493 | self.gap = nn.AdaptiveAvgPool1d(1)
494 |
495 | if self.CLS_on:
496 | self.CLS = nn.Parameter(torch.zeros(1,3,1,1024))
497 |
498 | if self.key_value_emb is not None:
499 | if self.key_value_emb.lower()=='k':
500 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
501 | elif self.key_value_emb.lower()=='v':
502 | self.value_embedding = nn.Linear(in_features=self.dim,out_features=self.dim)
503 | elif ''.join(sorted(self.key_value_emb.lower()))=='kv':
504 | self.key_embedding = nn.Linear(in_features=1024,out_features=self.dim)
505 | if self.model_name=='EfficientNet_Attention':
506 | self.value_embedding = nn.Linear(in_features=1024,out_features=self.dim)
507 | else:
508 | raise
509 |
510 | if self.Layernorm:
511 | if self.Skip_connection=='KC':
512 | self.layernorm1 = nn.BatchNorm2d(num_features=1)
513 | elif self.Skip_connection=='CF':
514 | self.layernorm2 = nn.BatchNorm2d(num_features=1)
515 | elif self.Skip_connection=='IF':
516 | self.layernorm3 = nn.BatchNorm2d(num_features=1)
517 | elif self.Skip_connection is None:
518 | pass
519 | else:
520 | raise
521 |
522 | def forward(self, x):
523 | n_frame = x.shape[2]
524 |
525 | if self.Positional_encoding_way=='Transformer':
526 | x = self.Positional_encoding_embedding(x)
527 | if self.CLS_on:
528 | x = torch.cat((self.CLS,x),dim=2)
529 | CT_adjust = nn.AdaptiveAvgPool2d((n_frame,self.dim))
530 |
531 | if self.Positional_encoding_way=='Transformer':
532 | if self.Positional_encoding is not None:
533 | x = self.Positional_encoding_op(x)
534 | if self.Dropout_on:
535 | x = self.dropout(x)
536 | elif self.Positional_encoding_way=='PGL_SUM' or self.Positional_encoding_way is None:
537 | pass
538 | else:
539 | raise
540 |
541 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['k','kv']:
542 | key = self.key_embedding(x)
543 | elif self.key_value_emb is None:
544 | key = x
545 | else:
546 | raise
547 |
548 | x_att = self.efficientnet(key)
549 |
550 | if self.Skip_connection is not None:
551 | if self.Skip_connection=='KC':
552 | x_att = x_att + key.squeeze(0)[0]
553 | if self.Layernorm:
554 | x_att = self.layernorm1(x_att.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
555 | elif self.Skip_connection in ['CF','IF']:
556 | pass
557 | else:
558 | raise
559 | elif self.Skip_connection is None:
560 | pass
561 | else:
562 | raise
563 |
564 | if self.CLS_on:
565 | if self.CLS_mix=='CNN':
566 | x_att = CT_adjust(x_att.unsqueeze(0)).squeeze(0)
567 | x = CT_adjust(x.squeeze(0)).unsqueeze(0)
568 | elif self.CLS_mix in ['SM','Final']:
569 | pass
570 | else:
571 | raise
572 | else:
573 | pass
574 |
575 | if self.Scale is not None:
576 | if self.Scale=='D':
577 | scaling_factor = x_att.shape[1]
578 | elif self.Scale=='T':
579 | scaling_factor = x_att.shape[0]
580 | elif self.Scale=='T_D':
581 | scaling_factor = x_att.shape[0] * x_att.shape[1]
582 | else:
583 | raise
584 | scaling_factor = scaling_factor ** 0.5
585 | x_att = x_att / scaling_factor
586 | elif self.Scale is None:
587 | pass
588 |
589 | if self.Positional_encoding_way=='PGL_SUM':
590 | if self.Positional_encoding is not None:
591 | x_att = self.Positional_encoding_op(x_att)
592 | elif self.Positional_encoding_way=='Transformer' or self.Positional_encoding_way is None:
593 | pass
594 | else:
595 | raise
596 |
597 | x = x.squeeze(0)[0]
598 | if self.Softmax_axis=='T':
599 | temporal_attention = F.softmax(x_att,dim=0)
600 | elif self.Softmax_axis=='D':
601 | spatial_attention = F.softmax(x_att,dim=1)
602 | elif self.Softmax_axis=='TD':
603 | temporal_attention = F.softmax(x_att,dim=0)
604 | spatial_attention = F.softmax(x_att,dim=1)
605 | elif self.Softmax_axis is None:
606 | pass
607 | else:
608 | raise
609 |
610 | if self.CLS_on:
611 | if self.CLS_mix=='SM':
612 | if self.Softmax_axis=='T':
613 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
614 | elif self.Softmax_axis=='D':
615 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
616 | elif self.Softmax_axis=='TD':
617 | temporal_attention = CT_adjust(temporal_attention.unsqueeze(0)).squeeze(0)
618 | spatial_attention = CT_adjust(spatial_attention.unsqueeze(0)).squeeze(0)
619 | elif self.Softmax_axis is None:
620 | pass
621 | else:
622 | raise
623 | elif self.CLS_mix in ['CNN','Final']:
624 | pass
625 | else:
626 | raise
627 | else:
628 | pass
629 |
630 | if self.Dropout_on and self.Positional_encoding_way=='PGL_SUM':
631 | if self.Softmax_axis=='T':
632 | temporal_attention = self.dropout(temporal_attention)
633 | elif self.Softmax_axis=='D':
634 | spatial_attention = self.dropout(spatial_attention)
635 | elif self.Softmax_axis=='TD':
636 | temporal_attention = self.dropout(temporal_attention)
637 | spatial_attention = self.dropout(spatial_attention)
638 | elif self.Softmax_axis is None:
639 | pass
640 | else:
641 | raise
642 |
643 | if self.key_value_emb is not None and self.key_value_emb.lower() in ['v','kv']:
644 | if self.model_name=='EfficientNet_Attention':
645 | x_out = self.value_embedding(x)
646 | elif self.model_name=='EfficientNet':
647 | x_out = x_att
648 | else:
649 | raise
650 | elif self.key_value_emb is None:
651 | if self.model_name=='EfficientNet':
652 | x_out = x_att
653 | elif self.model_name=='EfficientNet_Attention':
654 | x_out = x
655 | else:
656 | raise
657 | else:
658 | raise
659 |
660 | if self.CLS_on:
661 | if self.CLS_mix=='SM':
662 | x_out = CT_adjust(x_out.unsqueeze(0)).squeeze(0)
663 |
664 | if self.Softmax_axis=='T':
665 | x_out = x_out * temporal_attention
666 | elif self.Softmax_axis=='D':
667 | x_out = x_out * spatial_attention
668 | elif self.Softmax_axis=='TD':
669 | T,D = x_out.shape
670 | adjust_frame = T/D
671 | adjust_dimension = D/T
672 | if self.Balance=='T':
673 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
674 | elif self.Balance=='D':
675 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
676 | elif self.Balance=='BD':
677 | if T>D:
678 | x_out = x_out * temporal_attention + x_out * spatial_attention * adjust_dimension
679 | elif TD:
685 | x_out = x_out * temporal_attention * adjust_frame + x_out * spatial_attention
686 | elif T