├── figures ├── fig_teaser.jpg ├── fig_dataset_examples.jpg ├── fig_dataset_statistics.jpg └── fig_related_datasets.jpg ├── src ├── data │ ├── PAVS10K_seqs_test.txt │ └── PAVS10K_seqs_train.txt ├── options.py ├── models │ ├── convgru.py │ ├── non_local_dot_product.py │ ├── rcrnet_vit.py │ ├── resnet_dilation.py │ ├── blocks.py │ ├── CAVNet.py │ └── VIT.py ├── CAVNet_test.py ├── CAVNet_train.py ├── utils.py └── data.py ├── video_to_frames.py ├── video_seq_link └── README.md /figures/fig_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_teaser.jpg -------------------------------------------------------------------------------- /figures/fig_dataset_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_dataset_examples.jpg -------------------------------------------------------------------------------- /figures/fig_dataset_statistics.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_dataset_statistics.jpg -------------------------------------------------------------------------------- /figures/fig_related_datasets.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeeZ93/PAV-SOD/HEAD/figures/fig_related_datasets.jpg -------------------------------------------------------------------------------- /src/data/PAVS10K_seqs_test.txt: -------------------------------------------------------------------------------- 1 | _-6QUCaLvQ_3I 2 | _Ellen2 3 | _-Ngj6C_RMK1g_1 4 | _-J0Q4L68o3xE_1 5 | _-I-43DzvhxX8_1 6 | _-J0Q4L68o3xE_2 7 | _-SdGCX2H-_Uk 8 | _-RbgxpagCY_c 9 | _-Oak26yVbibQ 10 | _conversation 11 | _-V3zp7XOGBhs 12 | _-MYmMZxmSc1U 13 | _-Uy5LTocHmoA 14 | _-Bvu9m__ZX60 15 | _-g4fQ5iOVzsI 16 | _-JBrp8aG4lro_1 17 | _-72f3ayGhMEA_6 18 | _-1An41lDIJ6Q 19 | _-ey9J7w98wlI_2 20 | _-ByBF08H-wDA 21 | _conversation2 22 | _-HNQMF7e6IL0 23 | _-72f3ayGhMEA_4 24 | _-o7JEBWV4CmY 25 | _-5h95uTtPeck 26 | _-kZB3KMhqqyI 27 | _-RrGhiInqXhc 28 | -------------------------------------------------------------------------------- /src/data/PAVS10K_seqs_train.txt: -------------------------------------------------------------------------------- 1 | _-RSYbTSTz91g 2 | _-IRG9Z7Y2uS4 3 | _-ey9J7w98wlI 4 | _-MzcdEI-tSUc_4 5 | _-Ngj6C_RMK1g_2 6 | _-72f3ayGhMEA_1 7 | _-P4KyjvsceZQ 8 | _-4fxKBGthpaw 9 | _-0suxwissusc 10 | _-nZJGt3ZVg3g 11 | _-bO43msZTfwA 12 | _Ellen 13 | _-SJIbpqgYWGw 14 | _-0cfJOmUaNNI_1 15 | _-1LM84FSzW0g_3 16 | _band 17 | _-0cfJOmUaNNI_2 18 | _-SZYXQ-6bfiQ 19 | _-jb5YxiXIsjU 20 | _-72f3ayGhMEA_3 21 | _-ZuXCMpVR24I 22 | _-MFVmxoXgeNQ 23 | _-72f3ayGhMEA_2 24 | _-eqmjLZGZ36k 25 | _-gTB1nfK-0Ac 26 | _-Gq4Y4gL3zSg 27 | _-LThm7VYvwxY 28 | _-n524y8uPUaU 29 | _-TCUsegqBZ_M 30 | _-eGGFGota5_A 31 | _-UpFZ2YaKeqM 32 | _-idLVnagjl_s 33 | _-1LM84FSzW0g_1 34 | _-nDu57CGqbLM 35 | _-1LM84FSzW0g_2 36 | _-gy4TI-6j5po 37 | _-4SilhsTuDU0 38 | _-MDwhMMSFkJ8 39 | _-G8pABGosD38 40 | _-dBM3eM9HOoA 41 | -------------------------------------------------------------------------------- /video_to_frames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | 5 | def Vid2Frm(): 6 | seqs = os.listdir(os.path.join(os.getcwd(), 'Videos')) # please take care of your own path 7 | for seq in seqs: 8 | seq_path = os.path.join(os.getcwd(), 'Videos', seq) 9 | save_path = os.path.join(os.getcwd(), 'Frames', seq[:-4]) 10 | if not os.path.exists(save_path): os.makedirs(save_path) 11 | cap = cv2.VideoCapture(seq_path) 12 | frames_num = int(cap.get(7)) 13 | countF = 0 14 | for i in range(frames_num): 15 | ret, frame = cap.read() 16 | cv2.imwrite(os.path.join(save_path, format(str(countF), '0>5s') + '.png'), frame) 17 | print(" {} frames are extracted.".format(countF)) 18 | countF += 1 19 | 20 | 21 | if __name__ == '__main__': 22 | Vid2Frm() -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--epoch', type=int, default=21, help='epoch number') 5 | parser.add_argument('--lr', type=float, default=2.5e-6, help='learning rate') 6 | parser.add_argument('--batchsize', type=int, default=1, help='training batch size') # set as 1 7 | parser.add_argument('--trainsize', type=int, default=416, help='training dataset size') 8 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 9 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 10 | parser.add_argument('--lat_weight', type=int, default=10, help='weighting latent loss') 11 | parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu') 12 | parser.add_argument('--tr_root', type=str, default=os.getcwd() + '/data/PAVS10K_seqs_train.txt', help='') 13 | parser.add_argument('--te_root', type=str, default=os.getcwd() + '/data/PAVS10K_seqs_test.txt', help='') 14 | parser.add_argument('--save_path', type=str, default=os.getcwd() + '/CAVNet_cpts/', help='the path to save models and logs') 15 | opt = parser.parse_args() 16 | -------------------------------------------------------------------------------- /src/models/convgru.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | import torch 8 | from torch import nn 9 | 10 | class ConvGRUCell(nn.Module): 11 | """ 12 | ICLR2016: Delving Deeper into Convolutional Networks for Learning Video Representations 13 | url: https://arxiv.org/abs/1511.06432 14 | """ 15 | def __init__(self, input_channels, hidden_channels, kernel_size, cuda_flag=True): 16 | super(ConvGRUCell, self).__init__() 17 | self.input_channels = input_channels 18 | self.cuda_flag = cuda_flag 19 | self.hidden_channels = hidden_channels 20 | self.kernel_size = kernel_size 21 | 22 | padding = self.kernel_size // 2 23 | self.reset_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 24 | self.update_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 25 | self.output_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 26 | # init 27 | for m in self.state_dict(): 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 30 | nn.init.constant_(m.bias, 0) 31 | 32 | def forward(self, x, hidden): 33 | if hidden is None: 34 | size_h = [x.data.size()[0], self.hidden_channels] + list(x.data.size()[2:]) 35 | if self.cuda_flag: 36 | hidden = torch.zeros(size_h).cuda() 37 | else: 38 | hidden = torch.zeros(size_h) 39 | 40 | inputs = torch.cat((x, hidden), dim=1) 41 | reset_gate = torch.sigmoid(self.reset_gate(inputs)) 42 | update_gate = torch.sigmoid(self.update_gate(inputs)) 43 | 44 | reset_hidden = reset_gate * hidden 45 | reset_inputs = torch.tanh(self.output_gate(torch.cat((x, reset_hidden), dim=1))) 46 | new_hidden = (1 - update_gate)*reset_inputs + update_gate*hidden 47 | 48 | return new_hidden -------------------------------------------------------------------------------- /video_seq_link: -------------------------------------------------------------------------------- 1 | Class-Sequence link 2 | Speaking-French https://www.youtube.com/watch?v=IRG9Z7Y2uS4 3 | Speaking-WaitingRoom https://www.youtube.com/watch?v=ey9J7w98wlI 4 | Speaking-Cooking https://www.youtube.com/watch?v=MzcdEI-tSUc 5 | Speaking-AudiIntro https://www.youtube.com/watch?v=72f3ayGhMEA 6 | Speaking-Ellen https://www.youtube.com/watch?v=-SbZ4_ir148 7 | Speaking-GroveAction https://www.youtube.com/watch?v=0cfJOmUaNNI 8 | Speaking-Warehouse https://www.youtube.com/watch?v=1LM84FSzW0g 9 | Speaking-GroveConvo https://www.youtube.com/watch?v=0cfJOmUaNNI 10 | Speaking-Surfing https://www.youtube.com/watch?v=SZYXQ-6bfiQ 11 | Speaking-Passageway https://www.youtube.com/watch?v=jb5YxiXIsjU 12 | Speaking-RuralDriving https://www.youtube.com/watch?v=72f3ayGhMEA 13 | Speaking-Lawn https://www.youtube.com/watch?v=ZuXCMpVR24I 14 | Speaking-AudiAd https://www.youtube.com/watch?v=72f3ayGhMEA 15 | Speaking-ScenePlay https://www.youtube.com/watch?v=Gq4Y4gL3zSg 16 | Speaking-UrbanDriving https://www.youtube.com/watch?v=n524y8uPUaU 17 | Speaking-Interview https://www.youtube.com/watch?v=UpFZ2YaKeqM 18 | Speaking-Telephone https://www.youtube.com/watch?v=idLVnagjl_s 19 | Speaking-Walking https://www.youtube.com/watch?v=1LM84FSzW0g 20 | Speaking-Bridge https://www.youtube.com/watch?v=1LM84FSzW0g 21 | Speaking-Breakfast https://www.youtube.com/watch?v=4SilhsTuDU0 22 | Speaking-Debate https://www.youtube.com/watch?v=-SbZ4_ir148 23 | Speaking-BadmintonConvo https://www.youtube.com/watch?v=Ngj6C_RMK1g 24 | Speaking-Director https://www.youtube.com/watch?v=J0Q4L68o3xE 25 | Speaking-ChineseAd https://www.youtube.com/watch?v=J0Q4L68o3xE 26 | Speaking-Exhibition https://www.youtube.com/watch?v=RbgxpagCY_c 27 | Speaking-PianoConvo https://www.youtube.com/watch?v=Pxe920CL0GY 28 | Speaking-FilmingSite https://www.youtube.com/watch?v=MYmMZxmSc1U 29 | Speaking-Brothers https://www.youtube.com/watch?v=Uy5LTocHmoA 30 | Speaking-Rap https://www.youtube.com/watch?v=g4fQ5iOVzsI 31 | Speaking-Spanish https://www.youtube.com/watch?v=JBrp8aG4lro 32 | Speaking-Questions https://www.youtube.com/watch?v=ey9J7w98wlI 33 | Speaking-PianoMono https://www.youtube.com/watch?v=Pxe920CL0GY 34 | Speaking-Snowfield https://www.youtube.com/watch?v=o7JEBWV4CmY 35 | Speaking-Melodrama https://www.youtube.com/watch?v=5h95uTtPeck 36 | Speaking-Gymnasium https://www.youtube.com/watch?v=kZB3KMhqqyI 37 | Music-Guitar https://www.youtube.com/watch?v=4fxKBGthpaw 38 | Music-Subway https://www.youtube.com/watch?v=bO43msZTfwA 39 | Music-Jazz https://www.youtube.com/watch?v=SJIbpqgYWGw 40 | Music-Bass https://www.youtube.com/watch?v=baRj0O4cqgI 41 | Music-Canon https://www.youtube.com/watch?v=MFVmxoXgeNQ 42 | Music-MICOSinging https://www.youtube.com/watch?v=gTB1nfK-0Ac 43 | Music-Clarinet https://www.youtube.com/watch?v=TCUsegqBZ_M 44 | Music-Trumpet https://www.youtube.com/watch?v=eGGFGota5_A 45 | Music-PianoSaxophone https://www.youtube.com/watch?v=MDwhMMSFkJ8 46 | Music-Chorus https://www.youtube.com/watch?v=dBM3eM9HOoA 47 | Music-Studio https://www.youtube.com/watch?v=6QUCaLvQ_3I 48 | Music-Church https://www.youtube.com/watch?v=SdGCX2H-_Uk 49 | Music-Duet https://www.youtube.com/watch?v=Oak26yVbibQ 50 | Music-Blues https://www.youtube.com/watch?v=V3zp7XOGBhs 51 | Music-Violins https://www.youtube.com/watch?v=Bvu9m__ZX60 52 | Music-SingingDancing https://www.youtube.com/watch?v=1An41lDIJ6Q 53 | Miscellanea-Beach https://www.youtube.com/watch?v=RSYbTSTz91g 54 | Miscellanea-BadmintonGym https://www.youtube.com/watch?v=Ngj6C_RMK1g 55 | Miscellanea-InVehicle https://www.youtube.com/watch?v=P4KyjvsceZQ 56 | Miscellanea-Japanese https://www.youtube.com/watch?v=0suxwissusc 57 | Miscellanea-Tennis https://www.youtube.com/watch?v=nZJGt3ZVg3g 58 | Miscellanea-Diesel https://www.youtube.com/watch?v=eqmjLZGZ36k 59 | Miscellanea-Park https://www.youtube.com/watch?v=LThm7VYvwxY 60 | Miscellanea-Lion https://www.youtube.com/watch?v=nDu57CGqbLM 61 | Miscellanea-Carriage https://www.youtube.com/watch?v=gy4TI-6j5po 62 | Miscellanea-Platform https://www.youtube.com/watch?v=G8pABGosD38 63 | Miscellanea-Dog https://www.youtube.com/watch?v=I-43DzvhxX8 64 | Miscellanea-RacingCar https://www.youtube.com/watch?v=72f3ayGhMEA 65 | Miscellanea-Train https://www.youtube.com/watch?v=ByBF08H-wDA 66 | Miscellanea-Football https://www.youtube.com/watch?v=HNQMF7e6IL0 67 | Miscellanea-ParkingLot https://www.youtube.com/watch?v=72f3ayGhMEA 68 | Miscellanea-Skiing https://www.youtube.com/watch?v=RrGhiInqXhc 69 | -------------------------------------------------------------------------------- /src/CAVNet_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os, argparse 4 | import cv2 5 | from data import dataset_inference, clip_length 6 | 7 | from models.CAVNet import cavnet 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--testsize', type=int, default=416, help='testing size') 11 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 12 | parser.add_argument('--test_path', type=str, default=os.getcwd() + '/data/', help='test dataset path') 13 | parser.add_argument('--forward_iter', type=int, default=5, help='sample times') 14 | opt = parser.parse_args() 15 | 16 | dataset_path = opt.test_path 17 | 18 | # set device for test 19 | if opt.gpu_id == '0': 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | print('USE GPU 0') 22 | elif opt.gpu_id == '1': 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 24 | print('USE GPU 1') 25 | 26 | # load the model 27 | model = cavnet() 28 | model.load_state_dict(torch.load(os.getcwd() + '/model_test/CAVNet_final.pth')) 29 | model.cuda() 30 | model.eval() 31 | 32 | #test 33 | TIME = [] 34 | save_path_pred = './predictions/' 35 | save_path_sigma = './uncertainties/' 36 | if not os.path.exists(save_path_pred): os.makedirs(save_path_pred) 37 | if not os.path.exists(save_path_sigma): os.makedirs(save_path_sigma) 38 | root = os.path.join(dataset_path, 'PAVS10K_seqs_test.txt') 39 | test_loader = dataset_inference(root, opt.testsize) 40 | for i in range(test_loader.size): 41 | imgs, gts, audios, seq_name, frm_names = test_loader.load_data() 42 | for idx in range(clip_length): 43 | gts[idx] = np.asarray(gts[idx], np.float32) 44 | gts[idx] /= (gts[idx].max() + 1e-8) 45 | imgs[idx] = imgs[idx].cuda() 46 | audios = audios.cuda() 47 | 48 | # multiple forward 49 | pred1, pred2, pred3, ent1, ent2, ent3 = [], [], [], [], [], [] 50 | with torch.no_grad(): 51 | for ff in range(opt.forward_iter): 52 | pred_curr = model(imgs, audios) 53 | p1, p2, p3 = torch.sigmoid(pred_curr[0]), torch.sigmoid(pred_curr[1]), torch.sigmoid(pred_curr[2]) 54 | 55 | pred1.append(p1) 56 | pred2.append(p2) 57 | pred3.append(p3) 58 | ent1.append(-1 * p1 * torch.log(p1 + 1e-8)) 59 | ent2.append(-1 * p2 * torch.log(p2 + 1e-8)) 60 | ent3.append(-1 * p3 * torch.log(p3 + 1e-8)) 61 | 62 | pred1_c, pred2_c, pred3_c = torch.cat(pred1, dim=1), torch.cat(pred2, dim=1), torch.cat(pred3, dim=1) 63 | pred1_mu, pred2_mu, pred3_mu = torch.mean(pred1_c, 1, keepdim=True), torch.mean(pred2_c, 1, keepdim=True), \ 64 | torch.mean(pred3_c, 1, keepdim=True) 65 | ent1_c, ent2_c, ent3_c = torch.cat(ent1, dim=1), torch.cat(ent2, dim=1), torch.cat(ent3, dim=1) 66 | ent1_mu, ent2_mu, ent3_mu = torch.mean(ent1_c, 1, keepdim=True), torch.mean(ent2_c, 1, keepdim=True), \ 67 | torch.mean(ent3_c, 1, keepdim=True) 68 | 69 | # before save results 70 | pred1_mu, pred2_mu, pred3_mu = pred1_mu.data.cpu().numpy().squeeze(), pred2_mu.data.cpu().numpy().squeeze(),\ 71 | pred3_mu.data.cpu().numpy().squeeze() 72 | ent1_mu, ent2_mu, ent3_mu = ent1_mu.data.cpu().numpy().squeeze(), ent2_mu.data.cpu().numpy().squeeze(), \ 73 | ent3_mu.data.cpu().numpy().squeeze() 74 | 75 | pred1_mu = (pred1_mu - pred1_mu.min()) / (pred1_mu.max() - pred1_mu.min() + 1e-8) 76 | pred2_mu = (pred2_mu - pred2_mu.min()) / (pred2_mu.max() - pred2_mu.min() + 1e-8) 77 | pred3_mu = (pred3_mu - pred3_mu.min()) / (pred3_mu.max() - pred3_mu.min() + 1e-8) 78 | ent1_mu = 255 * (ent1_mu - ent1_mu.min()) / (ent1_mu.max() - ent1_mu.min() + 1e-8) 79 | ent2_mu = 255 * (ent2_mu - ent2_mu.min()) / (ent2_mu.max() - ent2_mu.min() + 1e-8) 80 | ent3_mu = 255 * (ent3_mu - ent3_mu.min()) / (ent3_mu.max() - ent3_mu.min() + 1e-8) 81 | 82 | uc1, uc2, uc3 = ent1_mu.astype(np.uint8), ent2_mu.astype(np.uint8), ent3_mu.astype(np.uint8) 83 | uc1, uc2, uc3 = cv2.applyColorMap(uc1, cv2.COLORMAP_JET), cv2.applyColorMap(uc2, cv2.COLORMAP_JET), \ 84 | cv2.applyColorMap(uc3, cv2.COLORMAP_JET) 85 | 86 | # save results 87 | curr_save_pth_pred = os.path.join(save_path_pred, seq_name) 88 | curr_save_pth_sigma = os.path.join(save_path_sigma, seq_name) 89 | if not os.path.exists(curr_save_pth_pred): os.makedirs(curr_save_pth_pred) 90 | if not os.path.exists(curr_save_pth_sigma): os.makedirs(curr_save_pth_sigma) 91 | 92 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[0])) 93 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[0]), pred1_mu * 255) 94 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[1])) 95 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[1]), pred2_mu * 255) 96 | print('save prediction to: ', os.path.join(curr_save_pth_pred, frm_names[2])) 97 | cv2.imwrite(os.path.join(curr_save_pth_pred, frm_names[2]), pred3_mu * 255) 98 | 99 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[0])) 100 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[0]), uc1) 101 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[1])) 102 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[1]), uc2) 103 | print('save uncertainty to: ', os.path.join(curr_save_pth_sigma, frm_names[2])) 104 | cv2.imwrite(os.path.join(curr_save_pth_sigma, frm_names[2]), uc3) 105 | 106 | print('Test Done!') 107 | -------------------------------------------------------------------------------- /src/models/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding: utf-8 4 | # 5 | # Author: AlexHex7 6 | # URL: https://github.com/AlexHex7/Non-local_pytorch 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class _NonLocalBlockND(nn.Module): 14 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 15 | super(_NonLocalBlockND, self).__init__() 16 | 17 | assert dimension in [1, 2, 3] 18 | 19 | self.dimension = dimension 20 | self.sub_sample = sub_sample 21 | 22 | self.in_channels = in_channels 23 | self.inter_channels = inter_channels 24 | 25 | if self.inter_channels is None: 26 | self.inter_channels = in_channels // 2 27 | if self.inter_channels == 0: 28 | self.inter_channels = 1 29 | 30 | if dimension == 3: 31 | conv_nd = nn.Conv3d 32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 33 | bn = nn.BatchNorm3d 34 | elif dimension == 2: 35 | conv_nd = nn.Conv2d 36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 37 | bn = nn.BatchNorm2d 38 | else: 39 | conv_nd = nn.Conv1d 40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 41 | bn = nn.BatchNorm1d 42 | 43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 44 | kernel_size=1, stride=1, padding=0) 45 | 46 | if bn_layer: 47 | self.W = nn.Sequential( 48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0), 50 | bn(self.in_channels) 51 | ) 52 | nn.init.constant_(self.W[1].weight, 0) 53 | nn.init.constant_(self.W[1].bias, 0) 54 | else: 55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | nn.init.constant_(self.W.weight, 0) 58 | nn.init.constant_(self.W.bias, 0) 59 | 60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 61 | kernel_size=1, stride=1, padding=0) 62 | 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x): 71 | ''' 72 | :param x: (b, c, t, h, w) 73 | :return: 74 | ''' 75 | 76 | batch_size = x.size(0) 77 | 78 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 79 | g_x = g_x.permute(0, 2, 1) 80 | 81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 82 | theta_x = theta_x.permute(0, 2, 1) 83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 84 | f = torch.matmul(theta_x, phi_x) 85 | N = f.size(-1) 86 | f_div_C = f / N 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | return z 95 | 96 | 97 | class NONLocalBlock1D(_NonLocalBlockND): 98 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 99 | super(NONLocalBlock1D, self).__init__(in_channels, 100 | inter_channels=inter_channels, 101 | dimension=1, sub_sample=sub_sample, 102 | bn_layer=bn_layer) 103 | 104 | 105 | class NONLocalBlock2D(_NonLocalBlockND): 106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 107 | super(NONLocalBlock2D, self).__init__(in_channels, 108 | inter_channels=inter_channels, 109 | dimension=2, sub_sample=sub_sample, 110 | bn_layer=bn_layer) 111 | 112 | 113 | class NONLocalBlock3D(_NonLocalBlockND): 114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 115 | super(NONLocalBlock3D, self).__init__(in_channels, 116 | inter_channels=inter_channels, 117 | dimension=3, sub_sample=sub_sample, 118 | bn_layer=bn_layer) 119 | 120 | 121 | if __name__ == '__main__': 122 | import torch 123 | 124 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 125 | img = torch.zeros(2, 3, 20) 126 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 127 | out = net(img) 128 | print(out.size()) 129 | 130 | img = torch.zeros(2, 3, 20, 20) 131 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 132 | out = net(img) 133 | print(out.size()) 134 | 135 | img = torch.randn(2, 3, 8, 20, 20) 136 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 137 | out = net(img) 138 | print(out.size()) -------------------------------------------------------------------------------- /src/CAVNet_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from datetime import datetime 5 | from data import get_loader, test_dataset, clip_length 6 | from utils import clip_gradient 7 | import logging 8 | import torch.backends.cudnn as cudnn 9 | from options import opt 10 | from utils import print_network, structure_loss, linear_annealing, l2_regularisation 11 | import cv2 12 | 13 | from models.CAVNet import cavnet 14 | 15 | 16 | #set the device for training 17 | if opt.gpu_id == '0': 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | print('USE GPU 0') 20 | elif opt.gpu_id == '1': 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 22 | print('USE GPU 1') 23 | cudnn.benchmark = True 24 | 25 | #build the model 26 | model = cavnet() 27 | print_network(model, 'CAVNet') 28 | if(opt.load is not None): 29 | model.load_state_dict(torch.load(opt.load)) 30 | print('load model from ', opt.load) 31 | model.cuda() 32 | 33 | params = model.parameters() 34 | optimizer = torch.optim.Adam(params, opt.lr) 35 | 36 | #set the path 37 | tr_root = opt.tr_root 38 | te_root = opt.te_root 39 | save_path = opt.save_path 40 | 41 | if not os.path.exists(save_path): 42 | os.makedirs(save_path) 43 | 44 | #load data 45 | print('load data...') 46 | train_loader = get_loader(tr_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 47 | test_loader = test_dataset(te_root, opt.trainsize) 48 | total_step = len(train_loader) 49 | 50 | logging.basicConfig(filename=save_path+'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 51 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 52 | logging.info("CAVNet-Train") 53 | logging.info("Config") 54 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};load:{};save_path:{}'. 55 | format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.load,save_path)) 56 | 57 | step = 0 58 | best_mae = 1 59 | best_epoch = 0 60 | 61 | #train function 62 | def train(train_loader, model, optimizer, epoch, save_path): 63 | global step 64 | model.train() 65 | loss_all = 0 66 | epoch_step = 0 67 | try: 68 | for i, (imgs, gts, audios, seq_name) in enumerate(train_loader, start=1): 69 | optimizer.zero_grad() 70 | 71 | audios = audios.cuda() 72 | for idx in range(clip_length): # default as three consecutive frames 73 | imgs[idx], gts[idx] = imgs[idx].cuda(), gts[idx].cuda() 74 | 75 | # debug 76 | # cv2.imwrite('img.png', imgs[1][0].permute(1, 2, 0).cpu().data.numpy() * 255) 77 | # cv2.imwrite('er_img.png', er_imgs[1][0].permute(1, 2, 0).cpu().data.numpy() * 255) 78 | # cv2.imwrite('gt.png', gts[1][0].permute(1, 2, 0).cpu().data.numpy() * 255) 79 | # cv2.imwrite('aem.png', aems[1][0].permute(1, 2, 0).cpu().data.numpy() * 255) 80 | # cv2.imwrite('cube1.png', cube_gts[1][0].permute(1, 2, 0).cpu().data.numpy() * 255) 81 | # cv2.imwrite('cube2.png', cube_gts[1][1].permute(1, 2, 0).cpu().data.numpy() * 255) 82 | # cv2.imwrite('cube3.png', cube_gts[1][2].permute(1, 2, 0).cpu().data.numpy() * 255) 83 | # cv2.imwrite('cube4.png', cube_gts[1][3].permute(1, 2, 0).cpu().data.numpy() * 255) 84 | # cv2.imwrite('cube5.png', cube_gts[1][4].permute(1, 2, 0).cpu().data.numpy() * 255) 85 | # cv2.imwrite('cube6.png', cube_gts[1][5].permute(1, 2, 0).cpu().data.numpy() * 255) 86 | # torchaudio.save('audio.wav', audios[0].cpu(), 48000) 87 | 88 | preds_prior, preds_post, lat_loss = model(imgs, audios, gts) 89 | 90 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch) 91 | reg_loss = (l2_regularisation(model.enc_mm_x) + l2_regularisation(model.enc_mm_xy) + \ 92 | l2_regularisation(model.backbone_dec_prior) + l2_regularisation(model.backbone_dec_post)) * 1e-4 93 | loss_list = [] 94 | for rr in range(clip_length): # prior 95 | loss_list.append(structure_loss(preds_prior[rr], gts[rr])) 96 | for pp in range(clip_length): # post 97 | loss_list.append(structure_loss(preds_post[pp], gts[pp])) 98 | for cc in range(clip_length): # latent loss 99 | loss_list.append(lat_loss[cc] * anneal_reg * opt.lat_weight) 100 | loss = sum(loss_list) / (clip_length * 3) 101 | loss = loss + reg_loss 102 | 103 | loss_vis_prior, loss_vis_post, loss_vis_lat = sum(loss_list[:3]) / clip_length, \ 104 | sum(loss_list[3:6]) / clip_length, sum(loss_list[6:]) / clip_length 105 | 106 | loss.backward() 107 | 108 | clip_gradient(optimizer, opt.clip) 109 | optimizer.step() 110 | step += 1 111 | epoch_step += 1 112 | loss_all += loss.data 113 | if i % 500 == 0 or i == total_step or i == 1: 114 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, ' 115 | 'LossPri: {:.4f}, LossPos: {:.4f}, LossLat: {:.4f}'. 116 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.data, 117 | loss_vis_prior.data, loss_vis_post.data, loss_vis_lat.data)) 118 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, ' 119 | 'LossPri: {:.4f}, LossPos: {:.4f}, LossLat: {:.4f}'. 120 | format(epoch, opt.epoch, i, total_step, loss.data, 121 | loss_vis_prior.data, loss_vis_post.data, loss_vis_lat.data)) 122 | 123 | loss_all /= epoch_step 124 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all)) 125 | if (epoch) % 5 == 0: 126 | torch.save(model.state_dict(), save_path+'CAVNet_epoch_{}.pth'.format(epoch)) 127 | except KeyboardInterrupt: 128 | print('Keyboard Interrupt: save model and exit.') 129 | if not os.path.exists(save_path): 130 | os.makedirs(save_path) 131 | torch.save(model.state_dict(), save_path+'CAVNet_epoch_{}.pth'.format(epoch+1)) 132 | print('save checkpoints successfully!') 133 | raise 134 | 135 | 136 | if __name__ == '__main__': 137 | print("Start train...") 138 | for epoch in range(1, opt.epoch): 139 | train(train_loader, model, optimizer, epoch, save_path) 140 | -------------------------------------------------------------------------------- /src/models/rcrnet_vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from models.resnet_dilation import Bottleneck, conv1x1 14 | from models.blocks import ( 15 | _make_encoder, 16 | forward_vit, 17 | ) 18 | 19 | class _ConvBatchNormReLU(nn.Sequential): 20 | def __init__(self, 21 | in_channels, 22 | out_channels, 23 | kernel_size, 24 | stride, 25 | padding, 26 | dilation, 27 | relu=True, 28 | ): 29 | super(_ConvBatchNormReLU, self).__init__() 30 | self.add_module( 31 | "conv", 32 | nn.Conv2d( 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | kernel_size=kernel_size, 36 | stride=stride, 37 | padding=padding, 38 | dilation=dilation, 39 | bias=False, 40 | ), 41 | ) 42 | #self.add_module( 43 | # "bn", 44 | # nn.BatchNorm2d(out_channels), 45 | #) 46 | 47 | if relu: 48 | self.add_module("relu", nn.ReLU()) 49 | 50 | def forward(self, x): 51 | return super(_ConvBatchNormReLU, self).forward(x) 52 | 53 | class _ASPPModule(nn.Module): 54 | """Atrous Spatial Pyramid Pooling with image pool""" 55 | 56 | def __init__(self, in_channels, out_channels, output_stride): 57 | super(_ASPPModule, self).__init__() 58 | if output_stride == 8: 59 | pyramids = [12, 24, 36] 60 | elif output_stride == 16: 61 | pyramids = [6, 12, 18] 62 | self.stages = nn.Module() 63 | self.stages.add_module( 64 | "c0", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1) 65 | ) 66 | for i, (dilation, padding) in enumerate(zip(pyramids, pyramids)): 67 | self.stages.add_module( 68 | "c{}".format(i + 1), 69 | _ConvBatchNormReLU(in_channels, out_channels, 3, 1, padding, dilation), 70 | ) 71 | self.imagepool = nn.Sequential( 72 | OrderedDict( 73 | [ 74 | ("pool", nn.AdaptiveAvgPool2d((1,1))), 75 | ("conv", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)), 76 | ] 77 | ) 78 | ) 79 | self.fire = nn.Sequential( 80 | OrderedDict( 81 | [ 82 | ("conv", _ConvBatchNormReLU(out_channels * 5, out_channels, 3, 1, 1, 1)), 83 | ("dropout", nn.Dropout2d(0.1)) 84 | ] 85 | ) 86 | ) 87 | 88 | def forward(self, x): 89 | h = self.imagepool(x) 90 | h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)] 91 | for stage in self.stages.children(): 92 | h += [stage(x)] 93 | h = torch.cat(h, dim=1) 94 | h = self.fire(h) 95 | 96 | return h 97 | 98 | class _RefinementModule(nn.Module): 99 | """ Reduce channels and refinment module""" 100 | 101 | def __init__(self, 102 | bottom_up_channels, 103 | reduce_channels, 104 | top_down_channels, 105 | refinement_channels, 106 | expansion=2 107 | ): 108 | super(_RefinementModule, self).__init__() 109 | downsample = None 110 | if bottom_up_channels != reduce_channels: 111 | downsample = nn.Sequential( 112 | conv1x1(bottom_up_channels, reduce_channels), 113 | nn.BatchNorm2d(reduce_channels), 114 | ) 115 | self.skip = Bottleneck(bottom_up_channels, reduce_channels // expansion, 1, 1, downsample, expansion) 116 | self.refine = _ConvBatchNormReLU(reduce_channels + top_down_channels, refinement_channels, 3, 1, 1, 1) 117 | 118 | def forward(self, td, bu): 119 | td = self.skip(td) 120 | x = torch.cat((bu, td), dim=1) 121 | x = self.refine(x) 122 | 123 | return x 124 | 125 | class RCRNet_vit(nn.Module): 126 | 127 | def __init__(self, n_classes, output_stride): 128 | super(RCRNet_vit, self).__init__() 129 | 130 | # vit encoder 131 | hooks = { 132 | "vitb_rn50_384": [0, 1, 8, 11], 133 | "vitb16_384": [2, 5, 8, 11], 134 | "vitl16_384": [5, 11, 17, 23], 135 | } 136 | self.pretrained = _make_encoder( 137 | "vitb_rn50_384", 138 | 256, 139 | False, # the pre-train model will be loaded anyway 140 | groups=1, 141 | expand=False, 142 | exportable=False, 143 | hooks=hooks["vitb_rn50_384"], 144 | use_readout="project", 145 | enable_attention_hooks=False, 146 | ) 147 | 148 | self.aspp = _ASPPModule(768, 256, output_stride) 149 | # Decoder 150 | self.decoder = nn.Sequential( 151 | OrderedDict( 152 | [ 153 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)), 154 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)), 155 | ] 156 | ) 157 | ) 158 | self.add_module("refinement1", _RefinementModule(768, 96, 256, 128, 2)) 159 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2)) 160 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2)) 161 | 162 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 163 | 164 | #if pretrained: 165 | # for key in self.state_dict(): 166 | # if 'resnet' not in key: 167 | # self.init_layer(key) 168 | 169 | # def init_layer(self, key): 170 | # if key.split('.')[-1] == 'weight': 171 | # if 'conv' in key: 172 | # if self.state_dict()[key].ndimension() >= 2: 173 | # nn.init.kaiming_normal_(self.state_dict()[key], mode='fan_out', nonlinearity='relu') 174 | # elif 'bn' in key: 175 | # self.state_dict()[key][...] = 1 176 | # elif key.split('.')[-1] == 'bias': 177 | # self.state_dict()[key][...] = 0.001 178 | 179 | def feat_conv(self, x): 180 | ''' 181 | Spatial feature extractor 182 | ''' 183 | block1, block2, block3, block4 = forward_vit(self.pretrained, x) 184 | block4 = self.aspp(block4) 185 | 186 | return block1, block2, block3, block4 187 | 188 | def seg_conv(self, block1, block2, block3, block4, shape): 189 | ''' 190 | Pixel-wise classifer 191 | ''' 192 | block4 = self.upsample2(block4) 193 | 194 | bu1 = self.refinement1(block3, block4) 195 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False) 196 | bu2 = self.refinement2(block2, bu1) 197 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False) 198 | bu3 = self.refinement3(block1, bu2) 199 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False) 200 | seg = self.decoder(bu3) 201 | 202 | return seg 203 | 204 | def forward(self, x): 205 | block1, block2, block3, block4 = self.feat_conv(x) 206 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:]) 207 | 208 | return seg -------------------------------------------------------------------------------- /src/models/resnet_dilation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # This code is based on torchvison resnet 5 | # URL: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | 7 | import torch.nn as nn 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=padding, dilation=dilation, bias=False) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride, dilation, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes, 1, dilation, dilation) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride, dilation, downsample=None, expansion=4): 71 | super(Bottleneck, self).__init__() 72 | self.expansion = expansion 73 | self.conv1 = conv1x1(inplanes, planes) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = conv1x1(planes, planes * self.expansion) 78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, output_stride, num_classes=1000, input_channels=3): 109 | super(ResNet, self).__init__() 110 | if output_stride == 8: 111 | stride = [1, 2, 1, 1] 112 | dilation = [1, 1, 2, 2] 113 | elif output_stride == 16: 114 | stride = [1, 2, 2, 1] 115 | dilation = [1, 1, 1, 2] 116 | 117 | self.inplanes = 64 118 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 119 | bias=False) 120 | self.bn1 = nn.BatchNorm2d(64) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0], dilation=dilation[0]) 124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1], dilation=dilation[1]) 125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2], dilation=dilation[2]) 126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3], dilation=dilation[3]) 127 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 128 | self.fc = nn.Linear(512 * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, nn.BatchNorm2d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride, dilation): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | conv1x1(self.inplanes, planes * block.expansion, stride), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, dilation, downsample)) 147 | self.inplanes = planes * block.expansion 148 | for _ in range(1, blocks): 149 | layers.append(block(self.inplanes, planes, 1, dilation)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.fc(x) 167 | 168 | return x 169 | 170 | 171 | def resnet18(pretrained=False, **kwargs): 172 | """Constructs a ResNet-18 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 179 | return model 180 | 181 | 182 | def resnet34(pretrained=False, **kwargs): 183 | """Constructs a ResNet-34 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 190 | return model 191 | 192 | 193 | def resnet50(pretrained=False, **kwargs): 194 | """Constructs a ResNet-50 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 201 | return model 202 | 203 | 204 | def resnet101(pretrained=False, **kwargs): 205 | """Constructs a ResNet-101 model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 212 | return model 213 | 214 | 215 | def resnet152(pretrained=False, **kwargs): 216 | """Constructs a ResNet-152 model. 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 221 | if pretrained: 222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 223 | return model 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [PAV-SOD: A New Task Towards Panoramic Audiovisual Saliency Detection (TOMM 2022)](https://drive.google.com/file/d/1-1RcARcbz4pACFzkjXcp6MP8R9CGScqI/view?usp=sharing) 2 | 3 | Object-level audiovisual saliency detection in 360° panoramic real-life dynamic scenes is important for exploring and modeling human perception in immersive environments, also for aiding the development of virtual, augmented and mixed reality applications in the fields of such as education, social network, entertainment and training. To this end, we propose a new task, panoramic audiovisual salient object detection (PAV-SOD), which aims to segment the objects grasping most of the human attention in 360° panoramic videos reflecting real-life daily scenes. To support the task, we collect PAVS10K, the first panoramic video dataset for audiovisual salient object detection, which consists of 67 4K-resolution equirectangular videos with per-video labels including hierarchical scene categories and associated attributes depicting specific challenges for conducting PAV-SOD, and 10,465 uniformly sampled video frames with manually annotated object-level and instance-level pixel-wise masks. The coarse-to-fine annotations enable multi-perspective analysis regarding PAV-SOD modeling. We further systematically benchmark 13 state-of-the-art salient object detection (SOD)/video object segmentation (VOS) methods based on our PAVS10K. Besides, we propose a new baseline network, which takes advantage of both visual and audio cues of 360° video frames by using a new conditional variational auto-encoder (CVAE). Our CVAE-based audiovisual network, namely CAV-Net, consists of a spatial-temporal visual segmentation network, a convolutional audio-encoding network and audiovisual distribution estimation modules. As a result, our CAV-Net outperforms all competing models and is able to estimate the aleatoric uncertainties within PAVS10K. With extensive experimental results, we gain several findings about PAV-SOD challenges and insights towards PAV-SOD model interpretability. We hope that our work could serve as a starting point for advancing SOD towards immersive media. 4 | 5 | ------ 6 | 7 | # PAVS10K 8 | 9 |

10 |
11 | 12 | Figure 1: An example of our PAVS10K where coarse-to-fine annotations are provided, based on a guidance of fixations acquired from subjective experiments conducted by multiple (N) subjects wearing Head-Mounted Displays (HMDs) and headphones. Each (e.g., fk, fl and fn, where random integral values {k, l, n} ∈ [1, T ]) of the total equirectangular (ER) video frames T of the sequence “Speaking”(Super-class)-“Brothers”(sub-class) are manually labeled with both object-level and instance-level pixel-wise masks. According to the features of defined salient objects within each of the sequences, multiple attributes, e.g., “multiple objects” (MO), “competing sounds” (CS), “geometrical distortion” (GD), “motion blur” (MB), “occlusions” (OC) and “low resolution” (LR) are further annotated to enable detailed analysis for PAV-SOD modeling. 13 | 14 |

15 | 16 |

17 |
18 | 19 | Figure 2: Summary of widely used salient object detection (SOD)/video object segmentation (VOS) datasets and PAVS10K. #Img: The number of images/video frames. #GT: The number of object-level pixel-wise masks (ground truth for SOD). Pub. = Publication. Obj.-Level = Object-Level Labels. Ins.-Level = Instance-Level Labels. Fix. GT = Fixation Maps. † denotes equirectangular images. 20 | 21 |

22 | 23 |

24 |
25 | 26 | Figure 3: Examples of challenging attributes on equirectangular images from our PAVS10K, with instance-level ground truth and fixations as annotation guidance. {𝑓𝑘, 𝑓𝑙, 𝑓𝑛} denote random frames of a given video. 27 | 28 |

29 | 30 |

31 |
32 | 33 | Figure 4: Statistics of the proposed PAVS10K. (a) Super-/sub-category information. (b) Instance density (labeled frames per sequence) of each sub-class. (c) Sound sources of PAVS10K scenes, such as musical instruments, human instances and animals. 34 | 35 |

36 | 37 | ------ 38 | 39 | # Benchmark Models 40 | 41 | **No.** | **Year** | **Pub.** | **Title** | **Links** 42 | :-: | :-:| :-: | :- | :-: 43 | 01 | **2019**| **CVPR** | Cascaded Partial Decoder for Fast and Accurate Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_Cascaded_Partial_Decoder_for_Fast_and_Accurate_Salient_Object_Detection_CVPR_2019_paper.pdf)/[Code](https://github.com/wuzhe71/CPD) 44 | 02 | **2019**| **CVPR** | See More, Know More: Unsupervised Video Object Segmentation with Co-Attention Siamese Networks | [Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_See_More_Know_More_Unsupervised_Video_Object_Segmentation_With_Co-Attention_CVPR_2019_paper.pdf)/[Code](https://github.com/carrierlxk/COSNet) 45 | 03 | **2019**| **ICCV** | Stacked Cross Refinement Network for Edge-Aware Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Wu_Stacked_Cross_Refinement_Network_for_Edge-Aware_Salient_Object_Detection_ICCV_2019_paper.pdf)/[Code](https://github.com/wuzhe71/SCRN) 46 | 04 | **2019**| **ICCV** | Semi-Supervised Video Salient Object Detection Using Pseudo-Labels | [Paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Yan_Semi-Supervised_Video_Salient_Object_Detection_Using_Pseudo-Labels_ICCV_2019_paper.pdf)/[Code](https://github.com/Kinpzz/RCRNet-Pytorch) 47 | 05 | **2020**| **AAAI** | F³Net: Fusion, Feedback and Focus for Salient Object Detection | [Paper](https://ojs.aaai.org/index.php/AAAI/article/download/6916/6770)/[Code](https://github.com/weijun88/F3Net) 48 | 06 | **2020**| **AAAI** | Pyramid Constrained Self-Attention Network for Fast Video Salient Object Detection | [Paper](https://ojs.aaai.org/index.php/AAAI/article/view/6718/6572)/[Code](https://github.com/guyuchao/PyramidCSA) 49 | 07 | **2020**| **CVPR** | Multi-scale Interactive Network for Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Pang_Multi-Scale_Interactive_Network_for_Salient_Object_Detection_CVPR_2020_paper.pdf)/[Code](https://github.com/lartpang/MINet) 50 | 08 | **2020**| **CVPR** | Label Decoupling Framework for Salient Object Detection | [Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Wei_Label_Decoupling_Framework_for_Salient_Object_Detection_CVPR_2020_paper.pdf)/[Code](https://github.com/weijun88/LDF) 51 | 09 | **2020**| **ECCV** | Highly Efficient Salient Object Detection with 100K Parameters | [Paper](http://mftp.mmcheng.net/Papers/20EccvSal100k.pdf)/[Code](https://github.com/ShangHua-Gao/SOD100K) 52 | 10 | **2020**| **ECCV** | Suppress and Balance: A Simple Gated Network for Salient Object Detection | [Paper](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123470035.pdf)/[Code](https://github.com/Xiaoqi-Zhao-DLUT/GateNet-RGB-Saliency) 53 | 11 | **2020**| **BMVC** | Making a Case for 3D Convolutions for Object Segmentation in Videos | [Paper](https://www.bmvc2020-conference.com/assets/papers/0233.pdf)/[Code](https://github.com/sabarim/3DC-Seg) 54 | 12 | **2020**| **SPL** | FANet: Features Adaptation Network for 360° Omnidirectional Salient Object Detection | [Paper](https://ieeexplore.ieee.org/document/9211754)/[Code](https://github.com/DreaMKHuang/FANet) 55 | 13 | **2021**| **CVPR** | Reciprocal Transformations for Unsupervised Video Object Segmentation | [Paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ren_Reciprocal_Transformations_for_Unsupervised_Video_Object_Segmentation_CVPR_2021_paper.pdf)/[Code](https://github.com/OliverRensu/RTNet) 56 | 57 | ------ 58 | 59 | # CAV-Net 60 | 61 | The codes are available at [src](https://github.com/PanoAsh/PAV-SOD/tree/main/src). 62 | 63 | The pre-trained models can be downloaded at [Google Drive](https://drive.google.com/file/d/1gNWmgmlBfJqCYE5phDuHTMFIou1TAmXs/view?usp=sharing). 64 | 65 | ------ 66 | 67 | # Dataset Downloads 68 | 69 | The whole object-/instance-level ground truth with default split can be downloaded from [Google Drive](https://drive.google.com/file/d/1Whp_ftuXza8-vkjNtICdxdRebcmzcrFi/view?usp=sharing). 70 | 71 | The videos (with ambisonics) with default split can be downloaded from [Google Drive](https://drive.google.com/file/d/13FEv1yAyMmK4GkiZ2Mce6gJxQuME7vG3/view?usp=sharing). 72 | 73 | The videos (with mono sound) can be downloaded from [Google Drive](https://drive.google.com/file/d/1klJnHSiUM7Ow2LkdaLe-O6CEQ9qmdo2F/view?usp=sharing) 74 | 75 | The audio files (.wav) can be downloaded from [Google Drive](https://drive.google.com/file/d/1-jqDArcm8vBhyku3Xb8HLopG1XmpAS13/view?usp=sharing). 76 | 77 | The head movement and eye fixation data can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1EpWc7GVcGFAn5VigV3c2-ZtIZElfXPX1?usp=sharing). 78 | 79 | To generate video frames, please refer to [video_to_frames.py](https://github.com/PanoAsh/ASOD60K/blob/main/video_to_frames.py). 80 | 81 | To get access to raw videos on YouTube, please refer to [video_seq_link](https://github.com/PanoAsh/ASOD60K/blob/main/video_seq_link). 82 | 83 | > Note: The PAVS10K dataset does not own the copyright of videos. Only researchers and educators who wish to use the videos for non-commercial researches and/or educational purposes, have access to PAVS10K. 84 | 85 | ------ 86 | 87 | # Citation 88 | 89 | @article{zhang2023pav, 90 | title={PAV-SOD: A New Task towards Panoramic Audiovisual Saliency Detection}, 91 | author={Zhang, Yi and Chao, Fang-Yi and Hamidouche, Wassim and Deforges, Olivier}, 92 | journal={ACM Transactions on Multimedia Computing, Communications and Applications}, 93 | volume={19}, 94 | number={3}, 95 | pages={1--26}, 96 | year={2023}, 97 | publisher={ACM New York, NY} 98 | } 99 | 100 | ------ 101 | 102 | # Contact 103 | 104 | yi23zhang.2022@gmail.com 105 | or 106 | fangyichao428@gmail.com (for details of head movement and eye fixation data). 107 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.nn.parameter import Parameter 5 | import scipy.stats as st 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | CE = torch.nn.BCEWithLogitsLoss() 12 | 13 | def l2_regularisation(m): 14 | l2_reg = None 15 | 16 | for W in m.parameters(): 17 | if l2_reg is None: 18 | l2_reg = W.norm(2) 19 | else: 20 | l2_reg = l2_reg + W.norm(2) 21 | 22 | return l2_reg 23 | 24 | # linear annealing to avoid posterior collapse 25 | def linear_annealing(init, fin, step, annealing_steps): 26 | """Linear annealing of a parameter.""" 27 | if annealing_steps == 0: 28 | return fin 29 | assert fin > init 30 | delta = fin - init 31 | annealed = min(init + delta * step / annealing_steps, fin) 32 | 33 | return annealed 34 | 35 | def structure_loss(pred, mask): 36 | # adaptive weighting mask 37 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 38 | 39 | # weighted binary cross entropy loss function 40 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 41 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8) 42 | 43 | pred = torch.sigmoid(pred) 44 | 45 | # weighted iou loss function 46 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 47 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 48 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8) 49 | 50 | return (wbce + wiou).mean() 51 | 52 | def weighted_e_loss(pred, mask): 53 | # adaptive weighting mask 54 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 55 | 56 | # weighted binary cross entropy loss function 57 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 58 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8) 59 | 60 | # weighted e loss function 61 | pred = torch.sigmoid(pred) 62 | mpred = pred.mean(dim=(2, 3)).view(pred.shape[0], pred.shape[1], 1, 1).repeat(1, 1, pred.shape[2], pred.shape[3]) 63 | phiFM = pred - mpred 64 | 65 | mmask = mask.mean(dim=(2, 3)).view(mask.shape[0], mask.shape[1], 1, 1).repeat(1, 1, mask.shape[2], mask.shape[3]) 66 | phiGT = mask - mmask 67 | 68 | EFM = (2.0 * phiFM * phiGT + 1e-8) / (phiFM * phiFM + phiGT * phiGT + 1e-8) 69 | QFM = (1 + EFM) * (1 + EFM) / 4.0 70 | eloss = 1.0 - QFM.mean(dim=(2, 3)) 71 | 72 | # weighted iou loss function 73 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 74 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 75 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8) 76 | 77 | return (wbce + eloss + wiou).mean() 78 | 79 | def hybrid_e_loss(pred, mask): 80 | # adaptive weighting mask 81 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 82 | 83 | # weighted binary cross entropy loss function 84 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 85 | wbce = ((weit * wbce).sum(dim=(2, 3)) + 1e-8) / (weit.sum(dim=(2, 3)) + 1e-8) 86 | 87 | # weighted e loss function 88 | pred = torch.sigmoid(pred) 89 | mpred = pred.mean(dim=(2, 3)).view(pred.shape[0], pred.shape[1], 1, 1).repeat(1, 1, pred.shape[2], pred.shape[3]) 90 | phiFM = pred - mpred 91 | 92 | mmask = mask.mean(dim=(2, 3)).view(mask.shape[0], mask.shape[1], 1, 1).repeat(1, 1, mask.shape[2], mask.shape[3]) 93 | phiGT = mask - mmask 94 | 95 | EFM = (2.0 * phiFM * phiGT + 1e-8) / (phiFM * phiFM + phiGT * phiGT + 1e-8) 96 | QFM = (1 + EFM) * (1 + EFM) / 4.0 97 | eloss = 1.0 - QFM.mean(dim=(2, 3)) 98 | 99 | # weighted iou loss function 100 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 101 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 102 | wiou = 1.0 - (inter + 1 + 1e-8) / (union - inter + 1 + 1e-8) 103 | 104 | return (wbce + eloss + wiou).mean() 105 | 106 | def gkern(kernlen=16, nsig=3): 107 | interval = (2*nsig+1.)/kernlen 108 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) 109 | kern1d = np.diff(st.norm.cdf(x)) 110 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 111 | kernel = kernel_raw/kernel_raw.sum() 112 | return kernel 113 | 114 | def min_max_norm(in_): 115 | max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 116 | min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 117 | in_ = in_ - min_ 118 | return in_.div(max_-min_+1e-8) 119 | 120 | class HA(nn.Module): 121 | # holistic attention module 122 | def __init__(self): 123 | super(HA, self).__init__() 124 | gaussian_kernel = np.float32(gkern(31, 4)) 125 | gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...] 126 | self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel)) 127 | 128 | def forward(self, attention, x): 129 | soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15) 130 | soft_attention = min_max_norm(soft_attention) 131 | x = torch.mul(x, soft_attention.max(attention)) 132 | return x 133 | 134 | 135 | def clip_gradient(optimizer, grad_clip): 136 | for group in optimizer.param_groups: 137 | for param in group['params']: 138 | if param.grad is not None: 139 | param.grad.data.clamp_(-grad_clip, grad_clip) 140 | 141 | 142 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 143 | decay = decay_rate ** (epoch // decay_epoch) 144 | for param_group in optimizer.param_groups: 145 | param_group['lr'] = decay * init_lr 146 | return 147 | 148 | def print_network(model, name): 149 | num_params = 0 150 | for p in model.parameters(): 151 | num_params += p.numel() 152 | print(model) 153 | print(name) 154 | print("The number of parameters: {}".format(num_params)) 155 | 156 | 157 | def calculate_cdf(histogram): 158 | """ 159 | This method calculates the cumulative distribution function 160 | :param array histogram: The values of the histogram 161 | :return: normalized_cdf: The normalized cumulative distribution function 162 | :rtype: array 163 | """ 164 | # Get the cumulative sum of the elements 165 | cdf = histogram.cumsum() 166 | 167 | # Normalize the cdf 168 | normalized_cdf = cdf / float(cdf.max()) 169 | 170 | return normalized_cdf 171 | 172 | def calculate_lookup(src_cdf, ref_cdf): 173 | """ 174 | This method creates the lookup table 175 | :param array src_cdf: The cdf for the source image 176 | :param array ref_cdf: The cdf for the reference image 177 | :return: lookup_table: The lookup table 178 | :rtype: array 179 | """ 180 | lookup_table = np.zeros(256) 181 | lookup_val = 0 182 | for src_pixel_val in range(len(src_cdf)): 183 | lookup_val 184 | for ref_pixel_val in range(len(ref_cdf)): 185 | if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: 186 | lookup_val = ref_pixel_val 187 | break 188 | lookup_table[src_pixel_val] = lookup_val 189 | return lookup_table 190 | 191 | def match_histograms(src_image, ref_image): 192 | """ 193 | This method matches the source image histogram to the 194 | reference signal 195 | :param image src_image: The original source image 196 | :param image ref_image: The reference image 197 | :return: image_after_matching 198 | :rtype: image (array) 199 | """ 200 | # Split the images into the different color channels 201 | # b means blue, g means green and r means red 202 | src_b, src_g, src_r = cv2.split(src_image) 203 | ref_b, ref_g, ref_r = cv2.split(ref_image) 204 | 205 | # Compute the b, g, and r histograms separately 206 | # The flatten() Numpy method returns a copy of the array c 207 | # collapsed into one dimension. 208 | src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) 209 | src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) 210 | src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) 211 | ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) 212 | ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) 213 | ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) 214 | 215 | # Compute the normalized cdf for the source and reference image 216 | src_cdf_blue = calculate_cdf(src_hist_blue) 217 | src_cdf_green = calculate_cdf(src_hist_green) 218 | src_cdf_red = calculate_cdf(src_hist_red) 219 | ref_cdf_blue = calculate_cdf(ref_hist_blue) 220 | ref_cdf_green = calculate_cdf(ref_hist_green) 221 | ref_cdf_red = calculate_cdf(ref_hist_red) 222 | 223 | # Make a separate lookup table for each color 224 | blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) 225 | green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) 226 | red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) 227 | 228 | # Use the lookup function to transform the colors of the original 229 | # source image 230 | blue_after_transform = cv2.LUT(src_b, blue_lookup_table) 231 | green_after_transform = cv2.LUT(src_g, green_lookup_table) 232 | red_after_transform = cv2.LUT(src_r, red_lookup_table) 233 | 234 | # Put the image back together 235 | image_after_matching = cv2.merge([ 236 | blue_after_transform, green_after_transform, red_after_transform]) 237 | image_after_matching = cv2.convertScaleAbs(image_after_matching) 238 | 239 | return image_after_matching 240 | 241 | def histogram(): 242 | ref_pth = os.getcwd() + '/Ref/0013.png' 243 | ori_pth = os.getcwd() + '/ori/' 244 | fin_pth = os.getcwd() + '/prep/' 245 | ori_list = os.listdir(ori_pth) 246 | count = 0 247 | for item in ori_list: 248 | ori_img = cv2.imread(os.path.join(ori_pth, item)) 249 | ref_img = cv2.imread(ref_pth) 250 | fin_img = match_histograms(ori_img, ref_img) 251 | cv2.imwrite(os.path.join(fin_pth, item), fin_img) 252 | count += 1 253 | print(count) 254 | 255 | 256 | if __name__ == '__main__': 257 | histogram() 258 | -------------------------------------------------------------------------------- /src/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.VIT import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | 32 | elif backbone == "vitb_rn50_384": 33 | pretrained = _make_pretrained_vitb_rn50_384( 34 | use_pretrained, 35 | hooks=hooks, 36 | use_vit_only=use_vit_only, 37 | use_readout=use_readout, 38 | enable_attention_hooks=enable_attention_hooks, 39 | ) 40 | 41 | elif backbone == "vitb16_384": 42 | pretrained = _make_pretrained_vitb16_384( 43 | use_pretrained, 44 | hooks=hooks, 45 | use_readout=use_readout, 46 | enable_attention_hooks=enable_attention_hooks, 47 | ) 48 | 49 | elif backbone == "resnext101_wsl": 50 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 51 | 52 | else: 53 | print(f"Backbone '{backbone}' not implemented") 54 | assert False 55 | 56 | return pretrained 57 | 58 | 59 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 60 | scratch = nn.Module() 61 | 62 | out_shape1 = out_shape 63 | out_shape2 = out_shape 64 | out_shape3 = out_shape 65 | out_shape4 = out_shape 66 | if expand == True: 67 | out_shape1 = out_shape 68 | out_shape2 = out_shape * 2 69 | out_shape3 = out_shape * 4 70 | out_shape4 = out_shape * 8 71 | 72 | scratch.layer1_rn = nn.Conv2d( 73 | in_shape[0], 74 | out_shape1, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1, 78 | bias=False, 79 | groups=groups, 80 | ) 81 | scratch.layer2_rn = nn.Conv2d( 82 | in_shape[1], 83 | out_shape2, 84 | kernel_size=3, 85 | stride=1, 86 | padding=1, 87 | bias=False, 88 | groups=groups, 89 | ) 90 | scratch.layer3_rn = nn.Conv2d( 91 | in_shape[2], 92 | out_shape3, 93 | kernel_size=3, 94 | stride=1, 95 | padding=1, 96 | bias=False, 97 | groups=groups, 98 | ) 99 | scratch.layer4_rn = nn.Conv2d( 100 | in_shape[3], 101 | out_shape4, 102 | kernel_size=3, 103 | stride=1, 104 | padding=1, 105 | bias=False, 106 | groups=groups, 107 | ) 108 | 109 | return scratch 110 | 111 | 112 | def _make_resnet_backbone(resnet): 113 | pretrained = nn.Module() 114 | pretrained.layer1 = nn.Sequential( 115 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 116 | ) 117 | 118 | pretrained.layer2 = resnet.layer2 119 | pretrained.layer3 = resnet.layer3 120 | pretrained.layer4 = resnet.layer4 121 | 122 | return pretrained 123 | 124 | 125 | def _make_pretrained_resnext101_wsl(use_pretrained): 126 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 127 | return _make_resnet_backbone(resnet) 128 | 129 | 130 | class Interpolate(nn.Module): 131 | """Interpolation module.""" 132 | 133 | def __init__(self, scale_factor, mode, align_corners=False): 134 | """Init. 135 | 136 | Args: 137 | scale_factor (float): scaling 138 | mode (str): interpolation mode 139 | """ 140 | super(Interpolate, self).__init__() 141 | 142 | self.interp = nn.functional.interpolate 143 | self.scale_factor = scale_factor 144 | self.mode = mode 145 | self.align_corners = align_corners 146 | 147 | def forward(self, x): 148 | """Forward pass. 149 | 150 | Args: 151 | x (tensor): input 152 | 153 | Returns: 154 | tensor: interpolated data 155 | """ 156 | 157 | x = self.interp( 158 | x, 159 | scale_factor=self.scale_factor, 160 | mode=self.mode, 161 | align_corners=self.align_corners, 162 | ) 163 | 164 | return x 165 | 166 | 167 | class ResidualConvUnit(nn.Module): 168 | """Residual convolution module.""" 169 | 170 | def __init__(self, features): 171 | """Init. 172 | 173 | Args: 174 | features (int): number of features 175 | """ 176 | super().__init__() 177 | 178 | self.conv1 = nn.Conv2d( 179 | features, features, kernel_size=3, stride=1, padding=1, bias=True 180 | ) 181 | 182 | self.conv2 = nn.Conv2d( 183 | features, features, kernel_size=3, stride=1, padding=1, bias=True 184 | ) 185 | 186 | self.relu = nn.ReLU(inplace=True) 187 | 188 | def forward(self, x): 189 | """Forward pass. 190 | 191 | Args: 192 | x (tensor): input 193 | 194 | Returns: 195 | tensor: output 196 | """ 197 | out = self.relu(x) 198 | out = self.conv1(out) 199 | out = self.relu(out) 200 | out = self.conv2(out) 201 | 202 | return out + x 203 | 204 | 205 | class FeatureFusionBlock(nn.Module): 206 | """Feature fusion block.""" 207 | 208 | def __init__(self, features): 209 | """Init. 210 | 211 | Args: 212 | features (int): number of features 213 | """ 214 | super(FeatureFusionBlock, self).__init__() 215 | 216 | self.resConfUnit1 = ResidualConvUnit(features) 217 | self.resConfUnit2 = ResidualConvUnit(features) 218 | 219 | def forward(self, *xs): 220 | """Forward pass. 221 | 222 | Returns: 223 | tensor: output 224 | """ 225 | output = xs[0] 226 | 227 | if len(xs) == 2: 228 | output += self.resConfUnit1(xs[1]) 229 | 230 | output = self.resConfUnit2(output) 231 | 232 | output = nn.functional.interpolate( 233 | output, scale_factor=2, mode="bilinear", align_corners=True 234 | ) 235 | 236 | return output 237 | 238 | 239 | class ResidualConvUnit_custom(nn.Module): 240 | """Residual convolution module.""" 241 | 242 | def __init__(self, features, activation, bn): 243 | """Init. 244 | 245 | Args: 246 | features (int): number of features 247 | """ 248 | super().__init__() 249 | 250 | self.bn = bn 251 | 252 | self.groups = 1 253 | 254 | self.conv1 = nn.Conv2d( 255 | features, 256 | features, 257 | kernel_size=3, 258 | stride=1, 259 | padding=1, 260 | bias=not self.bn, 261 | groups=self.groups, 262 | ) 263 | 264 | self.conv2 = nn.Conv2d( 265 | features, 266 | features, 267 | kernel_size=3, 268 | stride=1, 269 | padding=1, 270 | bias=not self.bn, 271 | groups=self.groups, 272 | ) 273 | 274 | if self.bn == True: 275 | self.bn1 = nn.BatchNorm2d(features) 276 | self.bn2 = nn.BatchNorm2d(features) 277 | 278 | self.activation = activation 279 | 280 | self.skip_add = nn.quantized.FloatFunctional() 281 | 282 | def forward(self, x): 283 | """Forward pass. 284 | 285 | Args: 286 | x (tensor): input 287 | 288 | Returns: 289 | tensor: output 290 | """ 291 | 292 | out = self.activation(x) 293 | out = self.conv1(out) 294 | if self.bn == True: 295 | out = self.bn1(out) 296 | 297 | out = self.activation(out) 298 | out = self.conv2(out) 299 | if self.bn == True: 300 | out = self.bn2(out) 301 | 302 | if self.groups > 1: 303 | out = self.conv_merge(out) 304 | 305 | return self.skip_add.add(out, x) 306 | 307 | # return out + x 308 | 309 | 310 | class FeatureFusionBlock_custom(nn.Module): 311 | """Feature fusion block.""" 312 | 313 | def __init__( 314 | self, 315 | features, 316 | activation, 317 | deconv=False, 318 | bn=False, 319 | expand=False, 320 | align_corners=True, 321 | ): 322 | """Init. 323 | 324 | Args: 325 | features (int): number of features 326 | """ 327 | super(FeatureFusionBlock_custom, self).__init__() 328 | 329 | self.deconv = deconv 330 | self.align_corners = align_corners 331 | 332 | self.groups = 1 333 | 334 | self.expand = expand 335 | out_features = features 336 | if self.expand == True: 337 | out_features = features // 2 338 | 339 | self.out_conv = nn.Conv2d( 340 | features, 341 | out_features, 342 | kernel_size=1, 343 | stride=1, 344 | padding=0, 345 | bias=True, 346 | groups=1, 347 | ) 348 | 349 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 350 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 351 | 352 | self.skip_add = nn.quantized.FloatFunctional() 353 | 354 | def forward(self, *xs): 355 | """Forward pass. 356 | 357 | Returns: 358 | tensor: output 359 | """ 360 | output = xs[0] 361 | 362 | if len(xs) == 2: 363 | res = self.resConfUnit1(xs[1]) 364 | output = self.skip_add.add(output, res) 365 | # output += res 366 | 367 | output = self.resConfUnit2(output) 368 | 369 | output = nn.functional.interpolate( 370 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 371 | ) 372 | 373 | output = self.out_conv(output) 374 | 375 | return output 376 | -------------------------------------------------------------------------------- /src/models/CAVNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | from models.rcrnet_vit import RCRNet_vit, _ConvBatchNormReLU, _RefinementModule 5 | from models.blocks import forward_vit 6 | from models.convgru import ConvGRUCell 7 | from models.non_local_dot_product import NONLocalBlock3D 8 | 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import torch 12 | import torch.nn as nn 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | import torch.nn.functional as F 15 | import numpy as np 16 | from torch.distributions import Normal, Independent, kl 17 | 18 | 19 | class InferenceModel_mm_x(nn.Module): 20 | def __init__(self, input_channels, channels, latent_size): 21 | super(InferenceModel_mm_x, self).__init__() 22 | self.contracting_path = nn.ModuleList() 23 | self.input_channels = input_channels 24 | self.relu = nn.ReLU(inplace=True) 25 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 26 | self.bn1 = nn.BatchNorm2d(channels) 27 | self.layer2 = nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2, padding=1) 28 | self.bn2 = nn.BatchNorm2d(channels * 2) 29 | self.layer3 = nn.Conv2d(2 * channels, 4 * channels, kernel_size=4, stride=2, padding=1) 30 | self.bn3 = nn.BatchNorm2d(channels * 4) 31 | self.layer4 = nn.Conv2d(4 * channels, 8 * channels, kernel_size=4, stride=2, padding=1) 32 | self.bn4 = nn.BatchNorm2d(channels * 8) 33 | self.layer5 = nn.Conv2d(8 * channels, 8 * channels, kernel_size=4, stride=2, padding=1) 34 | self.bn5 = nn.BatchNorm2d(channels * 8) 35 | self.channel = channels 36 | 37 | self.fc1 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size 38 | self.fc2 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size 39 | 40 | self.a_fc1 = nn.Linear(1024, latent_size) 41 | self.a_fc2 = nn.Linear(1024, latent_size) 42 | 43 | self.av_fc1 = nn.Bilinear(latent_size, latent_size, latent_size) 44 | self.av_fc2 = nn.Bilinear(latent_size, latent_size, latent_size) 45 | 46 | self.leakyrelu = nn.LeakyReLU() 47 | 48 | def forward(self, input, aux_input): 49 | output = self.leakyrelu(self.bn1(self.layer1(input))) 50 | # print(output.size()) 51 | output = self.leakyrelu(self.bn2(self.layer2(output))) 52 | # print(output.size()) 53 | output = self.leakyrelu(self.bn3(self.layer3(output))) 54 | # print(output.size()) 55 | output = self.leakyrelu(self.bn4(self.layer4(output))) 56 | # print(output.size()) 57 | output = self.leakyrelu(self.bn5(self.layer5(output))) 58 | output = output.view(-1, self.channel * 8 * 13 * 26) # adjust according to input size 59 | # print(output.size()) 60 | # output = self.tanh(output) 61 | 62 | # audio visual fusion 63 | mu = self.fc1(output) 64 | a_mu = self.a_fc1(aux_input) 65 | av_mu = self.av_fc1(mu, a_mu) 66 | 67 | logvar = self.fc2(output) 68 | a_logvar = self.a_fc2(aux_input) 69 | av_logvar = self.av_fc2(logvar, a_logvar) 70 | dist = Independent(Normal(loc=av_mu, scale=torch.exp(av_logvar)), 1) 71 | 72 | return av_mu, av_logvar, dist 73 | 74 | 75 | class InferenceModel_mm_xy(nn.Module): 76 | def __init__(self, input_channels, channels, latent_size): 77 | super(InferenceModel_mm_xy, self).__init__() 78 | self.contracting_path = nn.ModuleList() 79 | self.input_channels = input_channels 80 | self.relu = nn.ReLU(inplace=True) 81 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 82 | self.bn1 = nn.BatchNorm2d(channels) 83 | self.layer2 = nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2, padding=1) 84 | self.bn2 = nn.BatchNorm2d(channels * 2) 85 | self.layer3 = nn.Conv2d(2 * channels, 4 * channels, kernel_size=4, stride=2, padding=1) 86 | self.bn3 = nn.BatchNorm2d(channels * 4) 87 | self.layer4 = nn.Conv2d(4 * channels, 8 * channels, kernel_size=4, stride=2, padding=1) 88 | self.bn4 = nn.BatchNorm2d(channels * 8) 89 | self.layer5 = nn.Conv2d(8 * channels, 8 * channels, kernel_size=4, stride=2, padding=1) 90 | self.bn5 = nn.BatchNorm2d(channels * 8) 91 | self.channel = channels 92 | 93 | self.fc1 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size 94 | self.fc2 = nn.Linear(channels * 8 * 13 * 26, latent_size) # adjust according to input size 95 | 96 | self.a_fc1 = nn.Linear(1024, latent_size) 97 | self.a_fc2 = nn.Linear(1024, latent_size) 98 | 99 | self.av_fc1 = nn.Bilinear(latent_size, latent_size, latent_size) 100 | self.av_fc2 = nn.Bilinear(latent_size, latent_size, latent_size) 101 | 102 | self.leakyrelu = nn.LeakyReLU() 103 | 104 | def forward(self, input, aux_input): 105 | output = self.leakyrelu(self.bn1(self.layer1(input))) 106 | # print(output.size()) 107 | output = self.leakyrelu(self.bn2(self.layer2(output))) 108 | # print(output.size()) 109 | output = self.leakyrelu(self.bn3(self.layer3(output))) 110 | # print(output.size()) 111 | output = self.leakyrelu(self.bn4(self.layer4(output))) 112 | # print(output.size()) 113 | output = self.leakyrelu(self.bn5(self.layer5(output))) 114 | output = output.view(-1, self.channel * 8 * 13 * 26) # adjust according to input size 115 | # print(output.size()) 116 | # output = self.tanh(output) 117 | 118 | # audio visual fusion 119 | mu = self.fc1(output) 120 | a_mu = self.a_fc1(aux_input) 121 | av_mu = self.av_fc1(mu, a_mu) 122 | 123 | logvar = self.fc2(output) 124 | a_logvar = self.a_fc2(aux_input) 125 | av_logvar = self.av_fc2(logvar, a_logvar) 126 | dist = Independent(Normal(loc=av_mu, scale=torch.exp(av_logvar)), 1) 127 | 128 | return av_mu, av_logvar, dist 129 | 130 | 131 | class cavnet(nn.Module): 132 | def __init__(self, output_stride=16, lat_channel=16, lat_dim=32): 133 | super(cavnet, self).__init__() 134 | # for visual backbone 135 | # encoder 136 | self.backbone_enc_vit = RCRNet_vit(n_classes=1, output_stride=output_stride).pretrained 137 | self.backbone_enc_aspp = RCRNet_vit(n_classes=1, output_stride=output_stride).aspp 138 | 139 | # decoder 140 | self.backbone_dec_prior = RCRNet_decoder() 141 | self.backbone_dec_post = RCRNet_decoder() 142 | 143 | # for temporal module 144 | self.convgru_forward = ConvGRUCell(256, 256, 3) 145 | self.convgru_backward = ConvGRUCell(256, 256, 3) 146 | self.bidirection_conv = nn.Conv2d(512, 256, 3, 1, 1) 147 | self.non_local_block = NONLocalBlock3D(256, sub_sample=False, bn_layer=False) 148 | self.non_local_block2 = NONLocalBlock3D(256, sub_sample=False, bn_layer=False) 149 | 150 | # for CVAE 151 | self.enc_mm_x = InferenceModel_mm_x(3, lat_channel, lat_dim) 152 | self.enc_mm_xy = InferenceModel_mm_xy(4, lat_channel, lat_dim) 153 | self.spatial_axes = [2, 3] 154 | self.noise_conv_prior = nn.Conv2d(256 + lat_dim, 256, kernel_size=1, padding=0) 155 | self.noise_conv_post = nn.Conv2d(256 + lat_dim, 256, kernel_size=1, padding=0) 156 | 157 | # for audio 158 | # enocder 159 | self.soundnet = nn.Sequential( # 7 layers used 160 | nn.Conv2d(1, 16, (1, 64), (1, 2), (0, 32)), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d((1, 8), (1, 8)), 161 | nn.Conv2d(16, 32, (1, 32), (1, 2), (0, 16)), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d((1, 8), (1, 8)), 162 | nn.Conv2d(32, 64, (1, 16), (1, 2), (0, 8)), nn.BatchNorm2d(64), nn.ReLU(), 163 | nn.Conv2d(64, 128, (1, 8), (1, 2), (0, 4)), nn.BatchNorm2d(128), nn.ReLU(), 164 | nn.Conv2d(128, 256, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d((1, 4), (1, 4)), 165 | nn.Conv2d(256, 512, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(512), nn.ReLU(), 166 | nn.Conv2d(512, 1024, (1, 4), (1, 2), (0, 2)), nn.BatchNorm2d(1024), nn.ReLU(), nn.MaxPool2d((1, 2)) 167 | ) 168 | 169 | # load pre-training model 170 | if self.training: 171 | self.initialize_pretrain() # load visual pretrain (static visual part pre-trained on DUTS-tr) 172 | self.initialize_soundnet() # load audio pretrain (soundnet) 173 | 174 | # may freeze auditory branch 175 | for param in self.soundnet.parameters(): param.requires_grad = True 176 | 177 | def initialize_pretrain(self): 178 | backbone_pretrain = torch.load(os.getcwd() + '/pretrain/static_visual_pretrain.pth') 179 | 180 | all_params_enc_vit = {} 181 | for k, v in self.backbone_enc_vit.state_dict().items(): 182 | if 'pretrained.' + k in backbone_pretrain.keys(): 183 | v = backbone_pretrain['pretrained.' + k] 184 | all_params_enc_vit[k] = v 185 | self.backbone_enc_vit.load_state_dict(all_params_enc_vit) 186 | 187 | all_params_enc_aspp = {} 188 | for k, v in self.backbone_enc_aspp.state_dict().items(): 189 | if 'aspp.' + k in backbone_pretrain.keys(): 190 | v = backbone_pretrain['aspp.' + k] 191 | all_params_enc_aspp[k] = v 192 | self.backbone_enc_aspp.load_state_dict(all_params_enc_aspp) 193 | 194 | def initialize_soundnet(self): 195 | audio_pretrain_weights = torch.load(os.getcwd() + '/pretrain/soundnet8.pth') 196 | 197 | all_params = {} 198 | for k, v in self.soundnet.state_dict().items(): 199 | if 'module.soundnet8.' + k in audio_pretrain_weights.keys(): 200 | v = audio_pretrain_weights['module.soundnet8.' + k] 201 | all_params[k] = v 202 | self.soundnet.load_state_dict(all_params) 203 | 204 | def reparametrize(self, mu, logvar): 205 | std = logvar.mul(0.5).exp_() 206 | eps = torch.cuda.FloatTensor(std.size()).normal_() 207 | eps = Variable(eps) 208 | 209 | return eps.mul(std).add_(mu) 210 | 211 | def kl_divergence(self, posterior_latent_space, prior_latent_space): 212 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space) 213 | 214 | return kl_div 215 | 216 | def tile(self, a, dim, n_tile): 217 | init_dim = a.size(dim) 218 | repeat_idx = [1] * a.dim() 219 | repeat_idx[dim] = n_tile 220 | a = a.repeat(*(repeat_idx)) 221 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to( 222 | device) 223 | return torch.index_select(a, dim, order_index) 224 | 225 | def forward(self, seq, audio_clip, gt=None): 226 | # -------------------------------------------------------------------------------------------------------------- 227 | # encoder for mono sound 228 | feats_audio = self.soundnet(audio_clip.unsqueeze(0)) 229 | feats_audio = torch.mean(feats_audio, dim=-1).squeeze().unsqueeze(0) 230 | 231 | # -------------------------------------------------------------------------------------------------------------- 232 | # encoder for audio-visual-temporal 233 | feats_seq = [forward_vit(self.backbone_enc_vit, frame) for frame in seq] 234 | feats_time = [self.backbone_enc_aspp(feats[-1]) for feats in feats_seq] # bottleneck features 235 | feats_time = torch.stack(feats_time, dim=2) 236 | feats_time = self.non_local_block(feats_time) 237 | 238 | # Deep Bidirectional ConvGRU 239 | frame = seq[0] 240 | feat = feats_time[:, :, 0, :, :] 241 | feats_forward = [] 242 | # forward 243 | for i in range(len(seq)): 244 | feat = self.convgru_forward(feats_time[:, :, i, :, :], feat) 245 | feats_forward.append(feat) 246 | # backward 247 | feat = feats_forward[-1] 248 | feats_backward = [] 249 | for i in range(len(seq)): 250 | feat = self.convgru_backward(feats_forward[len(seq)-1-i], feat) 251 | feats_backward.append(feat) 252 | 253 | feats_backward = feats_backward[::-1] 254 | feats = [] 255 | for i in range(len(seq)): 256 | feat = torch.tanh(self.bidirection_conv(torch.cat((feats_forward[i], feats_backward[i]), dim=1))) 257 | feats.append(feat) 258 | feats = torch.stack(feats, dim=2) 259 | 260 | feats = self.non_local_block2(feats) # spatial-temporal bottleneck features 261 | 262 | # -------------------------------------------------------------------------------------------------------------- 263 | if gt == None: 264 | # model inference 265 | feats_prior = [] 266 | for rr in range(len(seq)): 267 | mu_prior, logvar_prior, _ = self.enc_mm_x(seq[rr], feats_audio) 268 | z_prior = self.reparametrize(mu_prior, logvar_prior) # instantiate latent variable 269 | z_prior = torch.unsqueeze(z_prior, 2) 270 | z_prior = self.tile(z_prior, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]]) 271 | z_prior = torch.unsqueeze(z_prior, 3) 272 | z_prior = self.tile(z_prior, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]]) 273 | f_prior = torch.cat((feats[:, :, rr, :, :], z_prior), 1) 274 | feats_prior.append(self.noise_conv_prior(f_prior)) 275 | 276 | preds_prior = [] 277 | for i, frame in enumerate(seq): 278 | seg_prior = self.backbone_dec_prior(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_prior[i], 279 | frame) 280 | preds_prior.append(seg_prior) 281 | 282 | return preds_prior 283 | else: 284 | # CVAE 285 | KLD, feats_prior, feats_post = [], [], [] 286 | for rr in range(len(seq)): 287 | mu_prior, logvar_prior, dist_prior = self.enc_mm_x(seq[rr], feats_audio) 288 | z_prior = self.reparametrize(mu_prior, logvar_prior) # instantiate latent variable 289 | z_prior = torch.unsqueeze(z_prior, 2) 290 | z_prior = self.tile(z_prior, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]]) 291 | z_prior = torch.unsqueeze(z_prior, 3) 292 | z_prior = self.tile(z_prior, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]]) 293 | f_prior = torch.cat((feats[:, :, rr, :, :], z_prior), 1) 294 | feats_prior.append(self.noise_conv_prior(f_prior)) 295 | 296 | mu_post, logvar_post, dist_post = self.enc_mm_xy(torch.cat((seq[rr], gt[rr]), 1), feats_audio) 297 | z_post = self.reparametrize(mu_post, logvar_post) # instantiate latent variable 298 | z_post = torch.unsqueeze(z_post, 2) 299 | z_post = self.tile(z_post, 2, feats[:, :, rr, :, :].shape[self.spatial_axes[0]]) 300 | z_post = torch.unsqueeze(z_post, 3) 301 | z_post = self.tile(z_post, 3, feats[:, :, rr, :, :].shape[self.spatial_axes[1]]) 302 | f_post = torch.cat((feats[:, :, rr, :, :], z_post), 1) 303 | feats_post.append(self.noise_conv_post(f_post)) 304 | 305 | KLD.append(torch.mean(self.kl_divergence(dist_post, dist_prior))) 306 | 307 | preds_prior, preds_post = [], [] 308 | for i, frame in enumerate(seq): 309 | seg_prior = self.backbone_dec_prior(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_prior[i], 310 | frame) 311 | preds_prior.append(seg_prior) 312 | seg_post = self.backbone_dec_post(feats_seq[i][0], feats_seq[i][1], feats_seq[i][2], feats_post[i], 313 | frame) 314 | preds_post.append(seg_post) 315 | 316 | return preds_prior, preds_post, KLD 317 | 318 | 319 | class RCRNet_decoder(nn.Module): 320 | def __init__(self, n_classes=1): 321 | super(RCRNet_decoder, self).__init__() 322 | self.decoder = nn.Sequential( 323 | OrderedDict( 324 | [ 325 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)), 326 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)), 327 | ] 328 | ) 329 | ) 330 | self.add_module("refinement1", _RefinementModule(768, 96, 256, 128, 2)) 331 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2)) 332 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2)) 333 | 334 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 335 | 336 | if self.training: self.init_pretrain() 337 | 338 | def init_pretrain(self): 339 | backbone_pretrain = torch.load(os.getcwd() + '/pretrain/static_visual_pretrain.pth') 340 | 341 | all_params_dec = {} 342 | for k, v in self.decoder.state_dict().items(): 343 | if 'decoder.' + k in backbone_pretrain.keys(): 344 | v = backbone_pretrain['decoder.' + k] 345 | all_params_dec[k] = v 346 | self.decoder.load_state_dict(all_params_dec) 347 | 348 | all_params_ref1 = {} 349 | for k, v in self.refinement1.state_dict().items(): 350 | if 'refinement1.' + k in backbone_pretrain.keys(): 351 | v = backbone_pretrain['refinement1.' + k] 352 | all_params_ref1[k] = v 353 | self.refinement1.load_state_dict(all_params_ref1) 354 | 355 | all_params_ref2 = {} 356 | for k, v in self.refinement2.state_dict().items(): 357 | if 'refinement2.' + k in backbone_pretrain.keys(): 358 | v = backbone_pretrain['refinement2.' + k] 359 | all_params_ref2[k] = v 360 | self.refinement2.load_state_dict(all_params_ref2) 361 | 362 | all_params_ref3 = {} 363 | for k, v in self.refinement3.state_dict().items(): 364 | if 'refinement3.' + k in backbone_pretrain.keys(): 365 | v = backbone_pretrain['refinement3.' + k] 366 | all_params_ref3[k] = v 367 | self.refinement3.load_state_dict(all_params_ref3) 368 | 369 | def seg_conv(self, block1, block2, block3, block4, shape): 370 | ''' 371 | Pixel-wise classifer 372 | ''' 373 | block4 = self.upsample2(block4) 374 | 375 | bu1 = self.refinement1(block3, block4) 376 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False) 377 | bu2 = self.refinement2(block2, bu1) 378 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False) 379 | bu3 = self.refinement3(block1, bu2) 380 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False) 381 | seg = self.decoder(bu3) 382 | 383 | return seg 384 | 385 | def forward(self, block1, block2, block3, block4, x): 386 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:]) 387 | 388 | return seg 389 | 390 | -------------------------------------------------------------------------------- /src/models/VIT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | activations = {} 10 | 11 | 12 | def get_activation(name): 13 | def hook(model, input, output): 14 | activations[name] = output 15 | 16 | return hook 17 | 18 | 19 | attention = {} 20 | 21 | 22 | def get_attention(name): 23 | def hook(module, input, output): 24 | x = input[0] 25 | B, N, C = x.shape 26 | qkv = ( 27 | module.qkv(x) 28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads) 29 | .permute(2, 0, 3, 1, 4) 30 | ) 31 | q, k, v = ( 32 | qkv[0], 33 | qkv[1], 34 | qkv[2], 35 | ) # make torchscript happy (cannot use tensor as tuple) 36 | 37 | attn = (q @ k.transpose(-2, -1)) * module.scale 38 | 39 | attn = attn.softmax(dim=-1) # [:,:,1,1:] 40 | attention[name] = attn 41 | 42 | return hook 43 | 44 | 45 | def get_mean_attention_map(attn, token, shape): 46 | attn = attn[:, :, token, 1:] 47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() 48 | attn = torch.nn.functional.interpolate( 49 | attn, size=shape[2:], mode="bicubic", align_corners=False 50 | ).squeeze(0) 51 | 52 | all_attn = torch.mean(attn, 0) 53 | 54 | return all_attn 55 | 56 | 57 | class Slice(nn.Module): 58 | def __init__(self, start_index=1): 59 | super(Slice, self).__init__() 60 | self.start_index = start_index 61 | 62 | def forward(self, x): 63 | return x[:, self.start_index :] 64 | 65 | 66 | class AddReadout(nn.Module): 67 | def __init__(self, start_index=1): 68 | super(AddReadout, self).__init__() 69 | self.start_index = start_index 70 | 71 | def forward(self, x): 72 | if self.start_index == 2: 73 | readout = (x[:, 0] + x[:, 1]) / 2 74 | else: 75 | readout = x[:, 0] 76 | return x[:, self.start_index :] + readout.unsqueeze(1) 77 | 78 | 79 | class ProjectReadout(nn.Module): 80 | def __init__(self, in_features, start_index=1): 81 | super(ProjectReadout, self).__init__() 82 | self.start_index = start_index 83 | 84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 85 | 86 | def forward(self, x): 87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 88 | features = torch.cat((x[:, self.start_index :], readout), -1) 89 | 90 | return self.project(features) 91 | 92 | 93 | class Transpose(nn.Module): 94 | def __init__(self, dim0, dim1): 95 | super(Transpose, self).__init__() 96 | self.dim0 = dim0 97 | self.dim1 = dim1 98 | 99 | def forward(self, x): 100 | x = x.transpose(self.dim0, self.dim1) 101 | return x 102 | 103 | 104 | def forward_vit(pretrained, x): 105 | b, c, h, w = x.shape 106 | 107 | glob = pretrained.model.forward_flex(x) 108 | 109 | layer_1 = pretrained.activations["1"] 110 | layer_2 = pretrained.activations["2"] 111 | layer_3 = pretrained.activations["3"] 112 | layer_4 = pretrained.activations["4"] 113 | 114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 118 | 119 | unflatten = nn.Sequential( 120 | nn.Unflatten( 121 | 2, 122 | torch.Size( 123 | [ 124 | h // pretrained.model.patch_size[1], 125 | w // pretrained.model.patch_size[0], 126 | ] 127 | ), 128 | ) 129 | ) 130 | 131 | if layer_1.ndim == 3: 132 | layer_1 = unflatten(layer_1) 133 | if layer_2.ndim == 3: 134 | layer_2 = unflatten(layer_2) 135 | if layer_3.ndim == 3: 136 | layer_3 = unflatten(layer_3) 137 | if layer_4.ndim == 3: 138 | layer_4 = unflatten(layer_4) 139 | 140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 144 | 145 | return layer_1, layer_2, layer_3, layer_4 146 | 147 | 148 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 149 | posemb_tok, posemb_grid = ( 150 | posemb[:, : self.start_index], 151 | posemb[0, self.start_index :], 152 | ) 153 | 154 | gs_old = int(math.sqrt(len(posemb_grid))) 155 | 156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 159 | 160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 161 | 162 | return posemb 163 | 164 | 165 | def forward_flex(self, x): 166 | b, c, h, w = x.shape 167 | 168 | pos_embed = self._resize_pos_embed( 169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 170 | ) 171 | 172 | B = x.shape[0] 173 | 174 | if hasattr(self.patch_embed, "backbone"): 175 | x = self.patch_embed.backbone(x) 176 | if isinstance(x, (list, tuple)): 177 | x = x[-1] # last feature if backbone outputs list/tuple of features 178 | 179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 180 | 181 | #if hasattr(self, "dist_token"): 182 | # cls_tokens = self.cls_token.expand( 183 | # B, -1, -1 184 | # ) # stole cls_tokens impl from Phil Wang, thanks 185 | # dist_token = self.dist_token.expand(B, -1, -1) 186 | # x = torch.cat((cls_tokens, dist_token, x), dim=1) 187 | #else: 188 | cls_tokens = self.cls_token.expand( 189 | B, -1, -1 190 | ) # stole cls_tokens impl from Phil Wang, thanks 191 | x = torch.cat((cls_tokens, x), dim=1) 192 | 193 | x = x + pos_embed 194 | x = self.pos_drop(x) 195 | 196 | for blk in self.blocks: 197 | x = blk(x) 198 | 199 | x = self.norm(x) 200 | 201 | return x 202 | 203 | 204 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 205 | if use_readout == "ignore": 206 | readout_oper = [Slice(start_index)] * len(features) 207 | elif use_readout == "add": 208 | readout_oper = [AddReadout(start_index)] * len(features) 209 | elif use_readout == "project": 210 | readout_oper = [ 211 | ProjectReadout(vit_features, start_index) for out_feat in features 212 | ] 213 | else: 214 | assert ( 215 | False 216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 217 | 218 | return readout_oper 219 | 220 | 221 | def _make_vit_b16_backbone( 222 | model, 223 | features=[96, 192, 384, 768], 224 | size=[384, 384], 225 | hooks=[2, 5, 8, 11], 226 | vit_features=768, 227 | use_readout="ignore", 228 | start_index=1, 229 | enable_attention_hooks=False, 230 | ): 231 | pretrained = nn.Module() 232 | 233 | pretrained.model = model 234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 238 | 239 | pretrained.activations = activations 240 | 241 | if enable_attention_hooks: 242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook( 243 | get_attention("attn_1") 244 | ) 245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook( 246 | get_attention("attn_2") 247 | ) 248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook( 249 | get_attention("attn_3") 250 | ) 251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook( 252 | get_attention("attn_4") 253 | ) 254 | pretrained.attention = attention 255 | 256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 257 | 258 | # 32, 48, 136, 384 259 | pretrained.act_postprocess1 = nn.Sequential( 260 | readout_oper[0], 261 | Transpose(1, 2), 262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 263 | nn.Conv2d( 264 | in_channels=vit_features, 265 | out_channels=features[0], 266 | kernel_size=1, 267 | stride=1, 268 | padding=0, 269 | ), 270 | nn.ConvTranspose2d( 271 | in_channels=features[0], 272 | out_channels=features[0], 273 | kernel_size=4, 274 | stride=4, 275 | padding=0, 276 | bias=True, 277 | dilation=1, 278 | groups=1, 279 | ), 280 | ) 281 | 282 | pretrained.act_postprocess2 = nn.Sequential( 283 | readout_oper[1], 284 | Transpose(1, 2), 285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 286 | nn.Conv2d( 287 | in_channels=vit_features, 288 | out_channels=features[1], 289 | kernel_size=1, 290 | stride=1, 291 | padding=0, 292 | ), 293 | nn.ConvTranspose2d( 294 | in_channels=features[1], 295 | out_channels=features[1], 296 | kernel_size=2, 297 | stride=2, 298 | padding=0, 299 | bias=True, 300 | dilation=1, 301 | groups=1, 302 | ), 303 | ) 304 | 305 | pretrained.act_postprocess3 = nn.Sequential( 306 | readout_oper[2], 307 | Transpose(1, 2), 308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 309 | nn.Conv2d( 310 | in_channels=vit_features, 311 | out_channels=features[2], 312 | kernel_size=1, 313 | stride=1, 314 | padding=0, 315 | ), 316 | ) 317 | 318 | pretrained.act_postprocess4 = nn.Sequential( 319 | readout_oper[3], 320 | Transpose(1, 2), 321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 322 | nn.Conv2d( 323 | in_channels=vit_features, 324 | out_channels=features[3], 325 | kernel_size=1, 326 | stride=1, 327 | padding=0, 328 | ), 329 | nn.Conv2d( 330 | in_channels=features[3], 331 | out_channels=features[3], 332 | kernel_size=3, 333 | stride=2, 334 | padding=1, 335 | ), 336 | ) 337 | 338 | pretrained.model.start_index = start_index 339 | pretrained.model.patch_size = [16, 16] 340 | 341 | # We inject this function into the VisionTransformer instances so that 342 | # we can use it with interpolated position embeddings without modifying the library source. 343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 344 | pretrained.model._resize_pos_embed = types.MethodType( 345 | _resize_pos_embed, pretrained.model 346 | ) 347 | 348 | return pretrained 349 | 350 | 351 | def _make_vit_b_rn50_backbone( 352 | model, 353 | features=[256, 512, 768, 768], 354 | size=[384, 384], 355 | hooks=[0, 1, 8, 11], 356 | vit_features=768, 357 | use_vit_only=False, 358 | use_readout="ignore", 359 | start_index=1, 360 | enable_attention_hooks=False, 361 | ): 362 | pretrained = nn.Module() 363 | 364 | pretrained.model = model 365 | 366 | if use_vit_only == True: 367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 369 | else: 370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 371 | get_activation("1") 372 | ) 373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 374 | get_activation("2") 375 | ) 376 | 377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 379 | 380 | if enable_attention_hooks: 381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) 382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) 383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) 384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) 385 | pretrained.attention = attention 386 | 387 | pretrained.activations = activations 388 | 389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 390 | 391 | if use_vit_only == True: 392 | pretrained.act_postprocess1 = nn.Sequential( 393 | readout_oper[0], 394 | Transpose(1, 2), 395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 396 | nn.Conv2d( 397 | in_channels=vit_features, 398 | out_channels=features[0], 399 | kernel_size=1, 400 | stride=1, 401 | padding=0, 402 | ), 403 | nn.ConvTranspose2d( 404 | in_channels=features[0], 405 | out_channels=features[0], 406 | kernel_size=4, 407 | stride=4, 408 | padding=0, 409 | bias=True, 410 | dilation=1, 411 | groups=1, 412 | ), 413 | ) 414 | 415 | pretrained.act_postprocess2 = nn.Sequential( 416 | readout_oper[1], 417 | Transpose(1, 2), 418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 419 | nn.Conv2d( 420 | in_channels=vit_features, 421 | out_channels=features[1], 422 | kernel_size=1, 423 | stride=1, 424 | padding=0, 425 | ), 426 | nn.ConvTranspose2d( 427 | in_channels=features[1], 428 | out_channels=features[1], 429 | kernel_size=2, 430 | stride=2, 431 | padding=0, 432 | bias=True, 433 | dilation=1, 434 | groups=1, 435 | ), 436 | ) 437 | else: 438 | pretrained.act_postprocess1 = nn.Sequential( 439 | nn.Identity(), nn.Identity(), nn.Identity() 440 | ) 441 | pretrained.act_postprocess2 = nn.Sequential( 442 | nn.Identity(), nn.Identity(), nn.Identity() 443 | ) 444 | 445 | pretrained.act_postprocess3 = nn.Sequential( 446 | readout_oper[2], 447 | Transpose(1, 2), 448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 449 | nn.Conv2d( 450 | in_channels=vit_features, 451 | out_channels=features[2], 452 | kernel_size=1, 453 | stride=1, 454 | padding=0, 455 | ), 456 | ) 457 | 458 | pretrained.act_postprocess4 = nn.Sequential( 459 | readout_oper[3], 460 | Transpose(1, 2), 461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 462 | nn.Conv2d( 463 | in_channels=vit_features, 464 | out_channels=features[3], 465 | kernel_size=1, 466 | stride=1, 467 | padding=0, 468 | ), 469 | nn.Conv2d( 470 | in_channels=features[3], 471 | out_channels=features[3], 472 | kernel_size=3, 473 | stride=2, 474 | padding=1, 475 | ), 476 | ) 477 | 478 | pretrained.model.start_index = start_index 479 | pretrained.model.patch_size = [16, 16] 480 | 481 | # We inject this function into the VisionTransformer instances so that 482 | # we can use it with interpolated position embeddings without modifying the library source. 483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 484 | 485 | # We inject this function into the VisionTransformer instances so that 486 | # we can use it with interpolated position embeddings without modifying the library source. 487 | pretrained.model._resize_pos_embed = types.MethodType( 488 | _resize_pos_embed, pretrained.model 489 | ) 490 | 491 | return pretrained 492 | 493 | 494 | def _make_pretrained_vitb_rn50_384( 495 | pretrained, 496 | use_readout="ignore", 497 | hooks=None, 498 | use_vit_only=False, 499 | enable_attention_hooks=False, 500 | ): 501 | model = timm.create_model("vit_base_resnet50_384", pretrained='True') 502 | 503 | hooks = [0, 1, 8, 11] if hooks == None else hooks 504 | return _make_vit_b_rn50_backbone( 505 | model, 506 | features=[256, 512, 768, 768], 507 | size=[384, 384], 508 | hooks=hooks, 509 | use_vit_only=use_vit_only, 510 | use_readout=use_readout, 511 | enable_attention_hooks=enable_attention_hooks, 512 | ) 513 | 514 | 515 | def _make_pretrained_vitl16_384( 516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 517 | ): 518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 519 | 520 | hooks = [5, 11, 17, 23] if hooks == None else hooks 521 | return _make_vit_b16_backbone( 522 | model, 523 | features=[256, 512, 1024, 1024], 524 | hooks=hooks, 525 | vit_features=1024, 526 | use_readout=use_readout, 527 | enable_attention_hooks=enable_attention_hooks, 528 | ) 529 | 530 | 531 | def _make_pretrained_vitb16_384( 532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 533 | ): 534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 535 | 536 | hooks = [2, 5, 8, 11] if hooks == None else hooks 537 | return _make_vit_b16_backbone( 538 | model, 539 | features=[96, 192, 384, 768], 540 | hooks=hooks, 541 | use_readout=use_readout, 542 | enable_attention_hooks=enable_attention_hooks, 543 | ) 544 | 545 | 546 | def _make_pretrained_deitb16_384( 547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 548 | ): 549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 550 | 551 | hooks = [2, 5, 8, 11] if hooks == None else hooks 552 | return _make_vit_b16_backbone( 553 | model, 554 | features=[96, 192, 384, 768], 555 | hooks=hooks, 556 | use_readout=use_readout, 557 | enable_attention_hooks=enable_attention_hooks, 558 | ) 559 | 560 | 561 | def _make_pretrained_deitb16_distil_384( 562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 563 | ): 564 | model = timm.create_model( 565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 566 | ) 567 | 568 | hooks = [2, 5, 8, 11] if hooks == None else hooks 569 | return _make_vit_b16_backbone( 570 | model, 571 | features=[96, 192, 384, 768], 572 | hooks=hooks, 573 | use_readout=use_readout, 574 | start_index=2, 575 | enable_attention_hooks=enable_attention_hooks, 576 | ) 577 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import torchaudio 7 | import cv2 8 | import torch 9 | 10 | clip_length = 3 11 | 12 | def split_list(a, n): 13 | k, m = divmod(len(a), n) 14 | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)) 15 | 16 | 17 | # dataset for training 18 | class SalObjDataset(data.Dataset): 19 | def __init__(self, root, trainsize): 20 | self.trainsize = trainsize 21 | 22 | # get visual and audio clips (duration of three consecutive key frames) 23 | with open(root, 'r') as f: 24 | self.seqs = [x.strip() for x in f.readlines()] 25 | 26 | video_clips, gt_clips = [], [] 27 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], [] # ambisonics 28 | for seq_idx in self.seqs: 29 | # get visual clips of each video 30 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx)) 31 | frm_list = sorted(frm_list) 32 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/train/', seq_idx)) 33 | gt_list = sorted(gt_list) 34 | for idx in range(len(frm_list)): 35 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx] 36 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/train/' + seq_idx + '/' + gt_list[idx] 37 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length))) 38 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length))) 39 | video_clips.append(frm_list_split) 40 | gt_clips.append(gt_list_split) 41 | 42 | # get audio clips of each video 43 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav' 44 | audio_ori = torchaudio.load(audio_pth)[0] 45 | audio_ori_ch1 = audio_ori[0] 46 | audio_ori_ch2 = audio_ori[1] 47 | audio_ori_ch3 = audio_ori[2] 48 | audio_ori_ch4 = audio_ori[3] 49 | audio_split_size = int(len(audio_ori[1])/(int(len(frm_list) / clip_length))) 50 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size) 51 | audio_ch1_split = list(audio_ch1_split) 52 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)] 53 | audio_clips_ch1.append(audio_ch1_split) 54 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size) 55 | audio_ch2_split = list(audio_ch2_split) 56 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)] 57 | audio_clips_ch2.append(audio_ch2_split) 58 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size) 59 | audio_ch3_split = list(audio_ch3_split) 60 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)] 61 | audio_clips_ch3.append(audio_ch3_split) 62 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size) 63 | audio_ch4_split = list(audio_ch4_split) 64 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)] 65 | audio_clips_ch4.append(audio_ch4_split) 66 | 67 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist] 68 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist] 69 | 70 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist] 71 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist] 72 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist] 73 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist] 74 | 75 | # tools for data transform 76 | self.img_transform = transforms.Compose([ 77 | transforms.Resize((self.trainsize, self.trainsize * 2)), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 80 | self.gt_transform = transforms.Compose([ 81 | transforms.Resize((self.trainsize, self.trainsize * 2)), 82 | transforms.ToTensor()]) 83 | 84 | # dataset size 85 | self.size = len(self.video_clips_flatten) 86 | 87 | def __getitem__(self, index): 88 | seq_name = self.video_clips_flatten[index][0].split('/')[-2] 89 | imgs, gts, audios = [], [], [] 90 | for idx in range(clip_length): # default as 3 consecutive frames of each clip 91 | curr_img = self.rgb_loader(self.video_clips_flatten[index][idx]) 92 | imgs.append(self.img_transform(curr_img)) 93 | 94 | curr_gt = self.binary_loader(self.gt_clips_flatten[index][idx]) 95 | gts.append(self.gt_transform(curr_gt)) 96 | 97 | # collect audio clip (default as 15 frames equivalently; five times of clip length) 98 | if index == 0: a_index = [0, 1, 2, 3, 4] 99 | elif index == 1: a_index = [0, 1, 2, 3, 4] 100 | elif index == self.size - 1: 101 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1] 102 | elif index == self.size - 2: 103 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1] 104 | else: 105 | if self.video_clips_flatten[index - 2][0].split('/')[-2] != seq_name: 106 | if self.video_clips_flatten[index - 1][0].split('/')[-2] != seq_name: 107 | a_index = [index, index + 1, index + 2, index + 3, index + 4] 108 | else: 109 | a_index = [index - 1, index, index + 1, index + 2, index + 3] 110 | elif self.video_clips_flatten[index + 2][0].split('/')[-2] != seq_name: 111 | if self.video_clips_flatten[index + 1][0].split('/')[-2] != seq_name: 112 | a_index = [index - 4, index - 3, index - 2, index - 1, index] 113 | else: 114 | a_index = [index - 3, index - 2, index - 1, index, index + 1] 115 | else: 116 | a_index = [index - 2, index - 1, index, index + 1, index + 2] 117 | 118 | a_ch1, a_ch2, a_ch3, a_ch4 = [], [], [], [] 119 | for aa in range(5): 120 | a_ch1.append(self.audio_ch1_clips_flatten[a_index[aa]]) 121 | a_ch2.append(self.audio_ch2_clips_flatten[a_index[aa]]) 122 | a_ch3.append(self.audio_ch3_clips_flatten[a_index[aa]]) 123 | a_ch4.append(self.audio_ch4_clips_flatten[a_index[aa]]) 124 | audios.append([torch.cat(a_ch1), torch.cat(a_ch2), torch.cat(a_ch3), torch.cat(a_ch4)]) 125 | audios = torch.stack(audios[0], dim=0) 126 | audios = torch.mean(audios, dim=0).unsqueeze(0) # ambisonics to mono 127 | 128 | return imgs, gts, audios, seq_name 129 | 130 | def rgb_loader(self, path): 131 | with open(path, 'rb') as f: 132 | img = Image.open(f) 133 | return img.convert('RGB') 134 | 135 | def binary_loader(self, path): 136 | with open(path, 'rb') as f: 137 | img = Image.open(f) 138 | return img.convert('L') 139 | 140 | def __len__(self): 141 | return self.size 142 | 143 | #dataloader for training 144 | def get_loader(root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 145 | 146 | dataset = SalObjDataset(root, trainsize) 147 | data_loader = data.DataLoader(dataset=dataset, 148 | batch_size=batchsize, 149 | shuffle=shuffle, 150 | num_workers=num_workers, 151 | pin_memory=pin_memory) 152 | return data_loader 153 | 154 | 155 | #test dataset and loader 156 | class test_dataset: 157 | def __init__(self, root, testsize): 158 | self.testsize = testsize 159 | 160 | with open(root, 'r') as f: 161 | self.seqs = [x.strip() for x in f.readlines()] 162 | 163 | video_clips, gt_clips = [], [] 164 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], [] 165 | for seq_idx in self.seqs: 166 | # get visual clips of each video 167 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx)) 168 | frm_list = sorted(frm_list) 169 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/test/', seq_idx)) 170 | gt_list = sorted(gt_list) 171 | for idx in range(len(frm_list)): 172 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx] 173 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/test/' + seq_idx + '/' + gt_list[idx] 174 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length))) 175 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length))) 176 | video_clips.append(frm_list_split) 177 | gt_clips.append(gt_list_split) 178 | 179 | # get audio clips of each video 180 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav' 181 | audio_ori = torchaudio.load(audio_pth)[0] 182 | audio_ori_ch1 = audio_ori[0] 183 | audio_ori_ch2 = audio_ori[1] 184 | audio_ori_ch3 = audio_ori[2] 185 | audio_ori_ch4 = audio_ori[3] 186 | audio_split_size = int(len(audio_ori[1]) / (int(len(frm_list) / clip_length))) 187 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size) 188 | audio_ch1_split = list(audio_ch1_split) 189 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)] 190 | audio_clips_ch1.append(audio_ch1_split) 191 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size) 192 | audio_ch2_split = list(audio_ch2_split) 193 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)] 194 | audio_clips_ch2.append(audio_ch2_split) 195 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size) 196 | audio_ch3_split = list(audio_ch3_split) 197 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)] 198 | audio_clips_ch3.append(audio_ch3_split) 199 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size) 200 | audio_ch4_split = list(audio_ch4_split) 201 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)] 202 | audio_clips_ch4.append(audio_ch4_split) 203 | 204 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist] 205 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist] 206 | 207 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist] 208 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist] 209 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist] 210 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist] 211 | 212 | # tools for data transform 213 | self.transform = transforms.Compose([ 214 | transforms.Resize((self.testsize, self.testsize * 2)), 215 | transforms.ToTensor(), 216 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 217 | self.transform_er_cube = transforms.Compose([ 218 | transforms.Resize((640, 1280)), 219 | transforms.ToTensor(), 220 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 221 | 222 | self.size = len(self.video_clips_flatten) 223 | self.index = 0 224 | 225 | def load_data(self): 226 | seq_name = self.video_clips_flatten[self.index][0].split('/')[-2] 227 | imgs, ER_imgs, gts, audios, frm_names = [], [], [], [], [] 228 | for idx in range(clip_length): # default as 3 consecutive frames of each clip 229 | curr_img = self.rgb_loader(self.video_clips_flatten[self.index][idx]) 230 | ER_imgs.append(self.transform_er_cube(curr_img).unsqueeze(0)) 231 | imgs.append(self.transform(curr_img).unsqueeze(0)) 232 | curr_gt = self.binary_loader(self.gt_clips_flatten[self.index][idx]) 233 | curr_gt = curr_gt.resize((self.testsize * 2, self.testsize)) 234 | gts.append(curr_gt) 235 | frm_names.append(self.video_clips_flatten[self.index][idx].split('/')[-1]) 236 | audios.append([self.audio_ch1_clips_flatten[self.index], self.audio_ch2_clips_flatten[self.index], 237 | self.audio_ch3_clips_flatten[self.index], self.audio_ch4_clips_flatten[self.index]]) 238 | audios = torch.stack(audios[0], dim=0).unsqueeze(0) 239 | 240 | self.index += 1 241 | self.index = self.index % self.size 242 | 243 | return imgs, ER_imgs, gts, audios, seq_name, frm_names 244 | 245 | def rgb_loader(self, path): 246 | with open(path, 'rb') as f: 247 | img = Image.open(f) 248 | return img.convert('RGB') 249 | 250 | def binary_loader(self, path): 251 | with open(path, 'rb') as f: 252 | img = Image.open(f) 253 | return img.convert('L') 254 | 255 | def __len__(self): 256 | return self.size 257 | 258 | 259 | class dataset_inference: 260 | def __init__(self, root, testsize): 261 | self.testsize = testsize 262 | 263 | with open(root, 'r') as f: 264 | self.seqs = [x.strip() for x in f.readlines()] 265 | 266 | video_clips, gt_clips = [], [] 267 | audio_clips_ch1, audio_clips_ch2, audio_clips_ch3, audio_clips_ch4 = [], [], [], [] 268 | for seq_idx in self.seqs: 269 | # get visual clips of each video 270 | frm_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/frame_key/', seq_idx)) 271 | frm_list = sorted(frm_list) 272 | gt_list = os.listdir(os.path.join('/home/yzhang1/PythonProjects/AV360/DATA/test/', seq_idx)) 273 | gt_list = sorted(gt_list) 274 | for idx in range(len(frm_list)): 275 | frm_list[idx] = '/home/yzhang1/PythonProjects/AV360/frame_key/' + seq_idx + '/' + frm_list[idx] 276 | gt_list[idx] = '/home/yzhang1/PythonProjects/AV360/DATA/test/' + seq_idx + '/' + gt_list[idx] 277 | frm_list_split = list(split_list(frm_list, int(len(frm_list) / clip_length))) 278 | gt_list_split = list(split_list(gt_list, int(len(gt_list) / clip_length))) 279 | video_clips.append(frm_list_split) 280 | gt_clips.append(gt_list_split) 281 | 282 | # get audio clips of each video 283 | audio_pth = '/home/yzhang1/PythonProjects/AV360/ambisonics_trimmed/' + seq_idx + '.wav' 284 | audio_ori = torchaudio.load(audio_pth)[0] 285 | audio_ori_ch1 = audio_ori[0] 286 | audio_ori_ch2 = audio_ori[1] 287 | audio_ori_ch3 = audio_ori[2] 288 | audio_ori_ch4 = audio_ori[3] 289 | audio_split_size = int(len(audio_ori[1]) / (int(len(frm_list) / clip_length))) 290 | audio_ch1_split = torch.split(tensor=audio_ori_ch1, split_size_or_sections=audio_split_size) 291 | audio_ch1_split = list(audio_ch1_split) 292 | audio_ch1_split = audio_ch1_split[:int(len(frm_list) / clip_length)] 293 | audio_clips_ch1.append(audio_ch1_split) 294 | audio_ch2_split = torch.split(tensor=audio_ori_ch2, split_size_or_sections=audio_split_size) 295 | audio_ch2_split = list(audio_ch2_split) 296 | audio_ch2_split = audio_ch2_split[:int(len(frm_list) / clip_length)] 297 | audio_clips_ch2.append(audio_ch2_split) 298 | audio_ch3_split = torch.split(tensor=audio_ori_ch3, split_size_or_sections=audio_split_size) 299 | audio_ch3_split = list(audio_ch3_split) 300 | audio_ch3_split = audio_ch3_split[:int(len(frm_list) / clip_length)] 301 | audio_clips_ch3.append(audio_ch3_split) 302 | audio_ch4_split = torch.split(tensor=audio_ori_ch4, split_size_or_sections=audio_split_size) 303 | audio_ch4_split = list(audio_ch4_split) 304 | audio_ch4_split = audio_ch4_split[:int(len(frm_list) / clip_length)] 305 | audio_clips_ch4.append(audio_ch4_split) 306 | 307 | self.video_clips_flatten = [item for sublist in video_clips for item in sublist] 308 | self.gt_clips_flatten = [item for sublist in gt_clips for item in sublist] 309 | 310 | self.audio_ch1_clips_flatten = [item for sublist in audio_clips_ch1 for item in sublist] 311 | self.audio_ch2_clips_flatten = [item for sublist in audio_clips_ch2 for item in sublist] 312 | self.audio_ch3_clips_flatten = [item for sublist in audio_clips_ch3 for item in sublist] 313 | self.audio_ch4_clips_flatten = [item for sublist in audio_clips_ch4 for item in sublist] 314 | 315 | # tools for data transform 316 | self.transform = transforms.Compose([ 317 | transforms.Resize((self.testsize, self.testsize * 2)), 318 | transforms.ToTensor(), 319 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 320 | 321 | self.size = len(self.video_clips_flatten) 322 | self.index = 0 323 | 324 | def load_data(self): 325 | seq_name = self.video_clips_flatten[self.index][0].split('/')[-2] 326 | imgs, gts, audios, frm_names = [], [], [], [] 327 | for idx in range(clip_length): # default as 3 consecutive frames of each clip 328 | curr_img = self.rgb_loader(self.video_clips_flatten[self.index][idx]) 329 | imgs.append(self.transform(curr_img).unsqueeze(0)) 330 | 331 | curr_gt = self.binary_loader(self.gt_clips_flatten[self.index][idx]) 332 | gts.append(curr_gt) 333 | 334 | frm_names.append(self.video_clips_flatten[self.index][idx].split('/')[-1]) 335 | 336 | # collect audio clip (default as 15 frames equivalently; five times of clip length) 337 | if self.index == 0: 338 | a_index = [0, 1, 2, 3, 4] 339 | elif self.index == 1: 340 | a_index = [0, 1, 2, 3, 4] 341 | elif self.index == self.size - 1: 342 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1] 343 | elif self.index == self.size - 2: 344 | a_index = [self.size - 5, self.size - 4, self.size - 3, self.size - 2, self.size - 1] 345 | else: 346 | if self.video_clips_flatten[self.index - 2][0].split('/')[-2] != seq_name: 347 | if self.video_clips_flatten[self.index - 1][0].split('/')[-2] != seq_name: 348 | a_index = [self.index, self.index + 1, self.index + 2, self.index + 3, self.index + 4] 349 | else: 350 | a_index = [self.index - 1, self.index, self.index + 1, self.index + 2, self.index + 3] 351 | elif self.video_clips_flatten[self.index + 2][0].split('/')[-2] != seq_name: 352 | if self.video_clips_flatten[self.index + 1][0].split('/')[-2] != seq_name: 353 | a_index = [self.index - 4, self.index - 3, self.index - 2, self.index - 1, self.index] 354 | else: 355 | a_index = [self.index - 3, self.index - 2, self.index - 1, self.index, self.index + 1] 356 | else: 357 | a_index = [self.index - 2, self.index - 1, self.index, self.index + 1, self.index + 2] 358 | 359 | a_ch1, a_ch2, a_ch3, a_ch4 = [], [], [], [] 360 | for aa in range(5): 361 | a_ch1.append(self.audio_ch1_clips_flatten[a_index[aa]]) 362 | a_ch2.append(self.audio_ch2_clips_flatten[a_index[aa]]) 363 | a_ch3.append(self.audio_ch3_clips_flatten[a_index[aa]]) 364 | a_ch4.append(self.audio_ch4_clips_flatten[a_index[aa]]) 365 | audios.append([torch.cat(a_ch1), torch.cat(a_ch2), torch.cat(a_ch3), torch.cat(a_ch4)]) 366 | audios = torch.stack(audios[0], dim=0) 367 | audios = torch.mean(audios, dim=0).unsqueeze(0).unsqueeze(0) 368 | 369 | self.index += 1 370 | self.index = self.index % self.size 371 | 372 | return imgs, gts, audios, seq_name, frm_names 373 | 374 | def rgb_loader(self, path): 375 | with open(path, 'rb') as f: 376 | img = Image.open(f) 377 | return img.convert('RGB') 378 | 379 | def binary_loader(self, path): 380 | with open(path, 'rb') as f: 381 | img = Image.open(f) 382 | return img.convert('L') 383 | 384 | def __len__(self): 385 | return self.size 386 | --------------------------------------------------------------------------------