├── figures
├── fig_teaser.jpg
├── fig_dataset_examples.jpg
├── fig_dataset_statistics.jpg
└── fig_related_datasets.jpg
├── src
├── data
│ ├── PAVS10K_seqs_test.txt
│ └── PAVS10K_seqs_train.txt
├── options.py
├── models
│ ├── convgru.py
│ ├── non_local_dot_product.py
│ ├── rcrnet_vit.py
│ ├── resnet_dilation.py
│ ├── blocks.py
│ ├── CAVNet.py
│ └── VIT.py
├── CAVNet_test.py
├── CAVNet_train.py
├── utils.py
└── data.py
├── video_to_frames.py
├── video_seq_link
└── README.md
/figures/fig_teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_teaser.jpg
--------------------------------------------------------------------------------
/figures/fig_dataset_examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_dataset_examples.jpg
--------------------------------------------------------------------------------
/figures/fig_dataset_statistics.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_dataset_statistics.jpg
--------------------------------------------------------------------------------
/figures/fig_related_datasets.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_related_datasets.jpg
--------------------------------------------------------------------------------
/src/data/PAVS10K_seqs_test.txt:
--------------------------------------------------------------------------------
1 | _-6QUCaLvQ_3I
2 | _Ellen2
3 | _-Ngj6C_RMK1g_1
4 | _-J0Q4L68o3xE_1
5 | _-I-43DzvhxX8_1
6 | _-J0Q4L68o3xE_2
7 | _-SdGCX2H-_Uk
8 | _-RbgxpagCY_c
9 | _-Oak26yVbibQ
10 | _conversation
11 | _-V3zp7XOGBhs
12 | _-MYmMZxmSc1U
13 | _-Uy5LTocHmoA
14 | _-Bvu9m__ZX60
15 | _-g4fQ5iOVzsI
16 | _-JBrp8aG4lro_1
17 | _-72f3ayGhMEA_6
18 | _-1An41lDIJ6Q
19 | _-ey9J7w98wlI_2
20 | _-ByBF08H-wDA
21 | _conversation2
22 | _-HNQMF7e6IL0
23 | _-72f3ayGhMEA_4
24 | _-o7JEBWV4CmY
25 | _-5h95uTtPeck
26 | _-kZB3KMhqqyI
27 | _-RrGhiInqXhc
28 |
--------------------------------------------------------------------------------
/src/data/PAVS10K_seqs_train.txt:
--------------------------------------------------------------------------------
1 | _-RSYbTSTz91g
2 | _-IRG9Z7Y2uS4
3 | _-ey9J7w98wlI
4 | _-MzcdEI-tSUc_4
5 | _-Ngj6C_RMK1g_2
6 | _-72f3ayGhMEA_1
7 | _-P4KyjvsceZQ
8 | _-4fxKBGthpaw
9 | _-0suxwissusc
10 | _-nZJGt3ZVg3g
11 | _-bO43msZTfwA
12 | _Ellen
13 | _-SJIbpqgYWGw
14 | _-0cfJOmUaNNI_1
15 | _-1LM84FSzW0g_3
16 | _band
17 | _-0cfJOmUaNNI_2
18 | _-SZYXQ-6bfiQ
19 | _-jb5YxiXIsjU
20 | _-72f3ayGhMEA_3
21 | _-ZuXCMpVR24I
22 | _-MFVmxoXgeNQ
23 | _-72f3ayGhMEA_2
24 | _-eqmjLZGZ36k
25 | _-gTB1nfK-0Ac
26 | _-Gq4Y4gL3zSg
27 | _-LThm7VYvwxY
28 | _-n524y8uPUaU
29 | _-TCUsegqBZ_M
30 | _-eGGFGota5_A
31 | _-UpFZ2YaKeqM
32 | _-idLVnagjl_s
33 | _-1LM84FSzW0g_1
34 | _-nDu57CGqbLM
35 | _-1LM84FSzW0g_2
36 | _-gy4TI-6j5po
37 | _-4SilhsTuDU0
38 | _-MDwhMMSFkJ8
39 | _-G8pABGosD38
40 | _-dBM3eM9HOoA
41 |
--------------------------------------------------------------------------------
/video_to_frames.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 |
5 | def Vid2Frm():
6 | seqs = os.listdir(os.path.join(os.getcwd(), 'Videos')) # please take care of your own path
7 | for seq in seqs:
8 | seq_path = os.path.join(os.getcwd(), 'Videos', seq)
9 | save_path = os.path.join(os.getcwd(), 'Frames', seq[:-4])
10 | if not os.path.exists(save_path): os.makedirs(save_path)
11 | cap = cv2.VideoCapture(seq_path)
12 | frames_num = int(cap.get(7))
13 | countF = 0
14 | for i in range(frames_num):
15 | ret, frame = cap.read()
16 | cv2.imwrite(os.path.join(save_path, format(str(countF), '0>5s') + '.png'), frame)
17 | print(" {} frames are extracted.".format(countF))
18 | countF += 1
19 |
20 |
21 | if __name__ == '__main__':
22 | Vid2Frm()
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--epoch', type=int, default=21, help='epoch number')
5 | parser.add_argument('--lr', type=float, default=2.5e-6, help='learning rate')
6 | parser.add_argument('--batchsize', type=int, default=1, help='training batch size') # set as 1
7 | parser.add_argument('--trainsize', type=int, default=416, help='training dataset size')
8 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
9 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
10 | parser.add_argument('--lat_weight', type=int, default=10, help='weighting latent loss')
11 | parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu')
12 | parser.add_argument('--tr_root', type=str, default=os.getcwd() + '/data/PAVS10K_seqs_train.txt', help='')
13 | parser.add_argument('--te_root', type=str, default=os.getcwd() + '/data/PAVS10K_seqs_test.txt', help='')
14 | parser.add_argument('--save_path', type=str, default=os.getcwd() + '/CAVNet_cpts/', help='the path to save models and logs')
15 | opt = parser.parse_args()
16 |
--------------------------------------------------------------------------------
/src/models/convgru.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | import torch
8 | from torch import nn
9 |
10 | class ConvGRUCell(nn.Module):
11 | """
12 | ICLR2016: Delving Deeper into Convolutional Networks for Learning Video Representations
13 | url: https://arxiv.org/abs/1511.06432
14 | """
15 | def __init__(self, input_channels, hidden_channels, kernel_size, cuda_flag=True):
16 | super(ConvGRUCell, self).__init__()
17 | self.input_channels = input_channels
18 | self.cuda_flag = cuda_flag
19 | self.hidden_channels = hidden_channels
20 | self.kernel_size = kernel_size
21 |
22 | padding = self.kernel_size // 2
23 | self.reset_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
24 | self.update_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
25 | self.output_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
26 | # init
27 | for m in self.state_dict():
28 | if isinstance(m, nn.Conv2d):
29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
30 | nn.init.constant_(m.bias, 0)
31 |
32 | def forward(self, x, hidden):
33 | if hidden is None:
34 | size_h = [x.data.size()[0], self.hidden_channels] + list(x.data.size()[2:])
35 | if self.cuda_flag:
36 | hidden = torch.zeros(size_h).cuda()
37 | else:
38 | hidden = torch.zeros(size_h)
39 |
40 | inputs = torch.cat((x, hidden), dim=1)
41 | reset_gate = torch.sigmoid(self.reset_gate(inputs))
42 | update_gate = torch.sigmoid(self.update_gate(inputs))
43 |
44 | reset_hidden = reset_gate * hidden
45 | reset_inputs = torch.tanh(self.output_gate(torch.cat((x, reset_hidden), dim=1)))
46 | new_hidden = (1 - update_gate)*reset_inputs + update_gate*hidden
47 |
48 | return new_hidden
--------------------------------------------------------------------------------
/video_seq_link:
--------------------------------------------------------------------------------
1 | Class-Sequence link
2 | Speaking-French https://www.youtube.com/watch?v=IRG9Z7Y2uS4
3 | Speaking-WaitingRoom https://www.youtube.com/watch?v=ey9J7w98wlI
4 | Speaking-Cooking https://www.youtube.com/watch?v=MzcdEI-tSUc
5 | Speaking-AudiIntro https://www.youtube.com/watch?v=72f3ayGhMEA
6 | Speaking-Ellen https://www.youtube.com/watch?v=-SbZ4_ir148
7 | Speaking-GroveAction https://www.youtube.com/watch?v=0cfJOmUaNNI
8 | Speaking-Warehouse https://www.youtube.com/watch?v=1LM84FSzW0g
9 | Speaking-GroveConvo https://www.youtube.com/watch?v=0cfJOmUaNNI
10 | Speaking-Surfing https://www.youtube.com/watch?v=SZYXQ-6bfiQ
11 | Speaking-Passageway https://www.youtube.com/watch?v=jb5YxiXIsjU
12 | Speaking-RuralDriving https://www.youtube.com/watch?v=72f3ayGhMEA
13 | Speaking-Lawn https://www.youtube.com/watch?v=ZuXCMpVR24I
14 | Speaking-AudiAd https://www.youtube.com/watch?v=72f3ayGhMEA
15 | Speaking-ScenePlay https://www.youtube.com/watch?v=Gq4Y4gL3zSg
16 | Speaking-UrbanDriving https://www.youtube.com/watch?v=n524y8uPUaU
17 | Speaking-Interview https://www.youtube.com/watch?v=UpFZ2YaKeqM
18 | Speaking-Telephone https://www.youtube.com/watch?v=idLVnagjl_s
19 | Speaking-Walking https://www.youtube.com/watch?v=1LM84FSzW0g
20 | Speaking-Bridge https://www.youtube.com/watch?v=1LM84FSzW0g
21 | Speaking-Breakfast https://www.youtube.com/watch?v=4SilhsTuDU0
22 | Speaking-Debate https://www.youtube.com/watch?v=-SbZ4_ir148
23 | Speaking-BadmintonConvo https://www.youtube.com/watch?v=Ngj6C_RMK1g
24 | Speaking-Director https://www.youtube.com/watch?v=J0Q4L68o3xE
25 | Speaking-ChineseAd https://www.youtube.com/watch?v=J0Q4L68o3xE
26 | Speaking-Exhibition https://www.youtube.com/watch?v=RbgxpagCY_c
27 | Speaking-PianoConvo https://www.youtube.com/watch?v=Pxe920CL0GY
28 | Speaking-FilmingSite https://www.youtube.com/watch?v=MYmMZxmSc1U
29 | Speaking-Brothers https://www.youtube.com/watch?v=Uy5LTocHmoA
30 | Speaking-Rap https://www.youtube.com/watch?v=g4fQ5iOVzsI
31 | Speaking-Spanish https://www.youtube.com/watch?v=JBrp8aG4lro
32 | Speaking-Questions https://www.youtube.com/watch?v=ey9J7w98wlI
33 | Speaking-PianoMono https://www.youtube.com/watch?v=Pxe920CL0GY
34 | Speaking-Snowfield https://www.youtube.com/watch?v=o7JEBWV4CmY
35 | Speaking-Melodrama https://www.youtube.com/watch?v=5h95uTtPeck
36 | Speaking-Gymnasium https://www.youtube.com/watch?v=kZB3KMhqqyI
37 | Music-Guitar https://www.youtube.com/watch?v=4fxKBGthpaw
38 | Music-Subway https://www.youtube.com/watch?v=bO43msZTfwA
39 | Music-Jazz https://www.youtube.com/watch?v=SJIbpqgYWGw
40 | Music-Bass https://www.youtube.com/watch?v=baRj0O4cqgI
41 | Music-Canon https://www.youtube.com/watch?v=MFVmxoXgeNQ
42 | Music-MICOSinging https://www.youtube.com/watch?v=gTB1nfK-0Ac
43 | Music-Clarinet https://www.youtube.com/watch?v=TCUsegqBZ_M
44 | Music-Trumpet https://www.youtube.com/watch?v=eGGFGota5_A
45 | Music-PianoSaxophone https://www.youtube.com/watch?v=MDwhMMSFkJ8
46 | Music-Chorus https://www.youtube.com/watch?v=dBM3eM9HOoA
47 | Music-Studio https://www.youtube.com/watch?v=6QUCaLvQ_3I
48 | Music-Church https://www.youtube.com/watch?v=SdGCX2H-_Uk
49 | Music-Duet https://www.youtube.com/watch?v=Oak26yVbibQ
50 | Music-Blues https://www.youtube.com/watch?v=V3zp7XOGBhs
51 | Music-Violins https://www.youtube.com/watch?v=Bvu9m__ZX60
52 | Music-SingingDancing https://www.youtube.com/watch?v=1An41lDIJ6Q
53 | Miscellanea-Beach https://www.youtube.com/watch?v=RSYbTSTz91g
54 | Miscellanea-BadmintonGym https://www.youtube.com/watch?v=Ngj6C_RMK1g
55 | Miscellanea-InVehicle https://www.youtube.com/watch?v=P4KyjvsceZQ
56 | Miscellanea-Japanese https://www.youtube.com/watch?v=0suxwissusc
57 | Miscellanea-Tennis https://www.youtube.com/watch?v=nZJGt3ZVg3g
58 | Miscellanea-Diesel https://www.youtube.com/watch?v=eqmjLZGZ36k
59 | Miscellanea-Park https://www.youtube.com/watch?v=LThm7VYvwxY
60 | Miscellanea-Lion https://www.youtube.com/watch?v=nDu57CGqbLM
61 | Miscellanea-Carriage https://www.youtube.com/watch?v=gy4TI-6j5po
62 | Miscellanea-Platform https://www.youtube.com/watch?v=G8pABGosD38
63 | Miscellanea-Dog https://www.youtube.com/watch?v=I-43DzvhxX8
64 | Miscellanea-RacingCar https://www.youtube.com/watch?v=72f3ayGhMEA
65 | Miscellanea-Train https://www.youtube.com/watch?v=ByBF08H-wDA
66 | Miscellanea-Football https://www.youtube.com/watch?v=HNQMF7e6IL0
67 | Miscellanea-ParkingLot https://www.youtube.com/watch?v=72f3ayGhMEA
68 | Miscellanea-Skiing https://www.youtube.com/watch?v=RrGhiInqXhc
69 |
--------------------------------------------------------------------------------
/src/CAVNet_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os, argparse
4 | import cv2
5 | from data import dataset_inference, clip_length
6 |
7 | from models.CAVNet import cavnet
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--testsize', type=int, default=416, help='testing size')
11 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
12 | parser.add_argument('--test_path', type=str, default=os.getcwd() + '/data/', help='test dataset path')
13 | parser.add_argument('--forward_iter', type=int, default=5, help='sample times')
14 | opt = parser.parse_args()
15 |
16 | dataset_path = opt.test_path
17 |
18 | # set device for test
19 | if opt.gpu_id == '0':
20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
21 | print('USE GPU 0')
22 | elif opt.gpu_id == '1':
23 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
24 | print('USE GPU 1')
25 |
26 | # load the model
27 | model = cavnet()
28 | model.load_state_dict(torch.load(os.getcwd() + '/model_test/CAVNet_final.pth'))
29 | model.cuda()
30 | model.eval()
31 |
32 | #test
33 | TIME = []
34 | save_path_pred = './predictions/'
35 | save_path_sigma = './uncertainties/'
36 | if not os.path.exists(save_path_pred): os.makedirs(save_path_pred)
37 | if not os.path.exists(save_path_sigma): os.makedirs(save_path_sigma)
38 | root = os.path.join(dataset_path, 'PAVS10K_seqs_test.txt')
39 | test_loader = dataset_inference(root, opt.testsize)
40 | for i in range(test_loader.size):
41 | imgs, gts, audios, seq_name, frm_names = test_loader.load_data()
42 | for idx in range(clip_length):
43 | gts[idx] = np.asarray(gts[idx], np.float32)
44 | gts[idx] /= (gts[idx].max() + 1e-8)
45 | imgs[idx] = imgs[idx].cuda()
46 | audios = audios.cuda()
47 |
48 | # multiple forward
49 | pred1, pred2, pred3, ent1, ent2, ent3 = [], [], [], [], [], []
50 | with torch.no_grad():
51 | for ff in range(opt.forward_iter):
52 | pred_curr = model(imgs, audios)
53 | p1, p2, p3 = torch.sigmoid(pred_curr[0]), torch.sigmoid(pred_curr[1]), torch.sigmoid(pred_curr[2])
54 |
55 | pred1.append(p1)
56 | pred2.append(p2)
57 | pred3.append(p3)
58 | ent1.append(-1 * p1 * torch.log(p1 + 1e-8))
59 | ent2.append(-1 * p2 * torch.log(p2 + 1e-8))
60 | ent3.append(-1 * p3 * torch.log(p3 + 1e-8))
61 |
62 | pred1_c, pred2_c, pred3_c = torch.cat(pred1, dim=1), torch.cat(pred2, dim=1), torch.cat(pred3, dim=1)
63 | pred1_mu, pred2_mu, pred3_mu = torch.mean(pred1_c, 1, keepdim=True), torch.mean(pred2_c, 1, keepdim=True), \
64 | torch.mean(pred3_c, 1, keepdim=True)
65 | ent1_c, ent2_c, ent3_c = torch.cat(ent1, dim=1), torch.cat(ent2, dim=1), torch.cat(ent3, dim=1)
66 | ent1_mu, ent2_mu, ent3_mu = torch.mean(ent1_c, 1, keepdim=True), torch.mean(ent2_c, 1, keepdim=True), \
67 | torch.mean(ent3_c, 1, keepdim=True)
68 |
69 | # before save results
70 | pred1_mu, pred2_mu, pred3_mu = pred1_mu.data.cpu().numpy().squeeze(), pred2_mu.data.cpu().numpy().squeeze(),\
71 | pred3_mu.data.cpu().numpy().squeeze()
72 | ent1_mu, ent2_mu, ent3_mu = ent1_mu.data.cpu().numpy().squeeze(), ent2_mu.data.cpu().numpy().squeeze(), \
73 | ent3_mu.data.cpu().numpy().squeeze()
74 |
75 | pred1_mu = (pred1_mu - pred1_mu.min()) / (pred1_mu.max() - pred1_mu.min() + 1e-8)
76 | pred2_mu = (pred2_mu - pred2_mu.min()) / (pred2_mu.max() - pred2_mu.min() + 1e-8)
77 | pred3_mu = (pred3_mu - pred3_mu.min()) / (pred3_mu.max() - pred3_mu.min() + 1e-8)
78 | ent1_mu = 255 * (ent1_mu - ent1_mu.min()) / (ent1_mu.max() - ent1_mu.min() + 1e-8)
79 | ent2_mu = 255 * (ent2_mu - ent2_mu.min()) / (ent2_mu.max() - ent2_mu.min() + 1e-8)
80 | ent3_mu = 255 * (ent3_mu - ent3_mu.min()) / (ent3_mu.max() - ent3_mu.min() + 1e-8)
81 |
82 | uc1, uc2, uc3 = ent1_mu.astype(np.uint8), ent2_mu.astype(np.uint8), ent3_mu.astype(np.uint8)
83 | uc1, uc2, uc3 = cv2.applyColorMap(uc1, cv2.COLORMAP_JET), cv2.applyColorMap(uc2, cv2.COLORMAP_JET), \
84 | cv2.applyColorMap(uc3, cv2.COLORMAP_JET)
85 |
86 | # save results
87 | curr_save_pth_pred = os.path.join(save_path_pred, seq_name)
88 | curr_save_pth_sigma = os.path.join(save_path_sigma, seq_name)
89 | if not os.path.exists(curr_save_pth_pred): os.makedirs(curr_save_pth_pred)
90 | if not os.path.exists(curr_save_pth_sigma): os.makedirs(curr_save_pth_sigma)
91 |
92 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[0]))
93 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[0]), pred1_mu * 255)
94 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[1]))
95 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[1]), pred2_mu * 255)
96 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[2]))
97 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[2]), pred3_mu * 255)
98 |
99 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[0]))
100 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[0]), uc1)
101 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[1]))
102 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[1]), uc2)
103 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[2]))
104 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[2]), uc3)
105 |
106 | print('Test Done!')
107 |
--------------------------------------------------------------------------------
/src/models/non_local_dot_product.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/env python
3 | # coding: utf-8
4 | #
5 | # Author: AlexHex7
6 | # URL: https://github.com/AlexHex7/Non-local_pytorch
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 |
13 | class _NonLocalBlockND(nn.Module):
14 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
15 | super(_NonLocalBlockND, self).__init__()
16 |
17 | assert dimension in [1, 2, 3]
18 |
19 | self.dimension = dimension
20 | self.sub_sample = sub_sample
21 |
22 | self.in_channels = in_channels
23 | self.inter_channels = inter_channels
24 |
25 | if self.inter_channels is None:
26 | self.inter_channels = in_channels // 2
27 | if self.inter_channels == 0:
28 | self.inter_channels = 1
29 |
30 | if dimension == 3:
31 | conv_nd = nn.Conv3d
32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
33 | bn = nn.BatchNorm3d
34 | elif dimension == 2:
35 | conv_nd = nn.Conv2d
36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
37 | bn = nn.BatchNorm2d
38 | else:
39 | conv_nd = nn.Conv1d
40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
41 | bn = nn.BatchNorm1d
42 |
43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
44 | kernel_size=1, stride=1, padding=0)
45 |
46 | if bn_layer:
47 | self.W = nn.Sequential(
48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0),
50 | bn(self.in_channels)
51 | )
52 | nn.init.constant_(self.W[1].weight, 0)
53 | nn.init.constant_(self.W[1].bias, 0)
54 | else:
55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
56 | kernel_size=1, stride=1, padding=0)
57 | nn.init.constant_(self.W.weight, 0)
58 | nn.init.constant_(self.W.bias, 0)
59 |
60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
61 | kernel_size=1, stride=1, padding=0)
62 |
63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64 | kernel_size=1, stride=1, padding=0)
65 |
66 | if sub_sample:
67 | self.g = nn.Sequential(self.g, max_pool_layer)
68 | self.phi = nn.Sequential(self.phi, max_pool_layer)
69 |
70 | def forward(self, x):
71 | '''
72 | :param x: (b, c, t, h, w)
73 | :return:
74 | '''
75 |
76 | batch_size = x.size(0)
77 |
78 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
79 | g_x = g_x.permute(0, 2, 1)
80 |
81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
82 | theta_x = theta_x.permute(0, 2, 1)
83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
84 | f = torch.matmul(theta_x, phi_x)
85 | N = f.size(-1)
86 | f_div_C = f / N
87 |
88 | y = torch.matmul(f_div_C, g_x)
89 | y = y.permute(0, 2, 1).contiguous()
90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
91 | W_y = self.W(y)
92 | z = W_y + x
93 |
94 | return z
95 |
96 |
97 | class NONLocalBlock1D(_NonLocalBlockND):
98 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
99 | super(NONLocalBlock1D, self).__init__(in_channels,
100 | inter_channels=inter_channels,
101 | dimension=1, sub_sample=sub_sample,
102 | bn_layer=bn_layer)
103 |
104 |
105 | class NONLocalBlock2D(_NonLocalBlockND):
106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
107 | super(NONLocalBlock2D, self).__init__(in_channels,
108 | inter_channels=inter_channels,
109 | dimension=2, sub_sample=sub_sample,
110 | bn_layer=bn_layer)
111 |
112 |
113 | class NONLocalBlock3D(_NonLocalBlockND):
114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
115 | super(NONLocalBlock3D, self).__init__(in_channels,
116 | inter_channels=inter_channels,
117 | dimension=3, sub_sample=sub_sample,
118 | bn_layer=bn_layer)
119 |
120 |
121 | if __name__ == '__main__':
122 | import torch
123 |
124 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
125 | img = torch.zeros(2, 3, 20)
126 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
127 | out = net(img)
128 | print(out.size())
129 |
130 | img = torch.zeros(2, 3, 20, 20)
131 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
132 | out = net(img)
133 | print(out.size())
134 |
135 | img = torch.randn(2, 3, 8, 20, 20)
136 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
137 | out = net(img)
138 | print(out.size())
--------------------------------------------------------------------------------
/src/CAVNet_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from datetime import datetime
5 | from data import get_loader, test_dataset, clip_length
6 | from utils import clip_gradient
7 | import logging
8 | import torch.backends.cudnn as cudnn
9 | from options import opt
10 | from utils import print_network, structure_loss, linear_annealing, l2_regularisation
11 | import cv2
12 |
13 | from models.CAVNet import cavnet
14 |
15 |
16 | #set the device for training
17 | if opt.gpu_id == '0':
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 | print('USE GPU 0')
20 | elif opt.gpu_id == '1':
21 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
22 | print('USE GPU 1')
23 | cudnn.benchmark = True
24 |
25 | #build the model
26 | model = cavnet()
27 | print_network(model, 'CAVNet')
28 | if(opt.load is not None):
29 | model.load_state_dict(torch.load(opt.load))
30 | print('load model from ', opt.load)
31 | model.cuda()
32 |
33 | params = model.parameters()
34 | optimizer = torch.optim.Adam(params, opt.lr)
35 |
36 | #set the path
37 | tr_root = opt.tr_root
38 | te_root = opt.te_root
39 | save_path = opt.save_path
40 |
41 | if not os.path.exists(save_path):
42 | os.makedirs(save_path)
43 |
44 | #load data
45 | print('load data...')
46 | train_loader = get_loader(tr_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
47 | test_loader = test_dataset(te_root, opt.trainsize)
48 | total_step = len(train_loader)
49 |
50 | logging.basicConfig(filename=save_path+'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]',
51 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p')
52 | logging.info("CAVNet-Train")
53 | logging.info("Config")
54 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};load:{};save_path:{}'.
55 | format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.load,save_path))
56 |
57 | step = 0
58 | best_mae = 1
59 | best_epoch = 0
60 |
61 | #train function
62 | def train(train_loader, model, optimizer, epoch, save_path):
63 | global step
64 | model.train()
65 | loss_all = 0
66 | epoch_step = 0
67 | try:
68 | for i, (imgs, gts, audios, seq_name) in enumerate(train_loader, start=1):
69 | optimizer.zero_grad()
70 |
71 | audios = audios.cuda()
72 | for idx in range(clip_length): # default as three consecutive frames
73 | imgs[idx], gts[idx] = imgs[idx].cuda(), gts[idx].cuda()
74 |
75 | # debug
76 | # cv2.imwrite('img.png', imgs[1][0].permute(1, 2, 0).cpu().data.numpy() * 255)
77 | # cv2.imwrite('er_img.png', er_imgs[1][0].permute(1, 2, 0).cpu().data.numpy() * 255)
78 | # cv2.imwrite('gt.png', gts[1][0].permute(1, 2, 0).cpu().data.numpy() * 255)
79 | # cv2.imwrite('aem.png', aems[1][0].permute(1, 2, 0).cpu().data.numpy() * 255)
80 | # cv2.imwrite('cube1.png', cube_gts[1][0].permute(1, 2, 0).cpu().data.numpy() * 255)
81 | # cv2.imwrite('cube2.png', cube_gts[1][1].permute(1, 2, 0).cpu().data.numpy() * 255)
82 | # cv2.imwrite('cube3.png', cube_gts[1][2].permute(1, 2, 0).cpu().data.numpy() * 255)
83 | # cv2.imwrite('cube4.png', cube_gts[1][3].permute(1, 2, 0).cpu().data.numpy() * 255)
84 | # cv2.imwrite('cube5.png', cube_gts[1][4].permute(1, 2, 0).cpu().data.numpy() * 255)
85 | # cv2.imwrite('cube6.png', cube_gts[1][5].permute(1, 2, 0).cpu().data.numpy() * 255)
86 | # torchaudio.save('audio.wav', audios[0].cpu(), 48000)
87 |
88 | preds_prior, preds_post, lat_loss = model(imgs, audios, gts)
89 |
90 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch)
91 | reg_loss = (l2_regularisation(model.enc_mm_x) + l2_regularisation(model.enc_mm_xy) + \
92 | l2_regularisation(model.backbone_dec_prior) + l2_regularisation(model.backbone_dec_post)) * 1e-4
93 | loss_list = []
94 | for rr in range(clip_length): # prior
95 | loss_list.append(structure_loss(preds_prior[rr], gts[rr]))
96 | for pp in range(clip_length): # post
97 | loss_list.append(structure_loss(preds_post[pp], gts[pp]))
98 | for cc in range(clip_length): # latent loss
99 | loss_list.append(lat_loss[cc] * anneal_reg * opt.lat_weight)
100 | loss = sum(loss_list) / (clip_length * 3)
101 | loss = loss + reg_loss
102 |
103 | loss_vis_prior, loss_vis_post, loss_vis_lat = sum(loss_list[:3]) / clip_length, \
104 | sum(loss_list[3:6]) / clip_length, sum(loss_list[6:]) / clip_length
105 |
106 | loss.backward()
107 |
108 | clip_gradient(optimizer, opt.clip)
109 | optimizer.step()
110 | step += 1
111 | epoch_step += 1
112 | loss_all += loss.data
113 | if i % 500 == 0 or i == total_step or i == 1:
114 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, '
115 | 'LossPri: {:.4f}, LossPos: {:.4f}, LossLat: {:.4f}'.
116 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.data,
117 | loss_vis_prior.data, loss_vis_post.data, loss_vis_lat.data))
118 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, '
119 | 'LossPri: {:.4f}, LossPos: {:.4f}, LossLat: {:.4f}'.
120 | format(epoch, opt.epoch, i, total_step, loss.data,
121 | loss_vis_prior.data, loss_vis_post.data, loss_vis_lat.data))
122 |
123 | loss_all /= epoch_step
124 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all))
125 | if (epoch) % 5 == 0:
126 | torch.save(model.state_dict(), save_path+'CAVNet_epoch_{}.pth'.format(epoch))
127 | except KeyboardInterrupt:
128 | print('Keyboard Interrupt: save model and exit.')
129 | if not os.path.exists(save_path):
130 | os.makedirs(save_path)
131 | torch.save(model.state_dict(), save_path+'CAVNet_epoch_{}.pth'.format(epoch+1))
132 | print('save checkpoints successfully!')
133 | raise
134 |
135 |
136 | if __name__ == '__main__':
137 | print("Start train...")
138 | for epoch in range(1, opt.epoch):
139 | train(train_loader, model, optimizer, epoch, save_path)
140 |
--------------------------------------------------------------------------------
/src/models/rcrnet_vit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from models.resnet_dilation import Bottleneck, conv1x1
14 | from models.blocks import (
15 | _make_encoder,
16 | forward_vit,
17 | )
18 |
19 | class _ConvBatchNormReLU(nn.Sequential):
20 | def __init__(self,
21 | in_channels,
22 | out_channels,
23 | kernel_size,
24 | stride,
25 | padding,
26 | dilation,
27 | relu=True,
28 | ):
29 | super(_ConvBatchNormReLU, self).__init__()
30 | self.add_module(
31 | "conv",
32 | nn.Conv2d(
33 | in_channels=in_channels,
34 | out_channels=out_channels,
35 | kernel_size=kernel_size,
36 | stride=stride,
37 | padding=padding,
38 | dilation=dilation,
39 | bias=False,
40 | ),
41 | )
42 | #self.add_module(
43 | # "bn",
44 | # nn.BatchNorm2d(out_channels),
45 | #)
46 |
47 | if relu:
48 | self.add_module("relu", nn.ReLU())
49 |
50 | def forward(self, x):
51 | return super(_ConvBatchNormReLU, self).forward(x)
52 |
53 | class _ASPPModule(nn.Module):
54 | """Atrous Spatial Pyramid Pooling with image pool"""
55 |
56 | def __init__(self, in_channels, out_channels, output_stride):
57 | super(_ASPPModule, self).__init__()
58 | if output_stride == 8:
59 | pyramids = [12, 24, 36]
60 | elif output_stride == 16:
61 | pyramids = [6, 12, 18]
62 | self.stages = nn.Module()
63 | self.stages.add_module(
64 | "c0", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)
65 | )
66 | for i, (dilation, padding) in enumerate(zip(pyramids, pyramids)):
67 | self.stages.add_module(
68 | "c{}".format(i + 1),
69 | _ConvBatchNormReLU(in_channels, out_channels, 3, 1, padding, dilation),
70 | )
71 | self.imagepool = nn.Sequential(
72 | OrderedDict(
73 | [
74 | ("pool", nn.AdaptiveAvgPool2d((1,1))),
75 | ("conv", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)),
76 | ]
77 | )
78 | )
79 | self.fire = nn.Sequential(
80 | OrderedDict(
81 | [
82 | ("conv", _ConvBatchNormReLU(out_channels * 5, out_channels, 3, 1, 1, 1)),
83 | ("dropout", nn.Dropout2d(0.1))
84 | ]
85 | )
86 | )
87 |
88 | def forward(self, x):
89 | h = self.imagepool(x)
90 | h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)]
91 | for stage in self.stages.children():
92 | h += [stage(x)]
93 | h = torch.cat(h, dim=1)
94 | h = self.fire(h)
95 |
96 | return h
97 |
98 | class _RefinementModule(nn.Module):
99 | """ Reduce channels and refinment module"""
100 |
101 | def __init__(self,
102 | bottom_up_channels,
103 | reduce_channels,
104 | top_down_channels,
105 | refinement_channels,
106 | expansion=2
107 | ):
108 | super(_RefinementModule, self).__init__()
109 | downsample = None
110 | if bottom_up_channels != reduce_channels:
111 | downsample = nn.Sequential(
112 | conv1x1(bottom_up_channels, reduce_channels),
113 | nn.BatchNorm2d(reduce_channels),
114 | )
115 | self.skip = Bottleneck(bottom_up_channels, reduce_channels // expansion, 1, 1, downsample, expansion)
116 | self.refine = _ConvBatchNormReLU(reduce_channels + top_down_channels, refinement_channels, 3, 1, 1, 1)
117 |
118 | def forward(self, td, bu):
119 | td = self.skip(td)
120 | x = torch.cat((bu, td), dim=1)
121 | x = self.refine(x)
122 |
123 | return x
124 |
125 | class RCRNet_vit(nn.Module):
126 |
127 | def __init__(self, n_classes, output_stride):
128 | super(RCRNet_vit, self).__init__()
129 |
130 | # vit encoder
131 | hooks = {
132 | "vitb_rn50_384": [0, 1, 8, 11],
133 | "vitb16_384": [2, 5, 8, 11],
134 | "vitl16_384": [5, 11, 17, 23],
135 | }
136 | self.pretrained = _make_encoder(
137 | "vitb_rn50_384",
138 | 256,
139 | False, # the pre-train model will be loaded anyway
140 | groups=1,
141 | expand=False,
142 | exportable=False,
143 | hooks=hooks["vitb_rn50_384"],
144 | use_readout="project",
145 | enable_attention_hooks=False,
146 | )
147 |
148 | self.aspp = _ASPPModule(768, 256, output_stride)
149 | # Decoder
150 | self.decoder = nn.Sequential(
151 | OrderedDict(
152 | [
153 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)),
154 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)),
155 | ]
156 | )
157 | )
158 | self.add_module("refinement1", _RefinementModule(768, 96, 256, 128, 2))
159 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2))
160 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2))
161 |
162 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
163 |
164 | #if pretrained:
165 | # for key in self.state_dict():
166 | # if 'resnet' not in key:
167 | # self.init_layer(key)
168 |
169 | # def init_layer(self, key):
170 | # if key.split('.')[-1] == 'weight':
171 | # if 'conv' in key:
172 | # if self.state_dict()[key].ndimension() >= 2:
173 | # nn.init.kaiming_normal_(self.state_dict()[key], mode='fan_out', nonlinearity='relu')
174 | # elif 'bn' in key:
175 | # self.state_dict()[key][...] = 1
176 | # elif key.split('.')[-1] == 'bias':
177 | # self.state_dict()[key][...] = 0.001
178 |
179 | def feat_conv(self, x):
180 | '''
181 | Spatial feature extractor
182 | '''
183 | block1, block2, block3, block4 = forward_vit(self.pretrained, x)
184 | block4 = self.aspp(block4)
185 |
186 | return block1, block2, block3, block4
187 |
188 | def seg_conv(self, block1, block2, block3, block4, shape):
189 | '''
190 | Pixel-wise classifer
191 | '''
192 | block4 = self.upsample2(block4)
193 |
194 | bu1 = self.refinement1(block3, block4)
195 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False)
196 | bu2 = self.refinement2(block2, bu1)
197 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False)
198 | bu3 = self.refinement3(block1, bu2)
199 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False)
200 | seg = self.decoder(bu3)
201 |
202 | return seg
203 |
204 | def forward(self, x):
205 | block1, block2, block3, block4 = self.feat_conv(x)
206 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:])
207 |
208 | return seg
--------------------------------------------------------------------------------
/src/models/resnet_dilation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # This code is based on torchvison resnet
5 | # URL: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
6 |
7 | import torch.nn as nn
8 | import torch.utils.model_zoo as model_zoo
9 |
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=padding, dilation=dilation, bias=False)
28 |
29 |
30 | def conv1x1(in_planes, out_planes, stride=1):
31 | """1x1 convolution"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride, dilation, downsample=None):
39 | super(BasicBlock, self).__init__()
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = nn.BatchNorm2d(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes, 1, dilation, dilation)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 |
70 | def __init__(self, inplanes, planes, stride, dilation, downsample=None, expansion=4):
71 | super(Bottleneck, self).__init__()
72 | self.expansion = expansion
73 | self.conv1 = conv1x1(inplanes, planes)
74 | self.bn1 = nn.BatchNorm2d(planes)
75 | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = conv1x1(planes, planes * self.expansion)
78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 |
108 | def __init__(self, block, layers, output_stride, num_classes=1000, input_channels=3):
109 | super(ResNet, self).__init__()
110 | if output_stride == 8:
111 | stride = [1, 2, 1, 1]
112 | dilation = [1, 1, 2, 2]
113 | elif output_stride == 16:
114 | stride = [1, 2, 2, 1]
115 | dilation = [1, 1, 1, 2]
116 |
117 | self.inplanes = 64
118 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
119 | bias=False)
120 | self.bn1 = nn.BatchNorm2d(64)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
123 | self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0], dilation=dilation[0])
124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1], dilation=dilation[1])
125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2], dilation=dilation[2])
126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3], dilation=dilation[3])
127 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
128 | self.fc = nn.Linear(512 * block.expansion, num_classes)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
133 | elif isinstance(m, nn.BatchNorm2d):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride, dilation):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | conv1x1(self.inplanes, planes * block.expansion, stride),
142 | nn.BatchNorm2d(planes * block.expansion),
143 | )
144 |
145 | layers = []
146 | layers.append(block(self.inplanes, planes, stride, dilation, downsample))
147 | self.inplanes = planes * block.expansion
148 | for _ in range(1, blocks):
149 | layers.append(block(self.inplanes, planes, 1, dilation))
150 |
151 | return nn.Sequential(*layers)
152 |
153 | def forward(self, x):
154 | x = self.conv1(x)
155 | x = self.bn1(x)
156 | x = self.relu(x)
157 | x = self.maxpool(x)
158 |
159 | x = self.layer1(x)
160 | x = self.layer2(x)
161 | x = self.layer3(x)
162 | x = self.layer4(x)
163 |
164 | x = self.avgpool(x)
165 | x = x.view(x.size(0), -1)
166 | x = self.fc(x)
167 |
168 | return x
169 |
170 |
171 | def resnet18(pretrained=False, **kwargs):
172 | """Constructs a ResNet-18 model.
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | """
176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
177 | if pretrained:
178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
179 | return model
180 |
181 |
182 | def resnet34(pretrained=False, **kwargs):
183 | """Constructs a ResNet-34 model.
184 | Args:
185 | pretrained (bool): If True, returns a model pre-trained on ImageNet
186 | """
187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
188 | if pretrained:
189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
190 | return model
191 |
192 |
193 | def resnet50(pretrained=False, **kwargs):
194 | """Constructs a ResNet-50 model.
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | """
198 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
199 | if pretrained:
200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
201 | return model
202 |
203 |
204 | def resnet101(pretrained=False, **kwargs):
205 | """Constructs a ResNet-101 model.
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
210 | if pretrained:
211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
212 | return model
213 |
214 |
215 | def resnet152(pretrained=False, **kwargs):
216 | """Constructs a ResNet-152 model.
217 | Args:
218 | pretrained (bool): If True, returns a model pre-trained on ImageNet
219 | """
220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
221 | if pretrained:
222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
223 | return model
224 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [PAV-SOD: A New Task Towards Panoramic Audiovisual Saliency Detection (TOMM 2022)](https://drive.google.com/file/d/1-1RcARcbz4pACFzkjXcp6MP8R9CGScqI/view?usp=sharing)
2 |
3 | Object-level audiovisual saliency detection in 360° panoramic real-life dynamic scenes is important for exploring and modeling human perception in immersive environments, also for aiding the development of virtual, augmented and mixed reality applications in the fields of such as education, social network, entertainment and training. To this end, we propose a new task, panoramic audiovisual salient object detection (PAV-SOD), which aims to segment the objects grasping most of the human attention in 360° panoramic videos reflecting real-life daily scenes. To support the task, we collect PAVS10K, the first panoramic video dataset for audiovisual salient object detection, which consists of 67 4K-resolution equirectangular videos with per-video labels including hierarchical scene categories and associated attributes depicting specific challenges for conducting PAV-SOD, and 10,465 uniformly sampled video frames with manually annotated object-level and instance-level pixel-wise masks. The coarse-to-fine annotations enable multi-perspective analysis regarding PAV-SOD modeling. We further systematically benchmark 13 state-of-the-art salient object detection (SOD)/video object segmentation (VOS) methods based on our PAVS10K. Besides, we propose a new baseline network, which takes advantage of both visual and audio cues of 360° video frames by using a new conditional variational auto-encoder (CVAE). Our CVAE-based audiovisual network, namely CAV-Net, consists of a spatial-temporal visual segmentation network, a convolutional audio-encoding network and audiovisual distribution estimation modules. As a result, our CAV-Net outperforms all competing models and is able to estimate the aleatoric uncertainties within PAVS10K. With extensive experimental results, we gain several findings about PAV-SOD challenges and insights towards PAV-SOD model interpretability. We hope that our work could serve as a starting point for advancing SOD towards immersive media.
4 |
5 | ------
6 |
7 | # PAVS10K
8 |
9 |
10 |
11 |
12 | Figure 1: An example of our PAVS10K where coarse-to-fine annotations are provided, based on a guidance of fixations acquired from subjective experiments conducted by multiple (N) subjects wearing Head-Mounted Displays (HMDs) and headphones. Each (e.g., fk, fl and fn, where random integral values {k, l, n} ∈ [1, T ]) of the total equirectangular (ER) video frames T of the sequence “Speaking”(Super-class)-“Brothers”(sub-class) are manually labeled with both object-level and instance-level pixel-wise masks. According to the features of defined salient objects within each of the sequences, multiple attributes, e.g., “multiple objects” (MO), “competing sounds” (CS), “geometrical distortion” (GD), “motion blur” (MB), “occlusions” (OC) and “low resolution” (LR) are further annotated to enable detailed analysis for PAV-SOD modeling.
13 |
14 |
15 |
16 |
17 |
18 |
19 | Figure 2: Summary of widely used salient object detection (SOD)/video object segmentation (VOS) datasets and PAVS10K. #Img: The number of images/video frames. #GT: The number of object-level pixel-wise masks (ground truth for SOD). Pub. = Publication. Obj.-Level = Object-Level Labels. Ins.-Level = Instance-Level Labels. Fix. GT = Fixation Maps. † denotes equirectangular images.
20 |
21 |
22 |
23 |
24 |
25 |
26 | Figure 3: Examples of challenging attributes on equirectangular images from our PAVS10K, with instance-level ground truth and fixations as annotation guidance. {𝑓𝑘, 𝑓𝑙, 𝑓𝑛} denote random frames of a given video.
27 |
28 |
29 |
30 |
31 |
32 |
33 | Figure 4: Statistics of the proposed PAVS10K. (a) Super-/sub-category information. (b) Instance density (labeled frames per sequence) of each sub-class. (c) Sound sources of PAVS10K scenes, such as musical instruments, human instances and animals.
34 |
35 |
36 |
37 | ------
38 |
39 | # Benchmark Models
40 |
41 | **No.** | **Year** | **Pub.** | **Title** | **Links**
42 | :-: | :-:| :-: | :- | :-:
43 | 01 | **2019**| **CVPR** | Cascaded Partial Decoder for Fast and Accurate Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_Cascaded_Partial_Decoder_for_Fast_and_Accurate_Salient_Object_Detection_CVPR_2019_paper.pdf)/[Code](https://github.com/wuzhe71/CPD)
44 | 02 | **2019**| **CVPR** | See More, Know More: Unsupervised Video Object Segmentation with Co-Attention Siamese Networks | [Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_See_More_Know_More_Unsupervised_Video_Object_Segmentation_With_Co-Attention_CVPR_2019_paper.pdf)/[Code](https://github.com/carrierlxk/COSNet)
45 | 03 | **2019**| **ICCV** | Stacked Cross Refinement Network for Edge-Aware Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Wu_Stacked_Cross_Refinement_Network_for_Edge-Aware_Salient_Object_Detection_ICCV_2019_paper.pdf)/[Code](https://github.com/wuzhe71/SCRN)
46 | 04 | **2019**| **ICCV** | Semi-Supervised Video Salient Object Detection Using Pseudo-Labels | [Paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Yan_Semi-Supervised_Video_Salient_Object_Detection_Using_Pseudo-Labels_ICCV_2019_paper.pdf)/[Code](https://github.com/Kinpzz/RCRNet-Pytorch)
47 | 05 | **2020**| **AAAI** | F³Net: Fusion, Feedback and Focus for Salient Object Detection | [Paper](https://ojs.aaai.org/index.php/AAAI/article/download/6916/6770)/[Code](https://github.com/weijun88/F3Net)
48 | 06 | **2020**| **AAAI** | Pyramid Constrained Self-Attention Network for Fast Video Salient Object Detection | [Paper](https://ojs.aaai.org/index.php/AAAI/article/view/6718/6572)/[Code](https://github.com/guyuchao/PyramidCSA)
49 | 07 | **2020**| **CVPR** | Multi-scale Interactive Network for Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Pang_Multi-Scale_Interactive_Network_for_Salient_Object_Detection_CVPR_2020_paper.pdf)/[Code](https://github.com/lartpang/MINet)
50 | 08 | **2020**| **CVPR** | Label Decoupling Framework for Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Wei_Label_Decoupling_Framework_for_Salient_Object_Detection_CVPR_2020_paper.pdf)/[Code](https://github.com/weijun88/LDF)
51 | 09 | **2020**| **ECCV** | Highly Efficient Salient Object Detection with 100K Parameters | [Paper](http://mftp.mmcheng.net/Papers/20EccvSal100k.pdf)/[Code](https://github.com/ShangHua-Gao/SOD100K)
52 | 10 | **2020**| **ECCV** | Suppress and Balance: A Simple Gated Network for Salient Object Detection | [Paper](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123470035.pdf)/[Code](https://github.com/Xiaoqi-Zhao-DLUT/GateNet-RGB-Saliency)
53 | 11 | **2020**| **BMVC** | Making a Case for 3D Convolutions for Object Segmentation in Videos | [Paper](https://www.bmvc2020-conference.com/assets/papers/0233.pdf)/[Code](https://github.com/sabarim/3DC-Seg)
54 | 12 | **2020**| **SPL** | FANet: Features Adaptation Network for 360° Omnidirectional Salient Object Detection | [Paper](https://ieeexplore.ieee.org/document/9211754)/[Code](https://github.com/DreaMKHuang/FANet)
55 | 13 | **2021**| **CVPR** | Reciprocal Transformations for Unsupervised Video Object Segmentation | [Paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ren_Reciprocal_Transformations_for_Unsupervised_Video_Object_Segmentation_CVPR_2021_paper.pdf)/[Code](https://github.com/OliverRensu/RTNet)
56 |
57 | ------
58 |
59 | # CAV-Net
60 |
61 | The codes are available at [src](https://github.com/PanoAsh/PAV-SOD/tree/main/src).
62 |
63 | The pre-trained models can be downloaded at [Google Drive](https://drive.google.com/file/d/1gNWmgmlBfJqCYE5phDuHTMFIou1TAmXs/view?usp=sharing).
64 |
65 | ------
66 |
67 | # Dataset Downloads
68 |
69 | The whole object-/instance-level ground truth with default split can be downloaded from [Google Drive](https://drive.google.com/file/d/1Whp_ftuXza8-vkjNtICdxdRebcmzcrFi/view?usp=sharing).
70 |
71 | The videos (with ambisonics) with default split can be downloaded from [Google Drive](https://drive.google.com/file/d/13FEv1yAyMmK4GkiZ2Mce6gJxQuME7vG3/view?usp=sharing).
72 |
73 | The videos (with mono sound) can be downloaded from [Google Drive](https://drive.google.com/file/d/1klJnHSiUM7Ow2LkdaLe-O6CEQ9qmdo2F/view?usp=sharing)
74 |
75 | The audio files (.wav) can be downloaded from [Google Drive](https://drive.google.com/file/d/1-jqDArcm8vBhyku3Xb8HLopG1XmpAS13/view?usp=sharing).
76 |
77 | The head movement and eye fixation data can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1EpWc7GVcGFAn5VigV3c2-ZtIZElfXPX1?usp=sharing).
78 |
79 | To generate video frames, please refer to [video_to_frames.py](https://github.com/PanoAsh/ASOD60K/blob/main/video_to_frames.py).
80 |
81 | To get access to raw videos on YouTube, please refer to [video_seq_link](https://github.com/PanoAsh/ASOD60K/blob/main/video_seq_link).
82 |
83 | > Note: The PAVS10K dataset does not own the copyright of videos. Only researchers and educators who wish to use the videos for non-commercial researches and/or educational purposes, have access to PAVS10K.
84 |
85 | ------
86 |
87 | # Citation
88 |
89 | @article{zhang2023pav,
90 | title={PAV-SOD: A New Task towards Panoramic Audiovisual Saliency Detection},
91 | author={Zhang, Yi and Chao, Fang-Yi and Hamidouche, Wassim and Deforges, Olivier},
92 | journal={ACM Transactions on Multimedia Computing, Communications and Applications},
93 | volume={19},
94 | number={3},
95 | pages={1--26},
96 | year={2023},
97 | publisher={ACM New York, NY}
98 | }
99 |
100 | ------
101 |
102 | # Contact
103 |
104 | yi23zhang.2022@gmail.com
105 | or
106 | fangyichao428@gmail.com (for details of head movement and eye fixation data).
107 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.nn.parameter import Parameter
5 | import scipy.stats as st
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | CE = torch.nn.BCEWithLogitsLoss()
12 |
13 | def l2_regularisation(m):
14 | l2_reg = None
15 |
16 | for W in m.parameters():
17 | if l2_reg is None:
18 | l2_reg = W.norm(2)
19 | else:
20 | l2_reg = l2_reg + W.norm(2)
21 |
22 | return l2_reg
23 |
24 | # linear annealing to avoid posterior collapse
25 | def linear_annealing(init, fin, step, annealing_steps):
26 | """Linear annealing of a parameter."""
27 | if annealing_steps == 0:
28 | return fin
29 | assert fin > init
30 | delta = fin - init
31 | annealed = min(init + delta * step / annealing_steps, fin)
32 |
33 | return annealed
34 |
35 | def structure_loss(pred, mask):
36 | # adaptive weighting mask
37 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
38 |
39 | # weighted binary cross entropy loss function
40 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
41 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8)
42 |
43 | pred = torch.sigmoid(pred)
44 |
45 | # weighted iou loss function
46 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
47 | union = ((pred + mask) * weit).sum(dim=(2, 3))
48 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8)
49 |
50 | return (wbce + wiou).mean()
51 |
52 | def weighted_e_loss(pred, mask):
53 | # adaptive weighting mask
54 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
55 |
56 | # weighted binary cross entropy loss function
57 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
58 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8)
59 |
60 | # weighted e loss function
61 | pred = torch.sigmoid(pred)
62 | mpred = pred.mean(dim=(2, 3)).view(pred.shape[0], pred.shape[1], 1, 1).repeat(1, 1, pred.shape[2], pred.shape[3])
63 | phiFM = pred - mpred
64 |
65 | mmask = mask.mean(dim=(2, 3)).view(mask.shape[0], mask.shape[1], 1, 1).repeat(1, 1, mask.shape[2], mask.shape[3])
66 | phiGT = mask - mmask
67 |
68 | EFM = (2.0 * phiFM * phiGT + 1e-8) / (phiFM * phiFM + phiGT * phiGT + 1e-8)
69 | QFM = (1 + EFM) * (1 + EFM) / 4.0
70 | eloss = 1.0 - QFM.mean(dim=(2, 3))
71 |
72 | # weighted iou loss function
73 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
74 | union = ((pred + mask) * weit).sum(dim=(2, 3))
75 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8)
76 |
77 | return (wbce + eloss + wiou).mean()
78 |
79 | def hybrid_e_loss(pred, mask):
80 | # adaptive weighting mask
81 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
82 |
83 | # weighted binary cross entropy loss function
84 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
85 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8)
86 |
87 | # weighted e loss function
88 | pred = torch.sigmoid(pred)
89 | mpred = pred.mean(dim=(2, 3)).view(pred.shape[0], pred.shape[1], 1, 1).repeat(1, 1, pred.shape[2], pred.shape[3])
90 | phiFM = pred - mpred
91 |
92 | mmask = mask.mean(dim=(2, 3)).view(mask.shape[0], mask.shape[1], 1, 1).repeat(1, 1, mask.shape[2], mask.shape[3])
93 | phiGT = mask - mmask
94 |
95 | EFM = (2.0 * phiFM * phiGT + 1e-8) / (phiFM * phiFM + phiGT * phiGT + 1e-8)
96 | QFM = (1 + EFM) * (1 + EFM) / 4.0
97 | eloss = 1.0 - QFM.mean(dim=(2, 3))
98 |
99 | # weighted iou loss function
100 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
101 | union = ((pred + mask) * weit).sum(dim=(2, 3))
102 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8)
103 |
104 | return (wbce + eloss + wiou).mean()
105 |
106 | def gkern(kernlen=16, nsig=3):
107 | interval = (2*nsig+1.)/kernlen
108 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
109 | kern1d = np.diff(st.norm.cdf(x))
110 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
111 | kernel = kernel_raw/kernel_raw.sum()
112 | return kernel
113 |
114 | def min_max_norm(in_):
115 | max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
116 | min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
117 | in_ = in_ - min_
118 | return in_.div(max_-min_+1e-8)
119 |
120 | class HA(nn.Module):
121 | # holistic attention module
122 | def __init__(self):
123 | super(HA, self).__init__()
124 | gaussian_kernel = np.float32(gkern(31, 4))
125 | gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
126 | self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))
127 |
128 | def forward(self, attention, x):
129 | soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
130 | soft_attention = min_max_norm(soft_attention)
131 | x = torch.mul(x, soft_attention.max(attention))
132 | return x
133 |
134 |
135 | def clip_gradient(optimizer, grad_clip):
136 | for group in optimizer.param_groups:
137 | for param in group['params']:
138 | if param.grad is not None:
139 | param.grad.data.clamp_(-grad_clip, grad_clip)
140 |
141 |
142 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
143 | decay = decay_rate ** (epoch // decay_epoch)
144 | for param_group in optimizer.param_groups:
145 | param_group['lr'] = decay * init_lr
146 | return
147 |
148 | def print_network(model, name):
149 | num_params = 0
150 | for p in model.parameters():
151 | num_params += p.numel()
152 | print(model)
153 | print(name)
154 | print("The number of parameters: {}".format(num_params))
155 |
156 |
157 | def calculate_cdf(histogram):
158 | """
159 | This method calculates the cumulative distribution function
160 | :param array histogram: The values of the histogram
161 | :return: normalized_cdf: The normalized cumulative distribution function
162 | :rtype: array
163 | """
164 | # Get the cumulative sum of the elements
165 | cdf = histogram.cumsum()
166 |
167 | # Normalize the cdf
168 | normalized_cdf = cdf / float(cdf.max())
169 |
170 | return normalized_cdf
171 |
172 | def calculate_lookup(src_cdf, ref_cdf):
173 | """
174 | This method creates the lookup table
175 | :param array src_cdf: The cdf for the source image
176 | :param array ref_cdf: The cdf for the reference image
177 | :return: lookup_table: The lookup table
178 | :rtype: array
179 | """
180 | lookup_table = np.zeros(256)
181 | lookup_val = 0
182 | for src_pixel_val in range(len(src_cdf)):
183 | lookup_val
184 | for ref_pixel_val in range(len(ref_cdf)):
185 | if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
186 | lookup_val = ref_pixel_val
187 | break
188 | lookup_table[src_pixel_val] = lookup_val
189 | return lookup_table
190 |
191 | def match_histograms(src_image, ref_image):
192 | """
193 | This method matches the source image histogram to the
194 | reference signal
195 | :param image src_image: The original source image
196 | :param image ref_image: The reference image
197 | :return: image_after_matching
198 | :rtype: image (array)
199 | """
200 | # Split the images into the different color channels
201 | # b means blue, g means green and r means red
202 | src_b, src_g, src_r = cv2.split(src_image)
203 | ref_b, ref_g, ref_r = cv2.split(ref_image)
204 |
205 | # Compute the b, g, and r histograms separately
206 | # The flatten() Numpy method returns a copy of the array c
207 | # collapsed into one dimension.
208 | src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
209 | src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
210 | src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
211 | ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
212 | ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
213 | ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
214 |
215 | # Compute the normalized cdf for the source and reference image
216 | src_cdf_blue = calculate_cdf(src_hist_blue)
217 | src_cdf_green = calculate_cdf(src_hist_green)
218 | src_cdf_red = calculate_cdf(src_hist_red)
219 | ref_cdf_blue = calculate_cdf(ref_hist_blue)
220 | ref_cdf_green = calculate_cdf(ref_hist_green)
221 | ref_cdf_red = calculate_cdf(ref_hist_red)
222 |
223 | # Make a separate lookup table for each color
224 | blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
225 | green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
226 | red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
227 |
228 | # Use the lookup function to transform the colors of the original
229 | # source image
230 | blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
231 | green_after_transform = cv2.LUT(src_g, green_lookup_table)
232 | red_after_transform = cv2.LUT(src_r, red_lookup_table)
233 |
234 | # Put the image back together
235 | image_after_matching = cv2.merge([
236 | blue_after_transform, green_after_transform, red_after_transform])
237 | image_after_matching = cv2.convertScaleAbs(image_after_matching)
238 |
239 | return image_after_matching
240 |
241 | def histogram():
242 | ref_pth = os.getcwd() + '/Ref/0013.png'
243 | ori_pth = os.getcwd() + '/ori/'
244 | fin_pth = os.getcwd() + '/prep/'
245 | ori_list = os.listdir(ori_pth)
246 | count = 0
247 | for item in ori_list:
248 | ori_img = cv2.imread(os.path.join(ori_pth, item))
249 | ref_img = cv2.imread(ref_pth)
250 | fin_img = match_histograms(ori_img, ref_img)
251 | cv2.imwrite(os.path.join(fin_pth, item), fin_img)
252 | count += 1
253 | print(count)
254 |
255 |
256 | if __name__ == '__main__':
257 | histogram()
258 |
--------------------------------------------------------------------------------
/src/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.VIT import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 |
12 | def _make_encoder(
13 | backbone,
14 | features,
15 | use_pretrained,
16 | groups=1,
17 | expand=False,
18 | exportable=True,
19 | hooks=None,
20 | use_vit_only=False,
21 | use_readout="ignore",
22 | enable_attention_hooks=False,
23 | ):
24 | if backbone == "vitl16_384":
25 | pretrained = _make_pretrained_vitl16_384(
26 | use_pretrained,
27 | hooks=hooks,
28 | use_readout=use_readout,
29 | enable_attention_hooks=enable_attention_hooks,
30 | )
31 |
32 | elif backbone == "vitb_rn50_384":
33 | pretrained = _make_pretrained_vitb_rn50_384(
34 | use_pretrained,
35 | hooks=hooks,
36 | use_vit_only=use_vit_only,
37 | use_readout=use_readout,
38 | enable_attention_hooks=enable_attention_hooks,
39 | )
40 |
41 | elif backbone == "vitb16_384":
42 | pretrained = _make_pretrained_vitb16_384(
43 | use_pretrained,
44 | hooks=hooks,
45 | use_readout=use_readout,
46 | enable_attention_hooks=enable_attention_hooks,
47 | )
48 |
49 | elif backbone == "resnext101_wsl":
50 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
51 |
52 | else:
53 | print(f"Backbone '{backbone}' not implemented")
54 | assert False
55 |
56 | return pretrained
57 |
58 |
59 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
60 | scratch = nn.Module()
61 |
62 | out_shape1 = out_shape
63 | out_shape2 = out_shape
64 | out_shape3 = out_shape
65 | out_shape4 = out_shape
66 | if expand == True:
67 | out_shape1 = out_shape
68 | out_shape2 = out_shape * 2
69 | out_shape3 = out_shape * 4
70 | out_shape4 = out_shape * 8
71 |
72 | scratch.layer1_rn = nn.Conv2d(
73 | in_shape[0],
74 | out_shape1,
75 | kernel_size=3,
76 | stride=1,
77 | padding=1,
78 | bias=False,
79 | groups=groups,
80 | )
81 | scratch.layer2_rn = nn.Conv2d(
82 | in_shape[1],
83 | out_shape2,
84 | kernel_size=3,
85 | stride=1,
86 | padding=1,
87 | bias=False,
88 | groups=groups,
89 | )
90 | scratch.layer3_rn = nn.Conv2d(
91 | in_shape[2],
92 | out_shape3,
93 | kernel_size=3,
94 | stride=1,
95 | padding=1,
96 | bias=False,
97 | groups=groups,
98 | )
99 | scratch.layer4_rn = nn.Conv2d(
100 | in_shape[3],
101 | out_shape4,
102 | kernel_size=3,
103 | stride=1,
104 | padding=1,
105 | bias=False,
106 | groups=groups,
107 | )
108 |
109 | return scratch
110 |
111 |
112 | def _make_resnet_backbone(resnet):
113 | pretrained = nn.Module()
114 | pretrained.layer1 = nn.Sequential(
115 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
116 | )
117 |
118 | pretrained.layer2 = resnet.layer2
119 | pretrained.layer3 = resnet.layer3
120 | pretrained.layer4 = resnet.layer4
121 |
122 | return pretrained
123 |
124 |
125 | def _make_pretrained_resnext101_wsl(use_pretrained):
126 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
127 | return _make_resnet_backbone(resnet)
128 |
129 |
130 | class Interpolate(nn.Module):
131 | """Interpolation module."""
132 |
133 | def __init__(self, scale_factor, mode, align_corners=False):
134 | """Init.
135 |
136 | Args:
137 | scale_factor (float): scaling
138 | mode (str): interpolation mode
139 | """
140 | super(Interpolate, self).__init__()
141 |
142 | self.interp = nn.functional.interpolate
143 | self.scale_factor = scale_factor
144 | self.mode = mode
145 | self.align_corners = align_corners
146 |
147 | def forward(self, x):
148 | """Forward pass.
149 |
150 | Args:
151 | x (tensor): input
152 |
153 | Returns:
154 | tensor: interpolated data
155 | """
156 |
157 | x = self.interp(
158 | x,
159 | scale_factor=self.scale_factor,
160 | mode=self.mode,
161 | align_corners=self.align_corners,
162 | )
163 |
164 | return x
165 |
166 |
167 | class ResidualConvUnit(nn.Module):
168 | """Residual convolution module."""
169 |
170 | def __init__(self, features):
171 | """Init.
172 |
173 | Args:
174 | features (int): number of features
175 | """
176 | super().__init__()
177 |
178 | self.conv1 = nn.Conv2d(
179 | features, features, kernel_size=3, stride=1, padding=1, bias=True
180 | )
181 |
182 | self.conv2 = nn.Conv2d(
183 | features, features, kernel_size=3, stride=1, padding=1, bias=True
184 | )
185 |
186 | self.relu = nn.ReLU(inplace=True)
187 |
188 | def forward(self, x):
189 | """Forward pass.
190 |
191 | Args:
192 | x (tensor): input
193 |
194 | Returns:
195 | tensor: output
196 | """
197 | out = self.relu(x)
198 | out = self.conv1(out)
199 | out = self.relu(out)
200 | out = self.conv2(out)
201 |
202 | return out + x
203 |
204 |
205 | class FeatureFusionBlock(nn.Module):
206 | """Feature fusion block."""
207 |
208 | def __init__(self, features):
209 | """Init.
210 |
211 | Args:
212 | features (int): number of features
213 | """
214 | super(FeatureFusionBlock, self).__init__()
215 |
216 | self.resConfUnit1 = ResidualConvUnit(features)
217 | self.resConfUnit2 = ResidualConvUnit(features)
218 |
219 | def forward(self, *xs):
220 | """Forward pass.
221 |
222 | Returns:
223 | tensor: output
224 | """
225 | output = xs[0]
226 |
227 | if len(xs) == 2:
228 | output += self.resConfUnit1(xs[1])
229 |
230 | output = self.resConfUnit2(output)
231 |
232 | output = nn.functional.interpolate(
233 | output, scale_factor=2, mode="bilinear", align_corners=True
234 | )
235 |
236 | return output
237 |
238 |
239 | class ResidualConvUnit_custom(nn.Module):
240 | """Residual convolution module."""
241 |
242 | def __init__(self, features, activation, bn):
243 | """Init.
244 |
245 | Args:
246 | features (int): number of features
247 | """
248 | super().__init__()
249 |
250 | self.bn = bn
251 |
252 | self.groups = 1
253 |
254 | self.conv1 = nn.Conv2d(
255 | features,
256 | features,
257 | kernel_size=3,
258 | stride=1,
259 | padding=1,
260 | bias=not self.bn,
261 | groups=self.groups,
262 | )
263 |
264 | self.conv2 = nn.Conv2d(
265 | features,
266 | features,
267 | kernel_size=3,
268 | stride=1,
269 | padding=1,
270 | bias=not self.bn,
271 | groups=self.groups,
272 | )
273 |
274 | if self.bn == True:
275 | self.bn1 = nn.BatchNorm2d(features)
276 | self.bn2 = nn.BatchNorm2d(features)
277 |
278 | self.activation = activation
279 |
280 | self.skip_add = nn.quantized.FloatFunctional()
281 |
282 | def forward(self, x):
283 | """Forward pass.
284 |
285 | Args:
286 | x (tensor): input
287 |
288 | Returns:
289 | tensor: output
290 | """
291 |
292 | out = self.activation(x)
293 | out = self.conv1(out)
294 | if self.bn == True:
295 | out = self.bn1(out)
296 |
297 | out = self.activation(out)
298 | out = self.conv2(out)
299 | if self.bn == True:
300 | out = self.bn2(out)
301 |
302 | if self.groups > 1:
303 | out = self.conv_merge(out)
304 |
305 | return self.skip_add.add(out, x)
306 |
307 | # return out + x
308 |
309 |
310 | class FeatureFusionBlock_custom(nn.Module):
311 | """Feature fusion block."""
312 |
313 | def __init__(
314 | self,
315 | features,
316 | activation,
317 | deconv=False,
318 | bn=False,
319 | expand=False,
320 | align_corners=True,
321 | ):
322 | """Init.
323 |
324 | Args:
325 | features (int): number of features
326 | """
327 | super(FeatureFusionBlock_custom, self).__init__()
328 |
329 | self.deconv = deconv
330 | self.align_corners = align_corners
331 |
332 | self.groups = 1
333 |
334 | self.expand = expand
335 | out_features = features
336 | if self.expand == True:
337 | out_features = features // 2
338 |
339 | self.out_conv = nn.Conv2d(
340 | features,
341 | out_features,
342 | kernel_size=1,
343 | stride=1,
344 | padding=0,
345 | bias=True,
346 | groups=1,
347 | )
348 |
349 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
350 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
351 |
352 | self.skip_add = nn.quantized.FloatFunctional()
353 |
354 | def forward(self, *xs):
355 | """Forward pass.
356 |
357 | Returns:
358 | tensor: output
359 | """
360 | output = xs[0]
361 |
362 | if len(xs) == 2:
363 | res = self.resConfUnit1(xs[1])
364 | output = self.skip_add.add(output, res)
365 | # output += res
366 |
367 | output = self.resConfUnit2(output)
368 |
369 | output = nn.functional.interpolate(
370 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
371 | )
372 |
373 | output = self.out_conv(output)
374 |
375 | return output
376 |
--------------------------------------------------------------------------------
/src/models/CAVNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 |
4 | from models.rcrnet_vit import RCRNet_vit, _ConvBatchNormReLU, _RefinementModule
5 | from models.blocks import forward_vit
6 | from models.convgru import ConvGRUCell
7 | from models.non_local_dot_product import NONLocalBlock3D
8 |
9 | from collections import OrderedDict
10 | from torch.autograd import Variable
11 | import torch
12 | import torch.nn as nn
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | import torch.nn.functional as F
15 | import numpy as np
16 | from torch.distributions import Normal, Independent, kl
17 |
18 |
19 | class InferenceModel_mm_x(nn.Module):
20 | def __init__(self, input_channels, channels, latent_size):
21 | super(InferenceModel_mm_x, self).__init__()
22 | self.contracting_path = nn.ModuleList()
23 | self.input_channels = input_channels
24 | self.relu = nn.ReLU(inplace=True)
25 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
26 | self.bn1 = nn.BatchNorm2d(channels)
27 | self.layer2 = nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2, padding=1)
28 | self.bn2 = nn.BatchNorm2d(channels * 2)
29 | self.layer3 = nn.Conv2d(2 * channels, 4 * channels, kernel_size=4, stride=2, padding=1)
30 | self.bn3 = nn.BatchNorm2d(channels * 4)
31 | self.layer4 = nn.Conv2d(4 * channels, 8 * channels, kernel_size=4, stride=2, padding=1)
32 | self.bn4 = nn.BatchNorm2d(channels * 8)
33 | self.layer5 = nn.Conv2d(8 * channels, 8 * channels, kernel_size=4, stride=2, padding=1)
34 | self.bn5 = nn.BatchNorm2d(channels * 8)
35 | self.channel = channels
36 |
37 | self.fc1 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size
38 | self.fc2 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size
39 |
40 | self.a_fc1 = nn.Linear(1024, latent_size)
41 | self.a_fc2 = nn.Linear(1024, latent_size)
42 |
43 | self.av_fc1 = nn.Bilinear(latent_size, latent_size, latent_size)
44 | self.av_fc2 = nn.Bilinear(latent_size, latent_size, latent_size)
45 |
46 | self.leakyrelu = nn.LeakyReLU()
47 |
48 | def forward(self, input, aux_input):
49 | output = self.leakyrelu(self.bn1(self.layer1(input)))
50 | # print(output.size())
51 | output = self.leakyrelu(self.bn2(self.layer2(output)))
52 | # print(output.size())
53 | output = self.leakyrelu(self.bn3(self.layer3(output)))
54 | # print(output.size())
55 | output = self.leakyrelu(self.bn4(self.layer4(output)))
56 | # print(output.size())
57 | output = self.leakyrelu(self.bn5(self.layer5(output)))
58 | output = output.view(-1, self.channel * 8 * 13 * 26) # adjust according to input size
59 | # print(output.size())
60 | # output = self.tanh(output)
61 |
62 | # audio visual fusion
63 | mu = self.fc1(output)
64 | a_mu = self.a_fc1(aux_input)
65 | av_mu = self.av_fc1(mu, a_mu)
66 |
67 | logvar = self.fc2(output)
68 | a_logvar = self.a_fc2(aux_input)
69 | av_logvar = self.av_fc2(logvar, a_logvar)
70 | dist = Independent(Normal(loc=av_mu, scale=torch.exp(av_logvar)), 1)
71 |
72 | return av_mu, av_logvar, dist
73 |
74 |
75 | class InferenceModel_mm_xy(nn.Module):
76 | def __init__(self, input_channels, channels, latent_size):
77 | super(InferenceModel_mm_xy, self).__init__()
78 | self.contracting_path = nn.ModuleList()
79 | self.input_channels = input_channels
80 | self.relu = nn.ReLU(inplace=True)
81 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
82 | self.bn1 = nn.BatchNorm2d(channels)
83 | self.layer2 = nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2, padding=1)
84 | self.bn2 = nn.BatchNorm2d(channels * 2)
85 | self.layer3 = nn.Conv2d(2 * channels, 4 * channels, kernel_size=4, stride=2, padding=1)
86 | self.bn3 = nn.BatchNorm2d(channels * 4)
87 | self.layer4 = nn.Conv2d(4 * channels, 8 * channels, kernel_size=4, stride=2, padding=1)
88 | self.bn4 = nn.BatchNorm2d(channels * 8)
89 | self.layer5 = nn.Conv2d(8 * channels, 8 * channels, kernel_size=4, stride=2, padding=1)
90 | self.bn5 = nn.BatchNorm2d(channels * 8)
91 | self.channel = channels
92 |
93 | self.fc1 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size
94 | self.fc2 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size
95 |
96 | self.a_fc1 = nn.Linear(1024, latent_size)
97 | self.a_fc2 = nn.Linear(1024, latent_size)
98 |
99 | self.av_fc1 = nn.Bilinear(latent_size, latent_size, latent_size)
100 | self.av_fc2 = nn.Bilinear(latent_size, latent_size, latent_size)
101 |
102 | self.leakyrelu = nn.LeakyReLU()
103 |
104 | def forward(self, input, aux_input):
105 | output = self.leakyrelu(self.bn1(self.layer1(input)))
106 | # print(output.size())
107 | output = self.leakyrelu(self.bn2(self.layer2(output)))
108 | # print(output.size())
109 | output = self.leakyrelu(self.bn3(self.layer3(output)))
110 | # print(output.size())
111 | output = self.leakyrelu(self.bn4(self.layer4(output)))
112 | # print(output.size())
113 | output = self.leakyrelu(self.bn5(self.layer5(output)))
114 | output = output.view(-1, self.channel * 8 * 13 * 26) # adjust according to input size
115 | # print(output.size())
116 | # output = self.tanh(output)
117 |
118 | # audio visual fusion
119 | mu = self.fc1(output)
120 | a_mu = self.a_fc1(aux_input)
121 | av_mu = self.av_fc1(mu, a_mu)
122 |
123 | logvar = self.fc2(output)
124 | a_logvar = self.a_fc2(aux_input)
125 | av_logvar = self.av_fc2(logvar, a_logvar)
126 | dist = Independent(Normal(loc=av_mu, scale=torch.exp(av_logvar)), 1)
127 |
128 | return av_mu, av_logvar, dist
129 |
130 |
131 | class cavnet(nn.Module):
132 | def __init__(self, output_stride=16, lat_channel=16, lat_dim=32):
133 | super(cavnet, self).__init__()
134 | # for visual backbone
135 | # encoder
136 | self.backbone_enc_vit = RCRNet_vit(n_classes=1, output_stride=output_stride).pretrained
137 | self.backbone_enc_aspp = RCRNet_vit(n_classes=1, output_stride=output_stride).aspp
138 |
139 | # decoder
140 | self.backbone_dec_prior = RCRNet_decoder()
141 | self.backbone_dec_post = RCRNet_decoder()
142 |
143 | # for temporal module
144 | self.convgru_forward = ConvGRUCell(256, 256, 3)
145 | self.convgru_backward = ConvGRUCell(256, 256, 3)
146 | self.bidirection_conv = nn.Conv2d(512, 256, 3, 1, 1)
147 | self.non_local_block = NONLocalBlock3D(256, sub_sample=False, bn_layer=False)
148 | self.non_local_block2 = NONLocalBlock3D(256, sub_sample=False, bn_layer=False)
149 |
150 | # for CVAE
151 | self.enc_mm_x = InferenceModel_mm_x(3, lat_channel, lat_dim)
152 | self.enc_mm_xy = InferenceModel_mm_xy(4, lat_channel, lat_dim)
153 | self.spatial_axes = [2, 3]
154 | self.noise_conv_prior = nn.Conv2d(256 + lat_dim, 256, kernel_size=1, padding=0)
155 | self.noise_conv_post = nn.Conv2d(256 + lat_dim, 256, kernel_size=1, padding=0)
156 |
157 | # for audio
158 | # enocder
159 | self.soundnet = nn.Sequential( # 7 layers used
160 | nn.Conv2d(1, 16, (1, 64), (1, 2), (0, 32)), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d((1, 8), (1, 8)),
161 | nn.Conv2d(16, 32, (1, 32), (1, 2), (0, 16)), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d((1, 8), (1, 8)),
162 | nn.Conv2d(32, 64, (1, 16), (1, 2), (0, 8)), nn.BatchNorm2d(64), nn.ReLU(),
163 | nn.Conv2d(64, 128, (1, 8), (1, 2), (0, 4)), nn.BatchNorm2d(128), nn.ReLU(),
164 | nn.Conv2d(128, 256, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d((1, 4), (1, 4)),
165 | nn.Conv2d(256, 512, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(512), nn.ReLU(),
166 | nn.Conv2d(512, 1024, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(1024), nn.ReLU(), nn.MaxPool2d((1, 2))
167 | )
168 |
169 | # load pre-training model
170 | if self.training:
171 | self.initialize_pretrain() # load visual pretrain (static visual part pre-trained on DUTS-tr)
172 | self.initialize_soundnet() # load audio pretrain (soundnet)
173 |
174 | # may freeze auditory branch
175 | for param in self.soundnet.parameters(): param.requires_grad = True
176 |
177 | def initialize_pretrain(self):
178 | backbone_pretrain = torch.load(os.getcwd() + '/pretrain/static_visual_pretrain.pth')
179 |
180 | all_params_enc_vit = {}
181 | for k, v in self.backbone_enc_vit.state_dict().items():
182 | if 'pretrained.' + k in backbone_pretrain.keys():
183 | v = backbone_pretrain['pretrained.' + k]
184 | all_params_enc_vit[k] = v
185 | self.backbone_enc_vit.load_state_dict(all_params_enc_vit)
186 |
187 | all_params_enc_aspp = {}
188 | for k, v in self.backbone_enc_aspp.state_dict().items():
189 | if 'aspp.' + k in backbone_pretrain.keys():
190 | v = backbone_pretrain['aspp.' + k]
191 | all_params_enc_aspp[k] = v
192 | self.backbone_enc_aspp.load_state_dict(all_params_enc_aspp)
193 |
194 | def initialize_soundnet(self):
195 | audio_pretrain_weights = torch.load(os.getcwd() + '/pretrain/soundnet8.pth')
196 |
197 | all_params = {}
198 | for k, v in self.soundnet.state_dict().items():
199 | if 'module.soundnet8.' + k in audio_pretrain_weights.keys():
200 | v = audio_pretrain_weights['module.soundnet8.' + k]
201 | all_params[k] = v
202 | self.soundnet.load_state_dict(all_params)
203 |
204 | def reparametrize(self, mu, logvar):
205 | std = logvar.mul(0.5).exp_()
206 | eps = torch.cuda.FloatTensor(std.size()).normal_()
207 | eps = Variable(eps)
208 |
209 | return eps.mul(std).add_(mu)
210 |
211 | def kl_divergence(self, posterior_latent_space, prior_latent_space):
212 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space)
213 |
214 | return kl_div
215 |
216 | def tile(self, a, dim, n_tile):
217 | init_dim = a.size(dim)
218 | repeat_idx = [1] * a.dim()
219 | repeat_idx[dim] = n_tile
220 | a = a.repeat(*(repeat_idx))
221 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(
222 | device)
223 | return torch.index_select(a, dim, order_index)
224 |
225 | def forward(self, seq, audio_clip, gt=None):
226 | # --------------------------------------------------------------------------------------------------------------
227 | # encoder for mono sound
228 | feats_audio = self.soundnet(audio_clip.unsqueeze(0))
229 | feats_audio = torch.mean(feats_audio, dim=-1).squeeze().unsqueeze(0)
230 |
231 | # --------------------------------------------------------------------------------------------------------------
232 | # encoder for audio-visual-temporal
233 | feats_seq = [forward_vit(self.backbone_enc_vit, frame) for frame in seq]
234 | feats_time = [self.backbone_enc_aspp(feats[-1]) for feats in feats_seq] # bottleneck features
235 | feats_time = torch.stack(feats_time, dim=2)
236 | feats_time = self.non_local_block(feats_time)
237 |
238 | # Deep Bidirectional ConvGRU
239 | frame = seq[0]
240 | feat = feats_time[:, :, 0, :, :]
241 | feats_forward = []
242 | # forward
243 | for i in range(len(seq)):
244 | feat = self.convgru_forward(feats_time[:, :, i, :, :], feat)
245 | feats_forward.append(feat)
246 | # backward
247 | feat = feats_forward[-1]
248 | feats_backward = []
249 | for i in range(len(seq)):
250 | feat = self.convgru_backward(feats_forward[len(seq)-1-i], feat)
251 | feats_backward.append(feat)
252 |
253 | feats_backward = feats_backward[::-1]
254 | feats = []
255 | for i in range(len(seq)):
256 | feat = torch.tanh(self.bidirection_conv(torch.cat((feats_forward[i], feats_backward[i]), dim=1)))
257 | feats.append(feat)
258 | feats = torch.stack(feats, dim=2)
259 |
260 | feats = self.non_local_block2(feats) # spatial-temporal bottleneck features
261 |
262 | # --------------------------------------------------------------------------------------------------------------
263 | if gt == None:
264 | # model inference
265 | feats_prior = []
266 | for rr in range(len(seq)):
267 | mu_prior, logvar_prior, _ = self.enc_mm_x(seq[rr], feats_audio)
268 | z_prior = self.reparametrize(mu_prior, logvar_prior) # instantiate latent variable
269 | z_prior = torch.unsqueeze(z_prior, 2)
270 | z_prior = self.tile(z_prior, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]])
271 | z_prior = torch.unsqueeze(z_prior, 3)
272 | z_prior = self.tile(z_prior, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]])
273 | f_prior = torch.cat((feats[:, :, rr, :, :], z_prior), 1)
274 | feats_prior.append(self.noise_conv_prior(f_prior))
275 |
276 | preds_prior = []
277 | for i, frame in enumerate(seq):
278 | seg_prior = self.backbone_dec_prior(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_prior[i],
279 | frame)
280 | preds_prior.append(seg_prior)
281 |
282 | return preds_prior
283 | else:
284 | # CVAE
285 | KLD, feats_prior, feats_post = [], [], []
286 | for rr in range(len(seq)):
287 | mu_prior, logvar_prior, dist_prior = self.enc_mm_x(seq[rr], feats_audio)
288 | z_prior = self.reparametrize(mu_prior, logvar_prior) # instantiate latent variable
289 | z_prior = torch.unsqueeze(z_prior, 2)
290 | z_prior = self.tile(z_prior, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]])
291 | z_prior = torch.unsqueeze(z_prior, 3)
292 | z_prior = self.tile(z_prior, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]])
293 | f_prior = torch.cat((feats[:, :, rr, :, :], z_prior), 1)
294 | feats_prior.append(self.noise_conv_prior(f_prior))
295 |
296 | mu_post, logvar_post, dist_post = self.enc_mm_xy(torch.cat((seq[rr], gt[rr]), 1), feats_audio)
297 | z_post = self.reparametrize(mu_post, logvar_post) # instantiate latent variable
298 | z_post = torch.unsqueeze(z_post, 2)
299 | z_post = self.tile(z_post, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]])
300 | z_post = torch.unsqueeze(z_post, 3)
301 | z_post = self.tile(z_post, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]])
302 | f_post = torch.cat((feats[:, :, rr, :, :], z_post), 1)
303 | feats_post.append(self.noise_conv_post(f_post))
304 |
305 | KLD.append(torch.mean(self.kl_divergence(dist_post, dist_prior)))
306 |
307 | preds_prior, preds_post = [], []
308 | for i, frame in enumerate(seq):
309 | seg_prior = self.backbone_dec_prior(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_prior[i],
310 | frame)
311 | preds_prior.append(seg_prior)
312 | seg_post = self.backbone_dec_post(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_post[i],
313 | frame)
314 | preds_post.append(seg_post)
315 |
316 | return preds_prior, preds_post, KLD
317 |
318 |
319 | class RCRNet_decoder(nn.Module):
320 | def __init__(self, n_classes=1):
321 | super(RCRNet_decoder, self).__init__()
322 | self.decoder = nn.Sequential(
323 | OrderedDict(
324 | [
325 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)),
326 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)),
327 | ]
328 | )
329 | )
330 | self.add_module("refinement1", _RefinementModule(768, 96, 256, 128, 2))
331 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2))
332 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2))
333 |
334 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
335 |
336 | if self.training: self.init_pretrain()
337 |
338 | def init_pretrain(self):
339 | backbone_pretrain = torch.load(os.getcwd() + '/pretrain/static_visual_pretrain.pth')
340 |
341 | all_params_dec = {}
342 | for k, v in self.decoder.state_dict().items():
343 | if 'decoder.' + k in backbone_pretrain.keys():
344 | v = backbone_pretrain['decoder.' + k]
345 | all_params_dec[k] = v
346 | self.decoder.load_state_dict(all_params_dec)
347 |
348 | all_params_ref1 = {}
349 | for k, v in self.refinement1.state_dict().items():
350 | if 'refinement1.' + k in backbone_pretrain.keys():
351 | v = backbone_pretrain['refinement1.' + k]
352 | all_params_ref1[k] = v
353 | self.refinement1.load_state_dict(all_params_ref1)
354 |
355 | all_params_ref2 = {}
356 | for k, v in self.refinement2.state_dict().items():
357 | if 'refinement2.' + k in backbone_pretrain.keys():
358 | v = backbone_pretrain['refinement2.' + k]
359 | all_params_ref2[k] = v
360 | self.refinement2.load_state_dict(all_params_ref2)
361 |
362 | all_params_ref3 = {}
363 | for k, v in self.refinement3.state_dict().items():
364 | if 'refinement3.' + k in backbone_pretrain.keys():
365 | v = backbone_pretrain['refinement3.' + k]
366 | all_params_ref3[k] = v
367 | self.refinement3.load_state_dict(all_params_ref3)
368 |
369 | def seg_conv(self, block1, block2, block3, block4, shape):
370 | '''
371 | Pixel-wise classifer
372 | '''
373 | block4 = self.upsample2(block4)
374 |
375 | bu1 = self.refinement1(block3, block4)
376 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False)
377 | bu2 = self.refinement2(block2, bu1)
378 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False)
379 | bu3 = self.refinement3(block1, bu2)
380 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False)
381 | seg = self.decoder(bu3)
382 |
383 | return seg
384 |
385 | def forward(self, block1, block2, block3, block4, x):
386 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:])
387 |
388 | return seg
389 |
390 |
--------------------------------------------------------------------------------
/src/models/VIT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | activations = {}
10 |
11 |
12 | def get_activation(name):
13 | def hook(model, input, output):
14 | activations[name] = output
15 |
16 | return hook
17 |
18 |
19 | attention = {}
20 |
21 |
22 | def get_attention(name):
23 | def hook(module, input, output):
24 | x = input[0]
25 | B, N, C = x.shape
26 | qkv = (
27 | module.qkv(x)
28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29 | .permute(2, 0, 3, 1, 4)
30 | )
31 | q, k, v = (
32 | qkv[0],
33 | qkv[1],
34 | qkv[2],
35 | ) # make torchscript happy (cannot use tensor as tuple)
36 |
37 | attn = (q @ k.transpose(-2, -1)) * module.scale
38 |
39 | attn = attn.softmax(dim=-1) # [:,:,1,1:]
40 | attention[name] = attn
41 |
42 | return hook
43 |
44 |
45 | def get_mean_attention_map(attn, token, shape):
46 | attn = attn[:, :, token, 1:]
47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48 | attn = torch.nn.functional.interpolate(
49 | attn, size=shape[2:], mode="bicubic", align_corners=False
50 | ).squeeze(0)
51 |
52 | all_attn = torch.mean(attn, 0)
53 |
54 | return all_attn
55 |
56 |
57 | class Slice(nn.Module):
58 | def __init__(self, start_index=1):
59 | super(Slice, self).__init__()
60 | self.start_index = start_index
61 |
62 | def forward(self, x):
63 | return x[:, self.start_index :]
64 |
65 |
66 | class AddReadout(nn.Module):
67 | def __init__(self, start_index=1):
68 | super(AddReadout, self).__init__()
69 | self.start_index = start_index
70 |
71 | def forward(self, x):
72 | if self.start_index == 2:
73 | readout = (x[:, 0] + x[:, 1]) / 2
74 | else:
75 | readout = x[:, 0]
76 | return x[:, self.start_index :] + readout.unsqueeze(1)
77 |
78 |
79 | class ProjectReadout(nn.Module):
80 | def __init__(self, in_features, start_index=1):
81 | super(ProjectReadout, self).__init__()
82 | self.start_index = start_index
83 |
84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85 |
86 | def forward(self, x):
87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88 | features = torch.cat((x[:, self.start_index :], readout), -1)
89 |
90 | return self.project(features)
91 |
92 |
93 | class Transpose(nn.Module):
94 | def __init__(self, dim0, dim1):
95 | super(Transpose, self).__init__()
96 | self.dim0 = dim0
97 | self.dim1 = dim1
98 |
99 | def forward(self, x):
100 | x = x.transpose(self.dim0, self.dim1)
101 | return x
102 |
103 |
104 | def forward_vit(pretrained, x):
105 | b, c, h, w = x.shape
106 |
107 | glob = pretrained.model.forward_flex(x)
108 |
109 | layer_1 = pretrained.activations["1"]
110 | layer_2 = pretrained.activations["2"]
111 | layer_3 = pretrained.activations["3"]
112 | layer_4 = pretrained.activations["4"]
113 |
114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118 |
119 | unflatten = nn.Sequential(
120 | nn.Unflatten(
121 | 2,
122 | torch.Size(
123 | [
124 | h // pretrained.model.patch_size[1],
125 | w // pretrained.model.patch_size[0],
126 | ]
127 | ),
128 | )
129 | )
130 |
131 | if layer_1.ndim == 3:
132 | layer_1 = unflatten(layer_1)
133 | if layer_2.ndim == 3:
134 | layer_2 = unflatten(layer_2)
135 | if layer_3.ndim == 3:
136 | layer_3 = unflatten(layer_3)
137 | if layer_4.ndim == 3:
138 | layer_4 = unflatten(layer_4)
139 |
140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144 |
145 | return layer_1, layer_2, layer_3, layer_4
146 |
147 |
148 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
149 | posemb_tok, posemb_grid = (
150 | posemb[:, : self.start_index],
151 | posemb[0, self.start_index :],
152 | )
153 |
154 | gs_old = int(math.sqrt(len(posemb_grid)))
155 |
156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159 |
160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161 |
162 | return posemb
163 |
164 |
165 | def forward_flex(self, x):
166 | b, c, h, w = x.shape
167 |
168 | pos_embed = self._resize_pos_embed(
169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170 | )
171 |
172 | B = x.shape[0]
173 |
174 | if hasattr(self.patch_embed, "backbone"):
175 | x = self.patch_embed.backbone(x)
176 | if isinstance(x, (list, tuple)):
177 | x = x[-1] # last feature if backbone outputs list/tuple of features
178 |
179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180 |
181 | #if hasattr(self, "dist_token"):
182 | # cls_tokens = self.cls_token.expand(
183 | # B, -1, -1
184 | # ) # stole cls_tokens impl from Phil Wang, thanks
185 | # dist_token = self.dist_token.expand(B, -1, -1)
186 | # x = torch.cat((cls_tokens, dist_token, x), dim=1)
187 | #else:
188 | cls_tokens = self.cls_token.expand(
189 | B, -1, -1
190 | ) # stole cls_tokens impl from Phil Wang, thanks
191 | x = torch.cat((cls_tokens, x), dim=1)
192 |
193 | x = x + pos_embed
194 | x = self.pos_drop(x)
195 |
196 | for blk in self.blocks:
197 | x = blk(x)
198 |
199 | x = self.norm(x)
200 |
201 | return x
202 |
203 |
204 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
205 | if use_readout == "ignore":
206 | readout_oper = [Slice(start_index)] * len(features)
207 | elif use_readout == "add":
208 | readout_oper = [AddReadout(start_index)] * len(features)
209 | elif use_readout == "project":
210 | readout_oper = [
211 | ProjectReadout(vit_features, start_index) for out_feat in features
212 | ]
213 | else:
214 | assert (
215 | False
216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217 |
218 | return readout_oper
219 |
220 |
221 | def _make_vit_b16_backbone(
222 | model,
223 | features=[96, 192, 384, 768],
224 | size=[384, 384],
225 | hooks=[2, 5, 8, 11],
226 | vit_features=768,
227 | use_readout="ignore",
228 | start_index=1,
229 | enable_attention_hooks=False,
230 | ):
231 | pretrained = nn.Module()
232 |
233 | pretrained.model = model
234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238 |
239 | pretrained.activations = activations
240 |
241 | if enable_attention_hooks:
242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243 | get_attention("attn_1")
244 | )
245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246 | get_attention("attn_2")
247 | )
248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249 | get_attention("attn_3")
250 | )
251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252 | get_attention("attn_4")
253 | )
254 | pretrained.attention = attention
255 |
256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257 |
258 | # 32, 48, 136, 384
259 | pretrained.act_postprocess1 = nn.Sequential(
260 | readout_oper[0],
261 | Transpose(1, 2),
262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263 | nn.Conv2d(
264 | in_channels=vit_features,
265 | out_channels=features[0],
266 | kernel_size=1,
267 | stride=1,
268 | padding=0,
269 | ),
270 | nn.ConvTranspose2d(
271 | in_channels=features[0],
272 | out_channels=features[0],
273 | kernel_size=4,
274 | stride=4,
275 | padding=0,
276 | bias=True,
277 | dilation=1,
278 | groups=1,
279 | ),
280 | )
281 |
282 | pretrained.act_postprocess2 = nn.Sequential(
283 | readout_oper[1],
284 | Transpose(1, 2),
285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286 | nn.Conv2d(
287 | in_channels=vit_features,
288 | out_channels=features[1],
289 | kernel_size=1,
290 | stride=1,
291 | padding=0,
292 | ),
293 | nn.ConvTranspose2d(
294 | in_channels=features[1],
295 | out_channels=features[1],
296 | kernel_size=2,
297 | stride=2,
298 | padding=0,
299 | bias=True,
300 | dilation=1,
301 | groups=1,
302 | ),
303 | )
304 |
305 | pretrained.act_postprocess3 = nn.Sequential(
306 | readout_oper[2],
307 | Transpose(1, 2),
308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309 | nn.Conv2d(
310 | in_channels=vit_features,
311 | out_channels=features[2],
312 | kernel_size=1,
313 | stride=1,
314 | padding=0,
315 | ),
316 | )
317 |
318 | pretrained.act_postprocess4 = nn.Sequential(
319 | readout_oper[3],
320 | Transpose(1, 2),
321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322 | nn.Conv2d(
323 | in_channels=vit_features,
324 | out_channels=features[3],
325 | kernel_size=1,
326 | stride=1,
327 | padding=0,
328 | ),
329 | nn.Conv2d(
330 | in_channels=features[3],
331 | out_channels=features[3],
332 | kernel_size=3,
333 | stride=2,
334 | padding=1,
335 | ),
336 | )
337 |
338 | pretrained.model.start_index = start_index
339 | pretrained.model.patch_size = [16, 16]
340 |
341 | # We inject this function into the VisionTransformer instances so that
342 | # we can use it with interpolated position embeddings without modifying the library source.
343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344 | pretrained.model._resize_pos_embed = types.MethodType(
345 | _resize_pos_embed, pretrained.model
346 | )
347 |
348 | return pretrained
349 |
350 |
351 | def _make_vit_b_rn50_backbone(
352 | model,
353 | features=[256, 512, 768, 768],
354 | size=[384, 384],
355 | hooks=[0, 1, 8, 11],
356 | vit_features=768,
357 | use_vit_only=False,
358 | use_readout="ignore",
359 | start_index=1,
360 | enable_attention_hooks=False,
361 | ):
362 | pretrained = nn.Module()
363 |
364 | pretrained.model = model
365 |
366 | if use_vit_only == True:
367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369 | else:
370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371 | get_activation("1")
372 | )
373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374 | get_activation("2")
375 | )
376 |
377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379 |
380 | if enable_attention_hooks:
381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385 | pretrained.attention = attention
386 |
387 | pretrained.activations = activations
388 |
389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390 |
391 | if use_vit_only == True:
392 | pretrained.act_postprocess1 = nn.Sequential(
393 | readout_oper[0],
394 | Transpose(1, 2),
395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396 | nn.Conv2d(
397 | in_channels=vit_features,
398 | out_channels=features[0],
399 | kernel_size=1,
400 | stride=1,
401 | padding=0,
402 | ),
403 | nn.ConvTranspose2d(
404 | in_channels=features[0],
405 | out_channels=features[0],
406 | kernel_size=4,
407 | stride=4,
408 | padding=0,
409 | bias=True,
410 | dilation=1,
411 | groups=1,
412 | ),
413 | )
414 |
415 | pretrained.act_postprocess2 = nn.Sequential(
416 | readout_oper[1],
417 | Transpose(1, 2),
418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419 | nn.Conv2d(
420 | in_channels=vit_features,
421 | out_channels=features[1],
422 | kernel_size=1,
423 | stride=1,
424 | padding=0,
425 | ),
426 | nn.ConvTranspose2d(
427 | in_channels=features[1],
428 | out_channels=features[1],
429 | kernel_size=2,
430 | stride=2,
431 | padding=0,
432 | bias=True,
433 | dilation=1,
434 | groups=1,
435 | ),
436 | )
437 | else:
438 | pretrained.act_postprocess1 = nn.Sequential(
439 | nn.Identity(), nn.Identity(), nn.Identity()
440 | )
441 | pretrained.act_postprocess2 = nn.Sequential(
442 | nn.Identity(), nn.Identity(), nn.Identity()
443 | )
444 |
445 | pretrained.act_postprocess3 = nn.Sequential(
446 | readout_oper[2],
447 | Transpose(1, 2),
448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449 | nn.Conv2d(
450 | in_channels=vit_features,
451 | out_channels=features[2],
452 | kernel_size=1,
453 | stride=1,
454 | padding=0,
455 | ),
456 | )
457 |
458 | pretrained.act_postprocess4 = nn.Sequential(
459 | readout_oper[3],
460 | Transpose(1, 2),
461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462 | nn.Conv2d(
463 | in_channels=vit_features,
464 | out_channels=features[3],
465 | kernel_size=1,
466 | stride=1,
467 | padding=0,
468 | ),
469 | nn.Conv2d(
470 | in_channels=features[3],
471 | out_channels=features[3],
472 | kernel_size=3,
473 | stride=2,
474 | padding=1,
475 | ),
476 | )
477 |
478 | pretrained.model.start_index = start_index
479 | pretrained.model.patch_size = [16, 16]
480 |
481 | # We inject this function into the VisionTransformer instances so that
482 | # we can use it with interpolated position embeddings without modifying the library source.
483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484 |
485 | # We inject this function into the VisionTransformer instances so that
486 | # we can use it with interpolated position embeddings without modifying the library source.
487 | pretrained.model._resize_pos_embed = types.MethodType(
488 | _resize_pos_embed, pretrained.model
489 | )
490 |
491 | return pretrained
492 |
493 |
494 | def _make_pretrained_vitb_rn50_384(
495 | pretrained,
496 | use_readout="ignore",
497 | hooks=None,
498 | use_vit_only=False,
499 | enable_attention_hooks=False,
500 | ):
501 | model = timm.create_model("vit_base_resnet50_384", pretrained='True')
502 |
503 | hooks = [0, 1, 8, 11] if hooks == None else hooks
504 | return _make_vit_b_rn50_backbone(
505 | model,
506 | features=[256, 512, 768, 768],
507 | size=[384, 384],
508 | hooks=hooks,
509 | use_vit_only=use_vit_only,
510 | use_readout=use_readout,
511 | enable_attention_hooks=enable_attention_hooks,
512 | )
513 |
514 |
515 | def _make_pretrained_vitl16_384(
516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517 | ):
518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519 |
520 | hooks = [5, 11, 17, 23] if hooks == None else hooks
521 | return _make_vit_b16_backbone(
522 | model,
523 | features=[256, 512, 1024, 1024],
524 | hooks=hooks,
525 | vit_features=1024,
526 | use_readout=use_readout,
527 | enable_attention_hooks=enable_attention_hooks,
528 | )
529 |
530 |
531 | def _make_pretrained_vitb16_384(
532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533 | ):
534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535 |
536 | hooks = [2, 5, 8, 11] if hooks == None else hooks
537 | return _make_vit_b16_backbone(
538 | model,
539 | features=[96, 192, 384, 768],
540 | hooks=hooks,
541 | use_readout=use_readout,
542 | enable_attention_hooks=enable_attention_hooks,
543 | )
544 |
545 |
546 | def _make_pretrained_deitb16_384(
547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548 | ):
549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550 |
551 | hooks = [2, 5, 8, 11] if hooks == None else hooks
552 | return _make_vit_b16_backbone(
553 | model,
554 | features=[96, 192, 384, 768],
555 | hooks=hooks,
556 | use_readout=use_readout,
557 | enable_attention_hooks=enable_attention_hooks,
558 | )
559 |
560 |
561 | def _make_pretrained_deitb16_distil_384(
562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563 | ):
564 | model = timm.create_model(
565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566 | )
567 |
568 | hooks = [2, 5, 8, 11] if hooks == None else hooks
569 | return _make_vit_b16_backbone(
570 | model,
571 | features=[96, 192, 384, 768],
572 | hooks=hooks,
573 | use_readout=use_readout,
574 | start_index=2,
575 | enable_attention_hooks=enable_attention_hooks,
576 | )
577 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import torchaudio
7 | import cv2
8 | import torch
9 |
10 | clip_length = 3
11 |
12 | def split_list(a, n):
13 | k, m = divmod(len(a), n)
14 | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
15 |
16 |
17 | # dataset for training
18 | class SalObjDataset(data.Dataset):
19 | def __init__(self, root, trainsize):
20 | self.trainsize = trainsize
21 |
22 | # get visual and audio clips (duration of three consecutive key frames)
23 | with open(root, 'r') as f:
24 | self.seqs = [x.strip() for x in f.readlines()]
25 |
26 | video_clips, gt_clips = [], []
27 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], [] # ambisonics
28 | for seq_idx in self.seqs:
29 | # get visual clips of each video
30 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx))
31 | frm_list = sorted(frm_list)
32 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/train/', seq_idx))
33 | gt_list = sorted(gt_list)
34 | for idx in range(len(frm_list)):
35 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx]
36 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/train/' + seq_idx + '/' + gt_list[idx]
37 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length)))
38 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length)))
39 | video_clips.append(frm_list_split)
40 | gt_clips.append(gt_list_split)
41 |
42 | # get audio clips of each video
43 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav'
44 | audio_ori = torchaudio.load(audio_pth)[0]
45 | audio_ori_ch1 = audio_ori[0]
46 | audio_ori_ch2 = audio_ori[1]
47 | audio_ori_ch3 = audio_ori[2]
48 | audio_ori_ch4 = audio_ori[3]
49 | audio_split_size = int(len(audio_ori[1])/(int(len(frm_list) / clip_length)))
50 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size)
51 | audio_ch1_split = list(audio_ch1_split)
52 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)]
53 | audio_clips_ch1.append(audio_ch1_split)
54 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size)
55 | audio_ch2_split = list(audio_ch2_split)
56 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)]
57 | audio_clips_ch2.append(audio_ch2_split)
58 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size)
59 | audio_ch3_split = list(audio_ch3_split)
60 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)]
61 | audio_clips_ch3.append(audio_ch3_split)
62 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size)
63 | audio_ch4_split = list(audio_ch4_split)
64 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)]
65 | audio_clips_ch4.append(audio_ch4_split)
66 |
67 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist]
68 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist]
69 |
70 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist]
71 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist]
72 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist]
73 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist]
74 |
75 | # tools for data transform
76 | self.img_transform = transforms.Compose([
77 | transforms.Resize((self.trainsize, self.trainsize * 2)),
78 | transforms.ToTensor(),
79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
80 | self.gt_transform = transforms.Compose([
81 | transforms.Resize((self.trainsize, self.trainsize * 2)),
82 | transforms.ToTensor()])
83 |
84 | # dataset size
85 | self.size = len(self.video_clips_flatten)
86 |
87 | def __getitem__(self, index):
88 | seq_name = self.video_clips_flatten[index][0].split('/')[-2]
89 | imgs, gts, audios = [], [], []
90 | for idx in range(clip_length): # default as 3 consecutive frames of each clip
91 | curr_img = self.rgb_loader(self.video_clips_flatten[index][idx])
92 | imgs.append(self.img_transform(curr_img))
93 |
94 | curr_gt = self.binary_loader(self.gt_clips_flatten[index][idx])
95 | gts.append(self.gt_transform(curr_gt))
96 |
97 | # collect audio clip (default as 15 frames equivalently; five times of clip length)
98 | if index == 0: a_index = [0, 1, 2, 3, 4]
99 | elif index == 1: a_index = [0, 1, 2, 3, 4]
100 | elif index == self.size - 1:
101 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1]
102 | elif index == self.size - 2:
103 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1]
104 | else:
105 | if self.video_clips_flatten[index - 2][0].split('/')[-2] != seq_name:
106 | if self.video_clips_flatten[index - 1][0].split('/')[-2] != seq_name:
107 | a_index = [index, index + 1, index + 2, index + 3, index + 4]
108 | else:
109 | a_index = [index - 1, index, index + 1, index + 2, index + 3]
110 | elif self.video_clips_flatten[index + 2][0].split('/')[-2] != seq_name:
111 | if self.video_clips_flatten[index + 1][0].split('/')[-2] != seq_name:
112 | a_index = [index - 4, index - 3, index - 2, index - 1, index]
113 | else:
114 | a_index = [index - 3, index - 2, index - 1, index, index + 1]
115 | else:
116 | a_index = [index - 2, index - 1, index, index + 1, index + 2]
117 |
118 | a_ch1, a_ch2, a_ch3, a_ch4 = [], [], [], []
119 | for aa in range(5):
120 | a_ch1.append(self.audio_ch1_clips_flatten[a_index[aa]])
121 | a_ch2.append(self.audio_ch2_clips_flatten[a_index[aa]])
122 | a_ch3.append(self.audio_ch3_clips_flatten[a_index[aa]])
123 | a_ch4.append(self.audio_ch4_clips_flatten[a_index[aa]])
124 | audios.append([torch.cat(a_ch1), torch.cat(a_ch2), torch.cat(a_ch3), torch.cat(a_ch4)])
125 | audios = torch.stack(audios[0], dim=0)
126 | audios = torch.mean(audios, dim=0).unsqueeze(0) # ambisonics to mono
127 |
128 | return imgs, gts, audios, seq_name
129 |
130 | def rgb_loader(self, path):
131 | with open(path, 'rb') as f:
132 | img = Image.open(f)
133 | return img.convert('RGB')
134 |
135 | def binary_loader(self, path):
136 | with open(path, 'rb') as f:
137 | img = Image.open(f)
138 | return img.convert('L')
139 |
140 | def __len__(self):
141 | return self.size
142 |
143 | #dataloader for training
144 | def get_loader(root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
145 |
146 | dataset = SalObjDataset(root, trainsize)
147 | data_loader = data.DataLoader(dataset=dataset,
148 | batch_size=batchsize,
149 | shuffle=shuffle,
150 | num_workers=num_workers,
151 | pin_memory=pin_memory)
152 | return data_loader
153 |
154 |
155 | #test dataset and loader
156 | class test_dataset:
157 | def __init__(self, root, testsize):
158 | self.testsize = testsize
159 |
160 | with open(root, 'r') as f:
161 | self.seqs = [x.strip() for x in f.readlines()]
162 |
163 | video_clips, gt_clips = [], []
164 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], []
165 | for seq_idx in self.seqs:
166 | # get visual clips of each video
167 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx))
168 | frm_list = sorted(frm_list)
169 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/test/', seq_idx))
170 | gt_list = sorted(gt_list)
171 | for idx in range(len(frm_list)):
172 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx]
173 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/test/' + seq_idx + '/' + gt_list[idx]
174 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length)))
175 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length)))
176 | video_clips.append(frm_list_split)
177 | gt_clips.append(gt_list_split)
178 |
179 | # get audio clips of each video
180 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav'
181 | audio_ori = torchaudio.load(audio_pth)[0]
182 | audio_ori_ch1 = audio_ori[0]
183 | audio_ori_ch2 = audio_ori[1]
184 | audio_ori_ch3 = audio_ori[2]
185 | audio_ori_ch4 = audio_ori[3]
186 | audio_split_size = int(len(audio_ori[1]) / (int(len(frm_list) / clip_length)))
187 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size)
188 | audio_ch1_split = list(audio_ch1_split)
189 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)]
190 | audio_clips_ch1.append(audio_ch1_split)
191 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size)
192 | audio_ch2_split = list(audio_ch2_split)
193 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)]
194 | audio_clips_ch2.append(audio_ch2_split)
195 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size)
196 | audio_ch3_split = list(audio_ch3_split)
197 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)]
198 | audio_clips_ch3.append(audio_ch3_split)
199 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size)
200 | audio_ch4_split = list(audio_ch4_split)
201 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)]
202 | audio_clips_ch4.append(audio_ch4_split)
203 |
204 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist]
205 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist]
206 |
207 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist]
208 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist]
209 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist]
210 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist]
211 |
212 | # tools for data transform
213 | self.transform = transforms.Compose([
214 | transforms.Resize((self.testsize, self.testsize * 2)),
215 | transforms.ToTensor(),
216 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
217 | self.transform_er_cube = transforms.Compose([
218 | transforms.Resize((640, 1280)),
219 | transforms.ToTensor(),
220 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
221 |
222 | self.size = len(self.video_clips_flatten)
223 | self.index = 0
224 |
225 | def load_data(self):
226 | seq_name = self.video_clips_flatten[self.index][0].split('/')[-2]
227 | imgs, ER_imgs, gts, audios, frm_names = [], [], [], [], []
228 | for idx in range(clip_length): # default as 3 consecutive frames of each clip
229 | curr_img = self.rgb_loader(self.video_clips_flatten[self.index][idx])
230 | ER_imgs.append(self.transform_er_cube(curr_img).unsqueeze(0))
231 | imgs.append(self.transform(curr_img).unsqueeze(0))
232 | curr_gt = self.binary_loader(self.gt_clips_flatten[self.index][idx])
233 | curr_gt = curr_gt.resize((self.testsize * 2, self.testsize))
234 | gts.append(curr_gt)
235 | frm_names.append(self.video_clips_flatten[self.index][idx].split('/')[-1])
236 | audios.append([self.audio_ch1_clips_flatten[self.index], self.audio_ch2_clips_flatten[self.index],
237 | self.audio_ch3_clips_flatten[self.index], self.audio_ch4_clips_flatten[self.index]])
238 | audios = torch.stack(audios[0], dim=0).unsqueeze(0)
239 |
240 | self.index += 1
241 | self.index = self.index % self.size
242 |
243 | return imgs, ER_imgs, gts, audios, seq_name, frm_names
244 |
245 | def rgb_loader(self, path):
246 | with open(path, 'rb') as f:
247 | img = Image.open(f)
248 | return img.convert('RGB')
249 |
250 | def binary_loader(self, path):
251 | with open(path, 'rb') as f:
252 | img = Image.open(f)
253 | return img.convert('L')
254 |
255 | def __len__(self):
256 | return self.size
257 |
258 |
259 | class dataset_inference:
260 | def __init__(self, root, testsize):
261 | self.testsize = testsize
262 |
263 | with open(root, 'r') as f:
264 | self.seqs = [x.strip() for x in f.readlines()]
265 |
266 | video_clips, gt_clips = [], []
267 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], []
268 | for seq_idx in self.seqs:
269 | # get visual clips of each video
270 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx))
271 | frm_list = sorted(frm_list)
272 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/test/', seq_idx))
273 | gt_list = sorted(gt_list)
274 | for idx in range(len(frm_list)):
275 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx]
276 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/test/' + seq_idx + '/' + gt_list[idx]
277 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length)))
278 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length)))
279 | video_clips.append(frm_list_split)
280 | gt_clips.append(gt_list_split)
281 |
282 | # get audio clips of each video
283 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav'
284 | audio_ori = torchaudio.load(audio_pth)[0]
285 | audio_ori_ch1 = audio_ori[0]
286 | audio_ori_ch2 = audio_ori[1]
287 | audio_ori_ch3 = audio_ori[2]
288 | audio_ori_ch4 = audio_ori[3]
289 | audio_split_size = int(len(audio_ori[1]) / (int(len(frm_list) / clip_length)))
290 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size)
291 | audio_ch1_split = list(audio_ch1_split)
292 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)]
293 | audio_clips_ch1.append(audio_ch1_split)
294 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size)
295 | audio_ch2_split = list(audio_ch2_split)
296 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)]
297 | audio_clips_ch2.append(audio_ch2_split)
298 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size)
299 | audio_ch3_split = list(audio_ch3_split)
300 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)]
301 | audio_clips_ch3.append(audio_ch3_split)
302 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size)
303 | audio_ch4_split = list(audio_ch4_split)
304 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)]
305 | audio_clips_ch4.append(audio_ch4_split)
306 |
307 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist]
308 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist]
309 |
310 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist]
311 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist]
312 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist]
313 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist]
314 |
315 | # tools for data transform
316 | self.transform = transforms.Compose([
317 | transforms.Resize((self.testsize, self.testsize * 2)),
318 | transforms.ToTensor(),
319 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
320 |
321 | self.size = len(self.video_clips_flatten)
322 | self.index = 0
323 |
324 | def load_data(self):
325 | seq_name = self.video_clips_flatten[self.index][0].split('/')[-2]
326 | imgs, gts, audios, frm_names = [], [], [], []
327 | for idx in range(clip_length): # default as 3 consecutive frames of each clip
328 | curr_img = self.rgb_loader(self.video_clips_flatten[self.index][idx])
329 | imgs.append(self.transform(curr_img).unsqueeze(0))
330 |
331 | curr_gt = self.binary_loader(self.gt_clips_flatten[self.index][idx])
332 | gts.append(curr_gt)
333 |
334 | frm_names.append(self.video_clips_flatten[self.index][idx].split('/')[-1])
335 |
336 | # collect audio clip (default as 15 frames equivalently; five times of clip length)
337 | if self.index == 0:
338 | a_index = [0, 1, 2, 3, 4]
339 | elif self.index == 1:
340 | a_index = [0, 1, 2, 3, 4]
341 | elif self.index == self.size - 1:
342 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1]
343 | elif self.index == self.size - 2:
344 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1]
345 | else:
346 | if self.video_clips_flatten[self.index - 2][0].split('/')[-2] != seq_name:
347 | if self.video_clips_flatten[self.index - 1][0].split('/')[-2] != seq_name:
348 | a_index = [self.index, self.index + 1, self.index + 2, self.index + 3, self.index + 4]
349 | else:
350 | a_index = [self.index - 1, self.index, self.index + 1, self.index + 2, self.index + 3]
351 | elif self.video_clips_flatten[self.index + 2][0].split('/')[-2] != seq_name:
352 | if self.video_clips_flatten[self.index + 1][0].split('/')[-2] != seq_name:
353 | a_index = [self.index - 4, self.index - 3, self.index - 2, self.index - 1, self.index]
354 | else:
355 | a_index = [self.index - 3, self.index - 2, self.index - 1, self.index, self.index + 1]
356 | else:
357 | a_index = [self.index - 2, self.index - 1, self.index, self.index + 1, self.index + 2]
358 |
359 | a_ch1, a_ch2, a_ch3, a_ch4 = [], [], [], []
360 | for aa in range(5):
361 | a_ch1.append(self.audio_ch1_clips_flatten[a_index[aa]])
362 | a_ch2.append(self.audio_ch2_clips_flatten[a_index[aa]])
363 | a_ch3.append(self.audio_ch3_clips_flatten[a_index[aa]])
364 | a_ch4.append(self.audio_ch4_clips_flatten[a_index[aa]])
365 | audios.append([torch.cat(a_ch1), torch.cat(a_ch2), torch.cat(a_ch3), torch.cat(a_ch4)])
366 | audios = torch.stack(audios[0], dim=0)
367 | audios = torch.mean(audios, dim=0).unsqueeze(0).unsqueeze(0)
368 |
369 | self.index += 1
370 | self.index = self.index % self.size
371 |
372 | return imgs, gts, audios, seq_name, frm_names
373 |
374 | def rgb_loader(self, path):
375 | with open(path, 'rb') as f:
376 | img = Image.open(f)
377 | return img.convert('RGB')
378 |
379 | def binary_loader(self, path):
380 | with open(path, 'rb') as f:
381 | img = Image.open(f)
382 | return img.convert('L')
383 |
384 | def __len__(self):
385 | return self.size
386 |
--------------------------------------------------------------------------------