├── Achieving Cross Modal Generalization with Multimodal Unified Representation.pdf ├── README.md ├── cnt.pkl ├── code ├── AVSBench_dowmstream │ └── avs_scripts │ │ └── avs_s4 │ │ ├── config.py │ │ ├── dataloader.py │ │ ├── loss.py │ │ ├── model │ │ ├── CLUB.py │ │ ├── CPC.py │ │ ├── Dual_lstm.py │ │ ├── Dual_model.py │ │ ├── PVT_AVSModel.py │ │ ├── ResNet_AVSModel.py │ │ ├── TPAVI.py │ │ ├── UniEncoder.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── PVT_AVSModel.cpython-38.pyc │ │ │ ├── TPAVI.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── main_model_2.cpython-38.pyc │ │ │ ├── mine.cpython-38.pyc │ │ │ ├── models.cpython-38.pyc │ │ │ └── pvt.cpython-38.pyc │ │ ├── main_model_2.py │ │ ├── mine.py │ │ ├── models.py │ │ ├── pvt.py │ │ ├── resnet.py │ │ ├── rnn_att_2.py │ │ ├── test.py │ │ ├── train.py │ │ └── transformer.py │ │ ├── test.sh │ │ ├── test_at.py │ │ ├── test_ta.py │ │ ├── torchvggish │ │ ├── __pycache__ │ │ │ ├── mel_features.cpython-38.pyc │ │ │ ├── vggish.cpython-38.pyc │ │ │ ├── vggish_input.cpython-38.pyc │ │ │ └── vggish_params.cpython-38.pyc │ │ ├── mel_features.py │ │ ├── vggish.py │ │ ├── vggish_input.py │ │ └── vggish_params.py │ │ ├── train.sh │ │ ├── train_at.py │ │ ├── train_ta.py │ │ └── utils │ │ ├── __pycache__ │ │ ├── pyutils.cpython-38.pyc │ │ ├── system.cpython-38.pyc │ │ └── utility.cpython-38.pyc │ │ ├── pyutils.py │ │ ├── system.py │ │ └── utility.py └── src │ ├── .DS_Store │ ├── .gitignore │ ├── ave.py │ ├── ave.sh │ ├── ave_avvp.py │ ├── ave_avvp.sh │ ├── avvp.py │ ├── avvp.sh │ ├── configs │ ├── default_config.yaml │ └── opts.py │ ├── current_configs.yaml │ ├── dataset │ ├── AVE_AVVP_dataset.py │ ├── AVE_dataset.py │ ├── AVVP_dataset.py │ ├── UCF_VGGSOUND_datasets.py │ ├── VGGSOUND_dataset.py │ ├── VGGSOUND_dataset179k.py │ └── __init__.py │ ├── model │ ├── CLUB.py │ ├── CPC.py │ ├── Dual_lstm.py │ ├── Dual_model.py │ ├── UniEncoder.py │ ├── __init__.py │ ├── main_model_2.py │ ├── mine.py │ ├── models.py │ ├── rnn_att_2.py │ ├── test.py │ └── transformer.py │ ├── pretrain.py │ ├── pretrain.sh │ ├── ucf_vggsound.py │ ├── ucf_vggsound.sh │ └── utils │ ├── Recorder.py │ ├── __init__.py │ ├── container.py │ ├── draw.py │ └── utils.py ├── figs ├── MM_EMA.pdf ├── illustration.pdf └── model.png └── requirements.txt /Achieving Cross Modal Generalization with Multimodal Unified Representation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/Achieving Cross Modal Generalization with Multimodal Unified Representation.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Achieving Cross Modal Generalization with Multimodal Unified Representation, NeurIPS 2023 2 | 3 | 4 | 5 | ![model](figs/model.png) 6 | 7 | This is the Pytorch implementation of our paper: 8 | 9 | Achieving Cross Modal Generalization with Multimodal Unified Representation 10 | 11 | [Yan Xia](https://scholar.google.com/citations?user=6kEbV3IAAAAJ&hl), [Hai Huang](https://scholar.google.com/citations?user=FKvBzQwAAAAJ), [Jieming Zhu](https://scholar.google.com/citations?user=oNKerP8AAAAJ), [Zhou Zhao](https://scholar.google.com.hk/citations?user=IIoFY90AAAAJ) 12 | 13 | In NeurIPS 2023 14 | 15 | ------ 16 | 17 | ### 📝Requirements and Installation 18 | 19 | - ##### Getting Started 20 | **Due to the version conflict between bert_embedding's dependency on NumPy and other libraries, directly installing according to requirements.txt may cause issues. For more details, you can refer to this [issue](https://github.com/haihuangcode/CMG/issues/14)."** 21 | ```python 22 | git clone https://github.com/haihuangcode/CMG 23 | cd CMG 24 | # You don't actually have to install all the libraries in the txt file, you can choose to install them as needed. 25 | # It is recommended to use Python 3.7, as some libraries used do not support higher versions of Python. 26 | conda create -n your_env_name python=3.7 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | - ##### Pretrain 31 | ```python 32 | # Before you begin pretraining, please make sure to modify the file paths under `args.dataset_name == 'vggsound_AVT'` in `pretrain.py` to your own paths. 33 | # Additionally, update the `file_path` and `self.label2prompt = pd.read_csv('')` paths in `dataset/VGGSOUND_dataset.py`. 34 | # The model save path is located under `--model_save_path` in `configs/opts.py`. 35 | # Please also remember to modify the paths related to downstream tasks and the corresponding dataset paths to your own paths. 36 | cd CMG/code/src 37 | ./pretrain.sh 38 | ``` 39 | 40 | - ##### AVE_downstream 41 | ```python 42 | cd CMG/code/src 43 | ./ave.sh 44 | ``` 45 | 46 | - ##### AVVP_downstream 47 | ```python 48 | cd CMG/code/src 49 | ./avvp.sh 50 | ``` 51 | 52 | - ##### AVE_AVVP_downstream 53 | ```python 54 | cd CMG/code/src 55 | ./ave_avvp.sh 56 | ``` 57 | 58 | - ##### UCF_VGGSOUND_downstream 59 | ```python 60 | cd CMG/code/src 61 | ./ucf_vggsound.sh 62 | ``` 63 | 64 | - ##### AVS_downstream 65 | ```python 66 | cd CMG/code/AVSBench_downstream/avs_scripts/avs_s4 67 | ./train.sh 68 | ./test.sh 69 | ``` 70 | 71 | ## 🎓Cite 72 | 73 | If you find this work useful, please consider citing it. 74 | 75 | ``` 76 | @article{xia2024achieving, 77 | title={Achieving Cross Modal Generalization with Multimodal Unified Representation}, 78 | author={Xia, Yan and Huang, Hai and Zhu, Jieming and Zhao, Zhou}, 79 | journal={Advances in Neural Information Processing Systems}, 80 | volume={36}, 81 | year={2024} 82 | } 83 | ``` 84 | 85 | ## ✏Model Checkpoints And Date Feature 86 | 87 | [data](https://pan.baidu.com/s/1CTcjMHVeG-8uo4HPWNNL9Q ) (pwd: 1234) 88 | - 2023.11.07 Update https://github.com/haihuangcode/CMG/issues/1 89 | 90 | [patch](https://pan.baidu.com/s/1rjVmRMut39kezw0FDZ7MwQ) (pwd: 1234) 91 | - 2024.12.27 This is a patch for the previous data errors. Please download the complete data from the above and replace the csv files in the patch with the ones in `data/vggsound40k/data`, specifically replacing `vggsound-avel40k.csv` and `video_name_vggsound40k_checked.csv`. The previous https://github.com/haihuangcode/CMG/issues/13 regarding unsatisfactory model training results were caused by the incomplete csv files that were uploaded earlier, which only contained 20k data entries. I apologize for not noticing this earlier /(ㄒoㄒ)/~~ 92 | ## ✏Directory 93 | 94 | ``` 95 | CMG 96 | ├── checkpoint 97 | ├── cnt.pkl 98 | ├── code 99 | ├── data 100 | ├── figs 101 | ├── paper 102 | ├── README.md 103 | └── requirements.txt 104 | ``` 105 | 106 | ## ✏Note 107 | - For the video and audio feature extraction method, please refer to [AVE](https://github.com/YapengTian/AVE-ECCV18), text is based on the label to generate a description-focused statement of approximately 10 words in length. 108 | - There is no validation set for the pre-training process, in this paper it is done by testing the performance of each model on the downstream of the [AVE](https://github.com/YapengTian/AVE-ECCV18), and the model with the best performance tests the rest of the downstream tasks, so the [AVE](https://github.com/YapengTian/AVE-ECCV18) can be regarded as a validation set and the model with the best pre-training appears in the first 5 epochs. 109 | - Pretraining can be performed using just one GPU, such as 4090 or A100. The experimental results in the paper were obtained by running on 4090 or A100. Multi-GPU parallel training yielded poorer model performance, possibly due to issues between the mutual information minimization design in DCID and Pytorch (but this was an early experimental observation, and was not re-verified after the code was finalized, since single GPU pretraining was sufficient). 110 | 111 | ## 👍Acknowledgments 112 | 113 | Our code is based on [AVE](https://github.com/YapengTian/AVE-ECCV18), [AVVP](https://github.com/YapengTian/AVVP-ECCV20), [PSP](https://github.com/jasongief/PSP_CVPR_2021), [CPSP](https://github.com/jasongief/CPSP), [VGGSOUND](https://github.com/hche11/VGGSound), [AVS](https://github.com/OpenNLPLab/AVSBench). 114 | -------------------------------------------------------------------------------- /cnt.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/cnt.pkl -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import yaml 3 | import pdb 4 | 5 | """ 6 | default config 7 | """ 8 | cfg = edict() 9 | cfg.BATCH_SIZE = 4 10 | cfg.LAMBDA_1 = 50 11 | 12 | ############################## 13 | # TRAIN 14 | cfg.TRAIN = edict() 15 | # TRAIN.SCHEDULER 16 | cfg.TRAIN.FREEZE_AUDIO_EXTRACTOR = True 17 | cfg.TRAIN.PRETRAINED_VGGISH_MODEL_PATH = "../../pretrained_backbones/vggish-10086976.pth" 18 | cfg.TRAIN.PREPROCESS_AUDIO_TO_LOG_MEL = False 19 | cfg.TRAIN.POSTPROCESS_LOG_MEL_WITH_PCA = False 20 | cfg.TRAIN.PRETRAINED_PCA_PARAMS_PATH = "../../pretrained_backbones/vggish_pca_params-970ea276.pth" 21 | cfg.TRAIN.FREEZE_VISUAL_EXTRACTOR = False 22 | cfg.TRAIN.PRETRAINED_RESNET50_PATH = "../../pretrained_backbones/resnet50-19c8e357.pth" 23 | cfg.TRAIN.PRETRAINED_PVTV2_PATH = "../../pretrained_backbones/pvt_v2_b5.pth" 24 | 25 | ############################### 26 | # DATA 27 | cfg.DATA = edict() 28 | cfg.DATA.ANNO_CSV = "AVSBench_data/Single-source/s4_meta_data.csv" 29 | cfg.DATA.DIR_IMG = "AVSBench_data/Single-source/s4_data/visual_frames" 30 | # cfg.DATA.DIR_AUDIO_LOG_MEL = "/root/autodl-tmp/AVSBench_data/Single-source/s4_data/audio_log_mel" 31 | cfg.DATA.DIR_AUDIO_FEATURE = "AVSBench/feature/audio" 32 | cfg.DATA.DIR_MASK = "AVSBench_data/Single-source/s4_data/gt_masks" 33 | cfg.DATA.IMG_SIZE = (224, 224) 34 | ############################### 35 | 36 | 37 | 38 | 39 | # def _edict2dict(dest_dict, src_edict): 40 | # if isinstance(dest_dict, dict) and isinstance(src_edict, dict): 41 | # for k, v in src_edict.items(): 42 | # if not isinstance(v, edict): 43 | # dest_dict[k] = v 44 | # else: 45 | # dest_dict[k] = {} 46 | # _edict2dict(dest_dict[k], v) 47 | # else: 48 | # return 49 | 50 | 51 | # def gen_config(config_file): 52 | # cfg_dict = {} 53 | # _edict2dict(cfg_dict, cfg) 54 | # with open(config_file, 'w') as f: 55 | # yaml.dump(cfg_dict, f, default_flow_style=False) 56 | 57 | 58 | # def _update_config(base_cfg, exp_cfg): 59 | # if isinstance(base_cfg, dict) and isinstance(exp_cfg, edict): 60 | # for k, v in exp_cfg.items(): 61 | # if k in base_cfg: 62 | # if not isinstance(v, dict): 63 | # base_cfg[k] = v 64 | # else: 65 | # _update_config(base_cfg[k], v) 66 | # else: 67 | # raise ValueError("{} not exist in config.py".format(k)) 68 | # else: 69 | # return 70 | 71 | 72 | # def update_config_from_file(filename): 73 | # exp_config = None 74 | # with open(filename) as f: 75 | # exp_config = edict(yaml.safe_load(f)) 76 | # _update_config(cfg, exp_config) 77 | 78 | if __name__ == "__main__": 79 | print(cfg) 80 | pdb.set_trace() 81 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from wave import _wave_params 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import Dataset 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pickle 10 | 11 | import cv2 12 | from PIL import Image 13 | from torchvision import transforms 14 | 15 | from config import cfg 16 | import pdb 17 | from bert_embedding import BertEmbedding 18 | import pickle 19 | import zipfile 20 | from io import BytesIO 21 | 22 | def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None): 23 | img_PIL = Image.open(path).convert(mode) 24 | if transform: 25 | img_tensor = transform(img_PIL) 26 | return img_tensor 27 | return img_PIL 28 | 29 | 30 | def load_audio_lm(audio_lm_path): 31 | with open(audio_lm_path, 'rb') as fr: 32 | audio_log_mel = pickle.load(fr) 33 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64] 34 | return audio_log_mel 35 | 36 | bert_embedding = BertEmbedding() 37 | with open('cnt.pkl', 'rb') as fp: 38 | id2idx = pickle.load(fp) 39 | 40 | class S4Dataset(Dataset): 41 | """Dataset for single sound source segmentation""" 42 | def __init__(self, split='train'): 43 | super(S4Dataset, self).__init__() 44 | self.split = split 45 | self.mask_num = 1 if self.split == 'train' else 5 46 | self.label2prompt = pd.read_csv('AVSBenchCategories2Prompts.csv') 47 | df_all = pd.read_csv(cfg.DATA.ANNO_CSV, sep=',') 48 | self.df_split = df_all[df_all['split'] == split] 49 | print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split)) 50 | self.img_transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 53 | ]) 54 | self.mask_transform = transforms.Compose([ 55 | transforms.ToTensor(), 56 | ]) 57 | 58 | 59 | def __getitem__(self, index): 60 | df_one_video = self.df_split.iloc[index] 61 | video_name, category = df_one_video[0], df_one_video[2] 62 | img_base_path = os.path.join(cfg.DATA.DIR_IMG, self.split, category, video_name) 63 | # audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, category, video_name + '.pkl') 64 | audio_feature_path = os.path.join(cfg.DATA.DIR_AUDIO_FEATURE, self.split, "zip",category) 65 | mask_base_path = os.path.join(cfg.DATA.DIR_MASK, self.split, category, video_name) 66 | # audio_log_mel = load_audio_lm(audio_lm_path) 67 | audio_feature = self._load_fea(audio_feature_path, video_name) 68 | 69 | if audio_feature.shape[0] < 5: 70 | cur_t = audio_feature.shape[0] 71 | add_arr = np.tile(audio_feature[-1, :], (5-cur_t, 1)) 72 | audio_feature = np.concatenate([audio_feature, add_arr], axis=0) 73 | elif audio_feature.shape[0] > 5: 74 | audio_feature = audio_feature[:5, :] 75 | 76 | 77 | text_fea = self.label2prompt.loc[self.label2prompt['label'] == category].values[0][1] 78 | # audio_lm_tensor = torch.from_numpy(audio_log_mel) 79 | imgs, masks = [], [] 80 | for img_id in range(1, 6): 81 | img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s_%d.png"%(video_name, img_id)), transform=self.img_transform) 82 | imgs.append(img) 83 | for mask_id in range(1, self.mask_num + 1): 84 | mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='1') 85 | masks.append(mask) 86 | imgs_tensor = torch.stack(imgs, dim=0) 87 | masks_tensor = torch.stack(masks, dim=0) 88 | 89 | sample = {'imgs_tensor': imgs_tensor, 90 | 'audio_fea': audio_feature, 91 | 'masks_tensor': masks_tensor, 92 | 'category': category, 93 | 'video_name': video_name, 94 | 'text_fea': text_fea} 95 | 96 | return sample 97 | 98 | 99 | def _load_fea(self, fea_base_path, video_id): 100 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 101 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 102 | for name in zfile.namelist(): 103 | if '.pkl' not in name: 104 | continue 105 | with zfile.open(name, mode='r') as fea_file: 106 | content = BytesIO(fea_file.read()) 107 | fea = pickle.load(content) 108 | return fea 109 | 110 | def __len__(self): 111 | return len(self.df_split) 112 | 113 | 114 | def collate_func(self, samples): 115 | 116 | bsz = len(samples) 117 | result = bert_embedding([sample['text_fea'] for sample in samples]) 118 | query = [] 119 | query_words = [] 120 | for a, b in result: 121 | words = [] 122 | words_emb = [] 123 | for word, emb in zip(a, b): 124 | idx = bert_embedding.vocab.token_to_idx[word] 125 | if idx in id2idx and idx != 0: 126 | words_emb.append(emb) 127 | words.append(id2idx[idx]) 128 | query.append(np.asarray(words_emb)) 129 | query_words.append(words) 130 | 131 | query_len = [] 132 | for i, sample in enumerate(query): 133 | # query_len.append(min(len(sample), 10))#max_num_words:10 134 | query_len.append(10)#max_num_words:10 135 | query1 = np.zeros([bsz, max(query_len), 768]).astype(np.float32) 136 | query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32) 137 | for i, sample in enumerate(query): 138 | keep = min(sample.shape[0], query1.shape[1]) 139 | query1[i, :keep] = sample[:keep] 140 | query_idx[i, :keep] = query_words[i][:keep] 141 | query_len = np.asarray(query_len) 142 | query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long() 143 | query_idx = torch.from_numpy(query_idx).long() 144 | 145 | image_tensors = [sample['imgs_tensor'] for sample in samples] 146 | stacked_images = np.stack(image_tensors) 147 | imgs_tensor = torch.from_numpy(stacked_images).float() 148 | 149 | maskeds_tensors = [sample['masks_tensor'] for sample in samples] 150 | stacked_masks = np.stack(maskeds_tensors) 151 | masks_tensor = torch.from_numpy(stacked_masks).float() 152 | 153 | categorys = [sample['category'] for sample in samples] 154 | video_names = [sample['video_name'] for sample in samples] 155 | 156 | return { 157 | 'query': query, 158 | 'imgs_tensor':imgs_tensor, 159 | 'audio_fea': torch.from_numpy(np.asarray([sample['audio_fea'] for sample in samples])).float(), 160 | 'masks_tensor':masks_tensor, 161 | 'category':categorys, 162 | 'video_name': video_names, 163 | } 164 | 165 | 166 | if __name__ == "__main__": 167 | train_dataset = S4Dataset('train') 168 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 169 | batch_size=2, 170 | shuffle=False, 171 | num_workers=8, 172 | pin_memory=True) 173 | 174 | for n_iter, batch_data in enumerate(train_dataloader): 175 | imgs, audio, mask = batch_data # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224] 176 | pdb.set_trace() 177 | print('n_iter', n_iter) 178 | pdb.set_trace() 179 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | 6 | 7 | def F1_IoU_BCELoss(pred_masks, first_gt_mask): 8 | """ 9 | binary cross entropy loss (iou loss) of the first frame for single sound source segmentation 10 | 11 | Args: 12 | pred_masks: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224] 13 | first_gt_mask: ground truth mask of the first frame, shape: [bs, 1, 1, 224, 224] 14 | """ 15 | assert len(pred_masks.shape) == 4 16 | pred_masks = torch.sigmoid(pred_masks) # [bs*5, 1, 224, 224] 17 | indices = torch.tensor(list(range(0, len(pred_masks), 5))) 18 | indices = indices.cuda() 19 | 20 | first_pred = torch.index_select(pred_masks, dim=0, index=indices) # [bs, 1, 224, 224] 21 | # assert first_pred.requires_grad == True, "Error when indexing predited masks" 22 | if len(first_gt_mask.shape) == 5: 23 | first_gt_mask = first_gt_mask.squeeze(1) # [bs, 1, 224, 224] 24 | first_bce_loss = nn.BCELoss()(first_pred, first_gt_mask) 25 | 26 | return first_bce_loss 27 | 28 | 29 | 30 | def A_MaskedV_SimmLoss(pred_masks, a_fea_list, v_map_list, \ 31 | count_stages=[], \ 32 | mask_pooling_type='avg', norm_fea=True): 33 | """ 34 | [audio] - [masked visual feature map] matching loss, Loss_AVM_AV reported in the paper 35 | 36 | Args: 37 | pred_masks: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224] 38 | a_fea_list: audio feature list, lenth = nl_stages, each of shape: [bs, T, C], C is equal to [256] 39 | v_map_list: feature map list of the encoder or decoder output, each of shape: [bs*5, C, H, W], C is equal to [256] 40 | count_stages: loss is computed in these stages 41 | """ 42 | assert len(pred_masks.shape) == 4 43 | pred_masks = torch.sigmoid(pred_masks) # [B*5, 1, 224, 224] 44 | total_loss = 0 45 | for stage in count_stages: 46 | a_fea, v_map = a_fea_list[stage], v_map_list[stage] 47 | a_fea = a_fea.view(-1, a_fea.shape[-1]) # [B*5, C] 48 | 49 | C, H, W = v_map.shape[1], v_map.shape[-2], v_map.shape[-1] 50 | assert C == a_fea.shape[-1], 'Error: dimensions of audio and visual features are not equal' 51 | 52 | if mask_pooling_type == "avg": 53 | downsample_pred_masks = nn.AdaptiveAvgPool2d((H, W))(pred_masks) # [bs*5, 1, H, W] 54 | elif mask_pooling_type == 'max': 55 | downsample_pred_masks = nn.AdaptiveMaxPool2d((H, W))(pred_masks) # [bs*5, 1, H, W] 56 | downsample_pred_masks = (downsample_pred_masks > 0.5).float() # [bs*5, 1, H, W] 57 | 58 | obj_pixel_num = downsample_pred_masks.sum(-1).sum(-1) # [bs*5, 1] 59 | 60 | masked_v_map = torch.mul(v_map, downsample_pred_masks) # [bs*5, C, H, W] 61 | # masked_v_fea = masked_v_map.mean(-1).mean(-1) # [bs*5, C] 62 | masked_v_fea = masked_v_map.sum(-1).sum(-1) / (obj_pixel_num + 1e-6)# [bs*5, C] 63 | 64 | if norm_fea: 65 | a_fea = F.normalize(a_fea, dim=-1) 66 | masked_v_fea = F.normalize(masked_v_fea, dim=-1) 67 | 68 | cos_simm_va = torch.sum(torch.mul(masked_v_fea, a_fea), dim=-1) # [bs*5] 69 | cos_simm_va = F.relu(cos_simm_va) + 1e-6 70 | cos_simm_va = (-1) * cos_simm_va.log() 71 | loss = cos_simm_va.mean() 72 | total_loss += loss 73 | 74 | total_loss /= len(count_stages) 75 | 76 | return total_loss 77 | 78 | 79 | 80 | def IouSemanticAwareLoss(pred_masks, first_gt_mask, \ 81 | a_fea_list, v_map_list, \ 82 | lambda_1=0, count_stages=[], \ 83 | sa_loss_flag=False, mask_pooling_type='avg'): 84 | """ 85 | loss for single sound source segmentation 86 | 87 | Args: 88 | pred_masks: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224] 89 | first_gt_mask: ground truth mask of the first frame, shape: [bs, 1, 1, 224, 224] 90 | a_fea_list: feature list of audio features 91 | v_map_list: feature map list of the encoder or decoder output, each of shape: [bs*5, C, H, W] 92 | count_stages: additional constraint loss on which stages' visual-audio features 93 | """ 94 | total_loss = 0 95 | f1_iou_loss = F1_IoU_BCELoss(pred_masks, first_gt_mask) 96 | total_loss += f1_iou_loss 97 | # pdb.set_trace() 98 | 99 | if sa_loss_flag: 100 | sa_loss = A_MaskedV_SimmLoss(pred_masks, a_fea_list, v_map_list, count_stages, mask_pooling_type) 101 | total_loss += lambda_1 * sa_loss 102 | else: 103 | sa_loss = torch.zeros(1) 104 | 105 | # pdb.set_trace() 106 | loss_dict = {} 107 | loss_dict['iou_loss'] = f1_iou_loss.item() 108 | loss_dict['sa_loss'] = sa_loss.item() 109 | loss_dict['lambda_1'] = lambda_1 110 | 111 | return total_loss, loss_dict 112 | 113 | 114 | if __name__ == "__main__": 115 | 116 | pdb.set_trace() 117 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/CLUB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CLUBSample_group(nn.Module): # Sampled version of the CLUB estimator 7 | def __init__(self, x_dim, y_dim, hidden_size): 8 | super(CLUBSample_group, self).__init__() 9 | self.x_dim = x_dim 10 | self.y_dim = y_dim 11 | self.hidden_size = hidden_size 12 | self.p_mu = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size // 2), 13 | nn.ReLU(), 14 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 15 | nn.ReLU(), 16 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 17 | nn.ReLU(), 18 | nn.Linear(self.hidden_size // 2, self.y_dim)) 19 | 20 | self.p_logvar = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size // 2), 21 | nn.ReLU(), 22 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 23 | nn.ReLU(), 24 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 25 | nn.ReLU(), 26 | nn.Linear(self.hidden_size // 2, self.y_dim), 27 | nn.Tanh()) 28 | 29 | def get_mu_logvar(self, x_samples): 30 | mu = self.p_mu(x_samples) 31 | logvar = self.p_logvar(x_samples) 32 | return mu, logvar 33 | 34 | def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood 35 | mu, logvar = self.get_mu_logvar(x_samples) # mu/logvar: (bs, y_dim) 36 | # mu = mu.unsqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, mu.shape[ 37 | # -1]) # (bs, y_dim) -> (bs, 1, y_dim) -> (bs, T, y_dim) -> (bs*T, y_dim) 38 | mu = mu.reshape(-1, mu.shape[-1]) 39 | #logvar = logvar.unsqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, logvar.shape[-1]) 40 | logvar = logvar.reshape(-1, logvar.shape[-1]) 41 | y_samples = y_samples.reshape(-1, y_samples.shape[-1]) # (bs, T, y_dim) -> (bs*T, y_dim) 42 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) / 2 43 | 44 | def mi_est(self, x_samples, y_samples): # x_samples: (bs, x_dim); y_samples: (bs, T, y_dim) 45 | 46 | mu, logvar = self.get_mu_logvar(x_samples) 47 | 48 | sample_size = x_samples.shape[0] 49 | # random_index = torch.randint(sample_size, (sample_size,)).long() 50 | random_index = torch.randperm(sample_size).long() 51 | 52 | # log of conditional probability of positive sample pairs 53 | #mu_exp1 = mu.unsqueeze(1).expand(-1, y_samples.shape[1], -1) # (bs, y_dim) -> (bs, T, y_dim) 54 | mu_exp1 = mu 55 | 56 | # logvar_exp1 = logvar.unqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, logvar.shape[-1]) 57 | positive = - ((mu_exp1 - y_samples) ** 2).mean(dim=1) / logvar.mean(dim=1).exp() # mean along T 58 | negative = - ((mu_exp1 - y_samples[random_index]) ** 2).mean(dim=1) / logvar.mean(dim=1).exp() # mean along T 59 | 60 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() / 2 61 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/Dual_lstm.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import copy 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn import Module 7 | from torch.nn import MultiheadAttention 8 | from torch.nn import ModuleList 9 | from torch.nn.init import xavier_uniform_ 10 | from torch.nn import Dropout 11 | from torch.nn import Linear 12 | from torch.nn import LayerNorm 13 | import math 14 | from torch.autograd import Variable 15 | 16 | 17 | 18 | class Dual_lstm_cell(nn.Module): 19 | def __init__(self, visual_input_dim, audio_input_dim, hidden_dim, alph=0.5, bias=True): 20 | super(Dual_lstm_cell, self).__init__() 21 | 22 | self.visual_input_dim = visual_input_dim 23 | self.audio_input_dim = audio_input_dim 24 | self.hidden_dim = hidden_dim 25 | self.alph = alph 26 | self.vs_linear = nn.Linear(self.visual_input_dim, 4 * self.hidden_dim, bias=bias) 27 | self.vh_linear = nn.Linear(self.hidden_dim, 4* self.hidden_dim, bias=bias) 28 | self.as_linear = nn.Linear(self.audio_input_dim, 4 * self.hidden_dim, bias=bias) 29 | self.ah_linear = nn.Linear(self.hidden_dim, 4 * self.hidden_dim, bias=bias) 30 | 31 | self.as_linear2 = nn.Linear(self.audio_input_dim, 4*self.hidden_dim, bias=bias) 32 | self.ah_linear2 = nn.Linear(self.hidden_dim, 4*self.hidden_dim, bias=bias) 33 | self.vs_linear2 = nn.Linear(self.visual_input_dim, 4*self.hidden_dim, bias=bias) 34 | self.vh_linear2 = nn.Linear(self.hidden_dim, 4*self.hidden_dim, bias=bias) 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | std = 1.0 / math.sqrt(self.hidden_dim) 39 | for w in self.parameters(): 40 | w.data.uniform_(-std, std) 41 | 42 | def forward(self, visual_state, visual_hidden, visual_cell, audio_state, audio_hidden, audio_cell): 43 | visual_gates = self.vs_linear(visual_state) + self.vh_linear(visual_hidden) 44 | #self.alph*self.as_linear(audio_state) + self.alph*self.ah_linear(audio_hidden) 45 | 46 | 47 | audio_gates = self.as_linear2(audio_state) + self.ah_linear2(audio_hidden) 48 | #self.alph*self.vs_linear2(visual_state) + self.alph*self.vh_linear2(visual_hidden) 49 | 50 | visual_i_gate, visual_f_gate, visual_c_gate, visual_o_gate = visual_gates.chunk(4,1) 51 | audio_i_gate, audio_f_gate, audio_c_gate, audio_o_gate = audio_gates.chunk(4,1) 52 | 53 | visual_i_gate = F.sigmoid(visual_i_gate) 54 | visual_f_gate = F.sigmoid(visual_f_gate) 55 | visual_c_gate = F.tanh(visual_c_gate) 56 | visual_o_gate = F.sigmoid(visual_o_gate) 57 | 58 | visual_cell = visual_f_gate * visual_cell + visual_i_gate * visual_c_gate 59 | visual_output = visual_o_gate * torch.tanh(visual_cell) 60 | 61 | audio_i_gate = F.sigmoid(audio_i_gate) 62 | audio_f_gate = F.sigmoid(audio_f_gate) 63 | audio_c_gate = F.tanh(audio_c_gate) 64 | audio_o_gate = F.sigmoid(audio_o_gate) 65 | 66 | audio_cell = audio_f_gate * audio_cell + audio_i_gate * audio_c_gate 67 | audio_output = audio_o_gate * torch.tanh(audio_cell) 68 | 69 | return visual_output, visual_cell, audio_output, audio_cell 70 | 71 | class Dual_lstm(nn.Module): 72 | def __init__(self): 73 | 74 | super(Dual_lstm, self).__init__() 75 | 76 | self.video_input_dim = 512 77 | self.video_fc_dim = 512 78 | self.d_model = 256 79 | self.v_fc = nn.Linear(self.video_input_dim, self.video_fc_dim) 80 | self.LSTM_cell = Dual_lstm_cell(visual_input_dim=512, audio_input_dim=128, hidden_dim=256) 81 | #self.LSTM_cell_r = Dual_lstm_cell(visual_input_dim=512, audio_input_dim=128, hidden_dim=256) 82 | 83 | 84 | self.relu = nn.ReLU() 85 | self.dropout = nn.Dropout(0.2) 86 | 87 | 88 | def forward(self, audio_feature, visual_feature): 89 | audio_rnn_input = audio_feature 90 | 91 | visual_rnn_input = visual_feature 92 | 93 | if torch.cuda.is_available(): 94 | visual_hidden = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 95 | visual_hidden_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 96 | else: 97 | visual_hidden = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 98 | visual_hidden_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 99 | 100 | if torch.cuda.is_available(): 101 | visual_cell = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 102 | visual_cell_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 103 | else: 104 | visual_cell = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 105 | visual_cell_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 106 | 107 | if torch.cuda.is_available(): 108 | audio_hidden = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 109 | audio_hidden_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 110 | else: 111 | audio_hidden = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 112 | audio_hidden_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 113 | 114 | if torch.cuda.is_available(): 115 | audio_cell = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 116 | audio_cell_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 117 | else: 118 | audio_cell = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 119 | audio_cell_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 120 | 121 | visual_output = [] 122 | audio_output = [] 123 | visual_output_r = [] 124 | audio_output_r = [] 125 | length = visual_rnn_input.size(1) 126 | 127 | visual_hidden = visual_hidden.double() 128 | visual_cell = visual_cell.double() 129 | audio_hidden = audio_hidden.double() 130 | audio_cell = audio_cell.double() 131 | visual_hidden_r = visual_hidden_r.double() 132 | visual_cell_r = visual_cell_r.double() 133 | audio_hidden_r = audio_hidden_r.double() 134 | audio_cell_r = audio_cell_r.double() 135 | 136 | 137 | for i in range(length): 138 | visual_hidden, visual_cell, audio_hidden, audio_cell = self.LSTM_cell(visual_rnn_input[:,i,:], visual_hidden, visual_cell, 139 | audio_rnn_input[:,i,:], audio_hidden, audio_cell) 140 | visual_output.append(visual_hidden) 141 | audio_output.append(audio_hidden) 142 | 143 | visual_output = torch.stack(visual_output,dim=1) 144 | audio_output = torch.stack(audio_output, dim=1) 145 | 146 | 147 | # for i in range(length): 148 | # visual_hidden_r, visual_cell_r, audio_hidden_r, audio_cell_r = self.LSTM_cell_r(visual_rnn_input[:,length-1-i,:], visual_hidden_r, 149 | # visual_cell_r, audio_rnn_input[:,length-1-i,:], 150 | # audio_hidden_r, audio_cell_r) 151 | # visual_output_r.append(visual_hidden_r) 152 | # audio_output_r.append(audio_hidden_r) 153 | 154 | # visual_output_r = torch.stack(visual_output_r, dim=1) 155 | # visual_output_r = torch.flip(visual_output_r, dims=[1]) 156 | # audio_output_r = torch.stack(audio_output_r, dim=1) 157 | # audio_output_r = torch.flip(audio_output_r, dims=[1]) 158 | # visual_output = torch.cat((visual_output, visual_output_r), dim=2) 159 | # audio_output = torch.cat((audio_output, audio_output_r), dim=2) 160 | return audio_output, visual_output 161 | 162 | 163 | # model = Dual_lstm() 164 | # visual_feature = torch.randn(32, 10,512) 165 | # audio_feature = torch.randn(32, 10, 128) 166 | # model(audio_feature, visual_feature) 167 | # 168 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/PVT_AVSModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from model.pvt import pvt_v2_b5 5 | from model.TPAVI import TPAVIModule 6 | import pdb 7 | 8 | 9 | class Classifier_Module(nn.Module): 10 | def __init__(self, dilation_series, padding_series, NoLabels, input_channel): 11 | super(Classifier_Module, self).__init__() 12 | self.conv2d_list = nn.ModuleList() 13 | for dilation, padding in zip(dilation_series, padding_series): 14 | self.conv2d_list.append(nn.Conv2d(input_channel, NoLabels, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True)) 15 | for m in self.conv2d_list: 16 | m.weight.data.normal_(0, 0.01) 17 | 18 | def forward(self, x): 19 | out = self.conv2d_list[0](x) 20 | for i in range(len(self.conv2d_list)-1): 21 | out += self.conv2d_list[i+1](x) 22 | return out 23 | 24 | 25 | class BasicConv2d(nn.Module): 26 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 27 | super(BasicConv2d, self).__init__() 28 | self.conv_bn = nn.Sequential( 29 | nn.Conv2d(in_planes, out_planes, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, bias=False), 32 | nn.BatchNorm2d(out_planes) 33 | ) 34 | 35 | def forward(self, x): 36 | x = self.conv_bn(x) 37 | return x 38 | 39 | 40 | class ResidualConvUnit(nn.Module): 41 | """Residual convolution module. 42 | """ 43 | 44 | def __init__(self, features): 45 | """Init. 46 | Args: 47 | features (int): number of features 48 | """ 49 | super().__init__() 50 | 51 | self.conv1 = nn.Conv2d( 52 | features, features, kernel_size=3, stride=1, padding=1, bias=True 53 | ) 54 | self.conv2 = nn.Conv2d( 55 | features, features, kernel_size=3, stride=1, padding=1, bias=True 56 | ) 57 | self.relu = nn.ReLU(inplace=True) 58 | 59 | def forward(self, x): 60 | """Forward pass. 61 | Args: 62 | x (tensor): input 63 | Returns: 64 | tensor: output 65 | """ 66 | out = self.relu(x) 67 | out = self.conv1(out) 68 | out = self.relu(out) 69 | out = self.conv2(out) 70 | 71 | return out + x 72 | 73 | class FeatureFusionBlock(nn.Module): 74 | """Feature fusion block. 75 | """ 76 | 77 | def __init__(self, features): 78 | """Init. 79 | Args: 80 | features (int): number of features 81 | """ 82 | super(FeatureFusionBlock, self).__init__() 83 | 84 | self.resConfUnit1 = ResidualConvUnit(features) 85 | self.resConfUnit2 = ResidualConvUnit(features) 86 | 87 | def forward(self, *xs): 88 | """Forward pass. 89 | Returns: 90 | tensor: output 91 | """ 92 | output = xs[0] 93 | 94 | if len(xs) == 2: 95 | output += self.resConfUnit1(xs[1]) 96 | 97 | output = self.resConfUnit2(output) 98 | 99 | output = nn.functional.interpolate( 100 | output, scale_factor=2, mode="bilinear", align_corners=True 101 | ) 102 | 103 | return output 104 | 105 | 106 | class Interpolate(nn.Module): 107 | """Interpolation module. 108 | """ 109 | 110 | def __init__(self, scale_factor, mode, align_corners=False): 111 | """Init. 112 | Args: 113 | scale_factor (float): scaling 114 | mode (str): interpolation mode 115 | """ 116 | super(Interpolate, self).__init__() 117 | 118 | self.interp = nn.functional.interpolate 119 | self.scale_factor = scale_factor 120 | self.mode = mode 121 | self.align_corners = align_corners 122 | 123 | def forward(self, x): 124 | """Forward pass. 125 | Args: 126 | x (tensor): input 127 | Returns: 128 | tensor: interpolated data 129 | """ 130 | 131 | x = self.interp( 132 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 133 | ) 134 | 135 | return x 136 | 137 | class Pred_endecoder(nn.Module): 138 | # pvt-v2 based encoder decoder 139 | def __init__(self, channel=256, config=None, vis_dim=[64, 128, 320, 512], tpavi_stages=[], tpavi_vv_flag=False, tpavi_va_flag=True): 140 | super(Pred_endecoder, self).__init__() 141 | self.cfg = config 142 | self.tpavi_stages = tpavi_stages 143 | self.tpavi_vv_flag = tpavi_vv_flag 144 | self.tpavi_va_flag = tpavi_va_flag 145 | self.vis_dim = vis_dim 146 | 147 | self.encoder_backbone = pvt_v2_b5() 148 | self.relu = nn.ReLU(inplace=True) 149 | 150 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 151 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 152 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 153 | self.upsample05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 154 | self.upsample025 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True) 155 | 156 | self.conv4 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, self.vis_dim[3]) 157 | self.conv3 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, self.vis_dim[2]) 158 | self.conv2 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, self.vis_dim[1]) 159 | self.conv1 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, self.vis_dim[0]) 160 | 161 | self.path4 = FeatureFusionBlock(channel) 162 | self.path3 = FeatureFusionBlock(channel) 163 | self.path2 = FeatureFusionBlock(channel) 164 | self.path1 = FeatureFusionBlock(channel) 165 | 166 | for i in self.tpavi_stages: 167 | setattr(self, f"tpavi_b{i+1}", TPAVIModule(in_channels=channel, mode='dot')) 168 | print("==> Build TPAVI block...") 169 | 170 | self.output_conv = nn.Sequential( 171 | nn.Conv2d(channel, 128, kernel_size=3, stride=1, padding=1), 172 | Interpolate(scale_factor=2, mode="bilinear"), 173 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 174 | nn.ReLU(True), 175 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 176 | ) 177 | 178 | if self.training: 179 | self.initialize_pvt_weights() 180 | 181 | 182 | def pre_reshape_for_tpavi(self, x): 183 | # x: [B*5, C, H, W] 184 | _, C, H, W = x.shape 185 | x = x.reshape(-1, 5, C, H, W) 186 | x = x.permute(0, 2, 1, 3, 4).contiguous() # [B, C, T, H, W] 187 | return x 188 | 189 | def post_reshape_for_tpavi(self, x): 190 | # x: [B, C, T, H, W] 191 | # return: [B*T, C, H, W] 192 | _, C, _, H, W = x.shape 193 | x = x.permute(0, 2, 1, 3, 4) # [B, T, C, H, W] 194 | x = x.view(-1, C, H, W) 195 | return x 196 | 197 | def tpavi_vv(self, x, stage): 198 | # x: visual, [B*T, C=256, H, W] 199 | tpavi_b = getattr(self, f'tpavi_b{stage+1}') 200 | x = self.pre_reshape_for_tpavi(x) # [B, C, T, H, W] 201 | x, _ = tpavi_b(x) # [B, C, T, H, W] 202 | x = self.post_reshape_for_tpavi(x) # [B*T, C, H, W] 203 | return x 204 | 205 | def tpavi_va(self, x, audio, stage): 206 | # x: visual, [B*T, C=256, H, W] 207 | # audio: [B*T, 128] 208 | # ra_flag: return audio feature list or not 209 | tpavi_b = getattr(self, f'tpavi_b{stage+1}') 210 | # print(audio.size()) 211 | audio = audio.reshape(-1, 5, audio.shape[-1]) # [B, T, 128] 212 | x = self.pre_reshape_for_tpavi(x) # [B, C, T, H, W] 213 | x, a = tpavi_b(x, audio) # [B, C, T, H, W], [B, T, C] 214 | x = self.post_reshape_for_tpavi(x) # [B*T, C, H, W] 215 | return x, a 216 | 217 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 218 | return block(dilation_series, padding_series, NoLabels, input_channel) 219 | 220 | def forward(self, x, audio_feature=None): 221 | x1, x2, x3, x4 = self.encoder_backbone(x) 222 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 223 | # shape for pvt-v2-b5 224 | # BF x 64 x 56 x 56 225 | # BF x 128 x 28 x 28 226 | # BF x 320 x 14 x 14 227 | # BF x 512 x 7 x 7 228 | 229 | conv1_feat = self.conv1(x1) # BF x 256 x 56 x 56 230 | conv2_feat = self.conv2(x2) # BF x 256 x 28 x 28 231 | conv3_feat = self.conv3(x3) # BF x 256 x 14 x 14 232 | conv4_feat = self.conv4(x4) # BF x 256 x 7 x 7 233 | # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape) 234 | 235 | feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat] 236 | a_fea_list = [None] * 4 237 | 238 | if len(self.tpavi_stages) > 0: 239 | if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag): 240 | raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \ 241 | tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)') 242 | for i in self.tpavi_stages: 243 | tpavi_count = 0 244 | conv_feat = torch.zeros_like(feature_map_list[i]).cuda() 245 | if self.tpavi_vv_flag: 246 | conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i) 247 | conv_feat += conv_feat_vv 248 | tpavi_count += 1 249 | if self.tpavi_va_flag: 250 | conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i) 251 | conv_feat += conv_feat_va 252 | tpavi_count += 1 253 | a_fea_list[i] = a_fea 254 | conv_feat /= tpavi_count 255 | feature_map_list[i] = conv_feat # update features of stage-i which conduct non-local 256 | 257 | conv4_feat = self.path4(feature_map_list[3]) # BF x 256 x 14 x 14 258 | conv43 = self.path3(conv4_feat, feature_map_list[2]) # BF x 256 x 28 x 28 259 | conv432 = self.path2(conv43, feature_map_list[1]) # BF x 256 x 56 x 56 260 | conv4321 = self.path1(conv432, feature_map_list[0]) # BF x 256 x 112 x 112 261 | 262 | pred = self.output_conv(conv4321) # BF x 1 x 224 x 224 263 | # print(pred.shape) 264 | 265 | return pred, feature_map_list, a_fea_list 266 | 267 | 268 | def initialize_pvt_weights(self,): 269 | pvt_model_dict = self.encoder_backbone.state_dict() 270 | pretrained_state_dicts = torch.load(self.cfg.TRAIN.PRETRAINED_PVTV2_PATH) 271 | # for k, v in pretrained_state_dicts['model'].items(): 272 | # if k in pvt_model_dict.keys(): 273 | # print(k, v.requires_grad) 274 | state_dict = {k : v for k, v in pretrained_state_dicts.items() if k in pvt_model_dict.keys()} 275 | pvt_model_dict.update(state_dict) 276 | self.encoder_backbone.load_state_dict(pvt_model_dict) 277 | print(f'==> Load pvt-v2-b5 parameters pretrained on ImageNet from {self.cfg.TRAIN.PRETRAINED_PVTV2_PATH}') 278 | # pdb.set_trace() 279 | 280 | 281 | if __name__ == "__main__": 282 | imgs = torch.randn(10, 3, 224, 224) 283 | audio = torch.randn(2, 5, 128) 284 | # model = Pred_endecoder(channel=256) 285 | model = Pred_endecoder(channel=256, tpavi_stages=[0,1,2,3], tpavi_va_flag=True,) 286 | # output = model(imgs) 287 | output = model(imgs, audio) 288 | pdb.set_trace() -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/ResNet_AVSModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from model.resnet import B2_ResNet 5 | from model.TPAVI import TPAVIModule 6 | import pdb 7 | 8 | 9 | class Classifier_Module(nn.Module): 10 | def __init__(self, dilation_series, padding_series, NoLabels, input_channel): 11 | super(Classifier_Module, self).__init__() 12 | self.conv2d_list = nn.ModuleList() 13 | for dilation, padding in zip(dilation_series, padding_series): 14 | self.conv2d_list.append(nn.Conv2d(input_channel, NoLabels, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True)) 15 | for m in self.conv2d_list: 16 | m.weight.data.normal_(0, 0.01) 17 | 18 | def forward(self, x): 19 | out = self.conv2d_list[0](x) 20 | for i in range(len(self.conv2d_list)-1): 21 | out += self.conv2d_list[i+1](x) 22 | return out 23 | 24 | 25 | class BasicConv2d(nn.Module): 26 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 27 | super(BasicConv2d, self).__init__() 28 | self.conv_bn = nn.Sequential( 29 | nn.Conv2d(in_planes, out_planes, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, bias=False), 32 | nn.BatchNorm2d(out_planes) 33 | ) 34 | 35 | def forward(self, x): 36 | x = self.conv_bn(x) 37 | return x 38 | 39 | 40 | class ResidualConvUnit(nn.Module): 41 | """Residual convolution module. 42 | """ 43 | 44 | def __init__(self, features): 45 | """Init. 46 | Args: 47 | features (int): number of features 48 | """ 49 | super().__init__() 50 | 51 | self.conv1 = nn.Conv2d( 52 | features, features, kernel_size=3, stride=1, padding=1, bias=True 53 | ) 54 | self.conv2 = nn.Conv2d( 55 | features, features, kernel_size=3, stride=1, padding=1, bias=True 56 | ) 57 | self.relu = nn.ReLU(inplace=True) 58 | 59 | def forward(self, x): 60 | """Forward pass. 61 | Args: 62 | x (tensor): input 63 | Returns: 64 | tensor: output 65 | """ 66 | out = self.relu(x) 67 | out = self.conv1(out) 68 | out = self.relu(out) 69 | out = self.conv2(out) 70 | 71 | return out + x 72 | 73 | class FeatureFusionBlock(nn.Module): 74 | """Feature fusion block. 75 | """ 76 | 77 | def __init__(self, features): 78 | """Init. 79 | Args: 80 | features (int): number of features 81 | """ 82 | super(FeatureFusionBlock, self).__init__() 83 | 84 | self.resConfUnit1 = ResidualConvUnit(features) 85 | self.resConfUnit2 = ResidualConvUnit(features) 86 | 87 | def forward(self, *xs): 88 | """Forward pass. 89 | Returns: 90 | tensor: output 91 | """ 92 | output = xs[0] 93 | 94 | if len(xs) == 2: 95 | output += self.resConfUnit1(xs[1]) 96 | 97 | output = self.resConfUnit2(output) 98 | 99 | output = nn.functional.interpolate( 100 | output, scale_factor=2, mode="bilinear", align_corners=True 101 | ) 102 | 103 | return output 104 | 105 | 106 | class Interpolate(nn.Module): 107 | """Interpolation module. 108 | """ 109 | 110 | def __init__(self, scale_factor, mode, align_corners=False): 111 | """Init. 112 | Args: 113 | scale_factor (float): scaling 114 | mode (str): interpolation mode 115 | """ 116 | super(Interpolate, self).__init__() 117 | 118 | self.interp = nn.functional.interpolate 119 | self.scale_factor = scale_factor 120 | self.mode = mode 121 | self.align_corners = align_corners 122 | 123 | def forward(self, x): 124 | """Forward pass. 125 | Args: 126 | x (tensor): input 127 | Returns: 128 | tensor: interpolated data 129 | """ 130 | 131 | x = self.interp( 132 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 133 | ) 134 | 135 | return x 136 | 137 | 138 | class Pred_endecoder(nn.Module): 139 | # resnet based encoder decoder 140 | def __init__(self, channel=256, config=None, tpavi_stages=[], tpavi_vv_flag=False, tpavi_va_flag=True): 141 | super(Pred_endecoder, self).__init__() 142 | self.cfg = config 143 | self.tpavi_stages = tpavi_stages 144 | self.tpavi_vv_flag = tpavi_vv_flag 145 | self.tpavi_va_flag = tpavi_va_flag 146 | 147 | self.resnet = B2_ResNet() 148 | self.relu = nn.ReLU(inplace=True) 149 | 150 | self.conv4 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 2048) 151 | self.conv3 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 1024) 152 | self.conv2 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 512) 153 | self.conv1 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 256) 154 | 155 | self.path4 = FeatureFusionBlock(channel) 156 | self.path3 = FeatureFusionBlock(channel) 157 | self.path2 = FeatureFusionBlock(channel) 158 | self.path1 = FeatureFusionBlock(channel) 159 | 160 | for i in self.tpavi_stages: 161 | setattr(self, f"tpavi_b{i+1}", TPAVIModule(in_channels=channel, mode='dot')) 162 | print("==> Build TPAVI block...") 163 | 164 | self.output_conv = nn.Sequential( 165 | nn.Conv2d(channel, 128, kernel_size=3, stride=1, padding=1), 166 | Interpolate(scale_factor=2, mode="bilinear"), 167 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 168 | nn.ReLU(True), 169 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 170 | ) 171 | 172 | if self.training: 173 | self.initialize_weights() 174 | 175 | 176 | def pre_reshape_for_tpavi(self, x): 177 | # x: [B*5, C, H, W] 178 | _, C, H, W = x.shape 179 | x = x.reshape(-1, 5, C, H, W) 180 | x = x.permute(0, 2, 1, 3, 4).contiguous() # [B, C, T, H, W] 181 | return x 182 | 183 | def post_reshape_for_tpavi(self, x): 184 | # x: [B, C, T, H, W] 185 | # return: [B*T, C, H, W] 186 | _, C, _, H, W = x.shape 187 | x = x.permute(0, 2, 1, 3, 4) # [B, T, C, H, W] 188 | x = x.view(-1, C, H, W) 189 | return x 190 | 191 | def tpavi_vv(self, x, stage): 192 | # x: visual, [B*T, C=256, H, W] 193 | tpavi_b = getattr(self, f'tpavi_b{stage+1}') 194 | x = self.pre_reshape_for_tpavi(x) # [B, C, T, H, W] 195 | x, _ = tpavi_b(x) # [B, C, T, H, W] 196 | x = self.post_reshape_for_tpavi(x) # [B*T, C, H, W] 197 | return x 198 | 199 | def tpavi_va(self, x, audio, stage): 200 | # x: visual, [B*T, C=256, H, W] 201 | # audio: [B*T, 128] 202 | # ra_flag: return audio feature list or not 203 | tpavi_b = getattr(self, f'tpavi_b{stage+1}') 204 | audio = audio.view(-1, 5, audio.shape[-1]) # [B, T, 128] 205 | x = self.pre_reshape_for_tpavi(x) # [B, C, T, H, W] 206 | x, a = tpavi_b(x, audio) # [B, C, T, H, W], [B, T, C] 207 | x = self.post_reshape_for_tpavi(x) # [B*T, C, H, W] 208 | return x, a 209 | 210 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 211 | return block(dilation_series, padding_series, NoLabels, input_channel) 212 | 213 | def forward(self, x, audio_feature=None): 214 | x = self.resnet.conv1(x) 215 | x = self.resnet.bn1(x) 216 | x = self.resnet.relu(x) 217 | x = self.resnet.maxpool(x) 218 | x1 = self.resnet.layer1(x) # BF x 256 x 56 x 56 219 | x2 = self.resnet.layer2(x1) # BF x 512 x 28 x 28 220 | x3 = self.resnet.layer3_1(x2) # BF x 1024 x 14 x 14 221 | x4 = self.resnet.layer4_1(x3) # BF x 2048 x 7 x 7 222 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 223 | 224 | conv1_feat = self.conv1(x1) # BF x 256 x 56 x 56 225 | conv2_feat = self.conv2(x2) # BF x 256 x 28 x 28 226 | conv3_feat = self.conv3(x3) # BF x 256 x 14 x 14 227 | conv4_feat = self.conv4(x4) # BF x 256 x 7 x 7 228 | # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape) 229 | 230 | feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat] 231 | a_fea_list = [None] * 4 232 | 233 | if len(self.tpavi_stages) > 0: 234 | if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag): 235 | raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \ 236 | tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)') 237 | for i in self.tpavi_stages: 238 | tpavi_count = 0 239 | conv_feat = torch.zeros_like(feature_map_list[i]).cuda() 240 | if self.tpavi_vv_flag: 241 | conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i) 242 | conv_feat += conv_feat_vv 243 | tpavi_count += 1 244 | if self.tpavi_va_flag: 245 | conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i) 246 | conv_feat += conv_feat_va 247 | tpavi_count += 1 248 | a_fea_list[i] = a_fea 249 | conv_feat /= tpavi_count 250 | feature_map_list[i] = conv_feat # update features of stage-i which conduct TPAVI 251 | 252 | conv4_feat = self.path4(feature_map_list[3]) # BF x 256 x 14 x 14 253 | conv43 = self.path3(conv4_feat, feature_map_list[2]) # BF x 256 x 28 x 28 254 | conv432 = self.path2(conv43, feature_map_list[1]) # BF x 256 x 56 x 56 255 | conv4321 = self.path1(conv432, feature_map_list[0]) # BF x 256 x 112 x 112 256 | # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape) 257 | 258 | pred = self.output_conv(conv4321) # BF x 1 x 224 x 224 259 | # print(pred.shape) 260 | 261 | return pred, feature_map_list, a_fea_list 262 | 263 | 264 | def initialize_weights(self): 265 | res50 = models.resnet50(pretrained=False) 266 | resnet50_dict = torch.load(self.cfg.TRAIN.PRETRAINED_RESNET50_PATH) 267 | res50.load_state_dict(resnet50_dict) 268 | pretrained_dict = res50.state_dict() 269 | # print(pretrained_dict.keys()) 270 | all_params = {} 271 | for k, v in self.resnet.state_dict().items(): 272 | if k in pretrained_dict.keys(): 273 | v = pretrained_dict[k] 274 | all_params[k] = v 275 | elif '_1' in k: 276 | name = k.split('_1')[0] + k.split('_1')[1] 277 | v = pretrained_dict[name] 278 | all_params[k] = v 279 | elif '_2' in k: 280 | name = k.split('_2')[0] + k.split('_2')[1] 281 | v = pretrained_dict[name] 282 | all_params[k] = v 283 | assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) 284 | self.resnet.load_state_dict(all_params) 285 | print(f'==> Load pretrained ResNet50 parameters from {self.cfg.TRAIN.PRETRAINED_RESNET50_PATH}') 286 | 287 | 288 | if __name__ == "__main__": 289 | imgs = torch.randn(10, 3, 224, 224) 290 | model = Pred_endecoder(channel=256, tpavi_stages=[0,1,2,3], tpavi_va_flag=True) 291 | output = model(imgs) 292 | pdb.set_trace() -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/TPAVI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class TPAVIModule(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, mode='dot', 8 | dimension=3, bn_layer=True): 9 | """ 10 | args: 11 | in_channels: original channel size (1024 in the paper) 12 | inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) 13 | mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 14 | dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) 15 | bn_layer: whether to add batch norm 16 | """ 17 | super(TPAVIModule, self).__init__() 18 | 19 | assert dimension in [1, 2, 3] 20 | 21 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 22 | raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 23 | 24 | self.mode = mode 25 | self.dimension = dimension 26 | 27 | self.in_channels = in_channels 28 | self.inter_channels = inter_channels 29 | 30 | # the channel size is reduced to half inside the block 31 | if self.inter_channels is None: 32 | self.inter_channels = in_channels // 2 33 | if self.inter_channels == 0: 34 | self.inter_channels = 1 35 | 36 | ## add align channel 37 | self.align_channel = nn.Linear(256, in_channels)#origin:128 38 | self.norm_layer=nn.LayerNorm(in_channels) 39 | 40 | # assign appropriate convolutional, max pool, and batch norm layers for different dimensions 41 | if dimension == 3: 42 | conv_nd = nn.Conv3d 43 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 44 | bn = nn.BatchNorm3d 45 | elif dimension == 2: 46 | conv_nd = nn.Conv2d 47 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 48 | bn = nn.BatchNorm2d 49 | else: 50 | conv_nd = nn.Conv1d 51 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 52 | bn = nn.BatchNorm1d 53 | 54 | # function g in the paper which goes through conv. with kernel size 1 55 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 56 | 57 | if bn_layer: 58 | self.W_z = nn.Sequential( 59 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), 60 | bn(self.in_channels) 61 | ) 62 | nn.init.constant_(self.W_z[1].weight, 0) 63 | nn.init.constant_(self.W_z[1].bias, 0) 64 | else: 65 | self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1) 66 | 67 | nn.init.constant_(self.W_z.weight, 0) 68 | nn.init.constant_(self.W_z.bias, 0) 69 | 70 | # define theta and phi for all operations except gaussian 71 | if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate": 72 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 73 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 74 | 75 | if self.mode == "concatenate": 76 | self.W_f = nn.Sequential( 77 | nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1), 78 | nn.ReLU() 79 | ) 80 | 81 | 82 | def forward(self, x, audio=None): 83 | """ 84 | args: 85 | x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 86 | audio: (N, T, C) 87 | """ 88 | 89 | audio_temp = 0 90 | batch_size, C = x.size(0), x.size(1) 91 | if audio is not None: 92 | # print('==> audio.shape', audio.shape) 93 | H, W = x.shape[-2], x.shape[-1] 94 | audio_temp = self.align_channel(audio) # [bs, T, C] 95 | audio = audio_temp.permute(0, 2, 1) # [bs, C, T] 96 | audio = audio.unsqueeze(-1).unsqueeze(-1) # [bs, C, T, 1, 1] 97 | audio = audio.repeat(1, 1, 1, H, W) # [bs, C, T, H, W] 98 | else: 99 | audio = x 100 | 101 | # (N, C, THW) 102 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [bs, C, THW] 103 | # print('g_x.shape', g_x.shape) 104 | # g_x = x.view(batch_size, C, -1) # [bs, C, THW] 105 | g_x = g_x.permute(0, 2, 1) # [bs, THW, C] 106 | 107 | if self.mode == "gaussian": 108 | theta_x = x.view(batch_size, self.in_channels, -1) 109 | phi_x = audio.view(batch_size, self.in_channels, -1) 110 | theta_x = theta_x.permute(0, 2, 1) 111 | f = torch.matmul(theta_x, phi_x) 112 | 113 | elif self.mode == "embedded" or self.mode == "dot": 114 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [bs, C', THW] 115 | phi_x = self.phi(audio).view(batch_size, self.inter_channels, -1) # [bs, C', THW] 116 | theta_x = theta_x.permute(0, 2, 1) # [bs, THW, C'] 117 | f = torch.matmul(theta_x, phi_x) # [bs, THW, THW] 118 | 119 | elif self.mode == "concatenate": 120 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 121 | phi_x = self.phi(audio).view(batch_size, self.inter_channels, 1, -1) 122 | 123 | h = theta_x.size(2) 124 | w = phi_x.size(3) 125 | theta_x = theta_x.repeat(1, 1, 1, w) 126 | phi_x = phi_x.repeat(1, 1, h, 1) 127 | 128 | concat = torch.cat([theta_x, phi_x], dim=1) 129 | f = self.W_f(concat) 130 | f = f.view(f.size(0), f.size(2), f.size(3)) 131 | 132 | if self.mode == "gaussian" or self.mode == "embedded": 133 | f_div_C = F.softmax(f, dim=-1) 134 | elif self.mode == "dot" or self.mode == "concatenate": 135 | N = f.size(-1) # number of position in x 136 | f_div_C = f / N # [bs, THW, THW] 137 | 138 | y = torch.matmul(f_div_C, g_x) # [bs, THW, C] 139 | 140 | # contiguous here just allocates contiguous chunk of memory 141 | y = y.permute(0, 2, 1).contiguous() # [bs, C, THW] 142 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # [bs, C', T, H, W] 143 | 144 | W_y = self.W_z(y) # [bs, C, T, H, W] 145 | # residual connection 146 | z = W_y + x # # [bs, C, T, H, W] 147 | 148 | # add LayerNorm 149 | z = z.permute(0, 2, 3, 4, 1) # [bs, T, H, W, C] 150 | z = self.norm_layer(z) 151 | z = z.permute(0, 4, 1, 2, 3) # [bs, C, T, H, W] 152 | 153 | return z, audio_temp 154 | 155 | 156 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/UniEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.transformer import TransformerEncoder 4 | 5 | class UNIEncoder(nn.Module): 6 | def __init__(self): 7 | super(UNIEncoder, self).__init__() 8 | # define transformer head 9 | self.tx = TransformerEncoder(d_model=256, 10 | d_kv=64, 11 | d_ff=4096, 12 | num_layers=24, 13 | num_heads=16, 14 | pre_norm=True, 15 | use_bias=True, 16 | activation="gelu", 17 | dropout_rate=0.1, 18 | layer_norm_epsilon=1e-6) 19 | 20 | # define post-tx projection head - it could be logits or embd space 21 | self.post_proj = nn.ModuleDict({# ReLU or GELU 22 | "video": nn.Sequential(nn.Linear(256, 256),nn.GELU()),#d_model=256 d_post_proj=256 23 | "audio": nn.Sequential(nn.Linear(256, 256),nn.GELU()) 24 | }) 25 | 26 | def _flatten_inputs(self, 27 | inputs): 28 | input_shape = inputs.shape 29 | bs = inputs.shape[0] 30 | d_embd = inputs.shape[-1] 31 | inputs = inputs.view(bs, -1, d_embd) 32 | 33 | return inputs, input_shape 34 | 35 | def _append_special_tokens(self, 36 | inputs, 37 | modality): 38 | batch_size = inputs.shape[0] 39 | agg_token = { 40 | "video": torch.nn.Parameter(torch.Tensor(256,)),#d_model 41 | "audio": torch.nn.Parameter(torch.Tensor(256,)), 42 | } 43 | special_embd = agg_token[modality][None, None, :].to(inputs.device) 44 | special_embd = special_embd.repeat(batch_size, 1, 1) 45 | 46 | return torch.cat([special_embd, inputs], dim=1) 47 | 48 | def _extend_attn_mask(self, 49 | attention_mask): 50 | attn_mask_shape = attention_mask.shape 51 | if len(attn_mask_shape) > 2: 52 | raise NotImplementedError 53 | 54 | batch_size = attn_mask_shape[0] 55 | extention_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype) 56 | extended_attention_mask = torch.cat([extention_mask, attention_mask], dim=1) 57 | return extended_attention_mask 58 | 59 | def _modality_call(self, 60 | inputs, 61 | modality, 62 | training=False, 63 | attention_mask=None, 64 | input_shape=None): 65 | embeddings = inputs 66 | if input_shape is None: 67 | embeddings, input_shape = self._flatten_inputs(embeddings) 68 | 69 | # print("pool:",embeddings) 70 | # print(features) 71 | 72 | # append modalities special tokens: [vid, aud, txt] 73 | tx_inputs = self._append_special_tokens(embeddings, modality) 74 | print("pool:",embeddings) 75 | 76 | # extend attention_mask accordingly 77 | if attention_mask is not None: 78 | attention_mask = self._extend_attn_mask(attention_mask) 79 | 80 | # call Transformer 81 | tx_outputs = self.tx(tx_inputs, attention_mask) 82 | 83 | # get last hidden states and perform final linear projection 84 | last_hidden_states = tx_outputs["hidden_states"][-1] 85 | modality_outputs = self.post_proj[modality](last_hidden_states) 86 | output_shape = list(input_shape[:-1]) + [modality_outputs.shape[-1]] 87 | # output_shape = list(256) + [modality_outputs.shape[-1]] 88 | 89 | features_pooled = modality_outputs[:, 0, :] 90 | features = modality_outputs[:, 1:, :].reshape(output_shape) 91 | 92 | # print("pool:",features_pooled) 93 | # print(features) 94 | 95 | # add token-level Transformer outputs 96 | outputs = {"features_pooled": features_pooled, 97 | "features": features} 98 | 99 | return outputs 100 | 101 | def forward(self, video_semantic_result, audio_semantic_result): 102 | 103 | """ 104 | outputs = {"features_pooled": features_pooled, 105 | "features": features} 106 | """ 107 | 108 | 109 | 110 | video_outputs = self._modality_call(inputs=video_semantic_result, 111 | modality='video', 112 | training=self.training, 113 | attention_mask=None) 114 | audio_outputs = self._modality_call(inputs=audio_semantic_result, 115 | modality='audio', 116 | training=self.training, 117 | attention_mask=None) 118 | 119 | # print("video_outputs:",video_outputs["features"].size(), video_outputs["features"].dtype) 120 | # print("video_semantic_result:",video_semantic_result.size(), video_semantic_result.dtype) 121 | 122 | """features_pooled可以拿来算infonce,现在还没用上""" 123 | 124 | # print(video_semantic_result) 125 | 126 | return video_outputs["features"], audio_outputs["features"] 127 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/PVT_AVSModel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/PVT_AVSModel.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/TPAVI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/TPAVI.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/main_model_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/main_model_2.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/mine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/mine.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/pvt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/model/__pycache__/pvt.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/mine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # class Mine(nn.Module): 7 | # def __init__(self): 8 | # super(Mine, self).__init__() 9 | # self.fc1_x = nn.Linear(2048, 512) 10 | # self.fc1_y = nn.Linear(2048, 512) 11 | # self.fc2 = nn.Linear(512,1) 12 | # def forward(self, x,y): 13 | # h1 = F.leaky_relu(self.fc1_x(x)+self.fc1_y(y)) 14 | # h2 = self.fc2(h1) 15 | # return h2 16 | # 17 | # Mine = Mine() 18 | # def mi_estimator(x, y, y_): 19 | # 20 | # joint, marginal = Mine(x, y), Mine(x, y_) 21 | # return torch.mean(joint) - torch.log(torch.mean(torch.exp(marginal))) 22 | 23 | # x = torch.randn(32, 10, 2048) 24 | # y = torch.randn(32, 10, 2048) 25 | # y_ = torch.randn(32, 10, 2048) 26 | # joint, marginal = Mine(x, y), Mine(x, y_) 27 | # loss = torch.mean(joint) - torch.log(torch.mean(torch.exp(marginal))) 28 | # print(loss) 29 | 30 | # class Mine2(nn.Module): 31 | # def __init__(self, x_dim, y_dim, hidden_dim): 32 | # super(Mine2, self).__init__() 33 | 34 | # 35 | # 36 | # class MINE(nn.Module): 37 | # def __init__(self, hidden_size=256): 38 | # super(MINE, self).__init__() 39 | # self.layers = nn.Sequential(nn.Linear(512, hidden_size), 40 | # nn.ReLU(), 41 | # nn.Linear(hidden_size, 1)) 42 | # 43 | # def forward(self, x, y): 44 | # batch_size = x.size(0) 45 | # tiled_x = torch.cat([x, x, ], dim=0) 46 | # print("tiled_x:",tiled_x.size()) 47 | # idx = torch.randperm(batch_size) 48 | # 49 | # shuffled_y = y[idx] 50 | # concat_y = torch.cat([y, shuffled_y], dim=0) 51 | # print("concat_y:", concat_y.size()) 52 | # 53 | # 54 | # inputs = torch.cat([tiled_x, concat_y], dim=1) 55 | # print("inputs:",inputs.size()) 56 | # logits = self.layers(inputs) 57 | # 58 | # pred_xy = logits[:batch_size] 59 | # pred_x_y = logits[batch_size:] 60 | # loss = -(torch.mean(pred_xy) 61 | # - torch.log(torch.mean(torch.exp(pred_x_y)))) 62 | # 63 | # return loss 64 | # # 65 | 66 | 67 | class MINE(nn.Module): 68 | def __init__(self, x_dim, y_dim, hidden_size): 69 | super(MINE, self).__init__() 70 | self.T_func = nn.Sequential(nn.Linear(x_dim + y_dim, hidden_size), 71 | nn.ReLU(), 72 | nn.Linear(hidden_size, 1)) 73 | 74 | def forward(self, x_samples, y_samples): # samples have shape [sample_size, dim] 75 | # shuffle and concatenate 76 | sample_size = y_samples.shape[0] 77 | random_index = torch.randint(sample_size, (sample_size,)).long() 78 | 79 | y_shuffle = y_samples[random_index] 80 | #print("y_shuffle", y_shuffle.size()) 81 | 82 | # np默认返回float64类型。F.linear对精度傻了。所以加了个.to(torch.float32) 83 | T0 = self.T_func(torch.cat([x_samples, y_samples], dim=-1).to(torch.float32)) 84 | #print("T0:",T0.size()) 85 | T1 = self.T_func(torch.cat([x_samples, y_shuffle], dim=-1).to(torch.float32)) 86 | #print("T1:", T1.size()) 87 | 88 | lower_bound = T0.mean() - torch.log(T1.exp().mean()) 89 | 90 | # compute the negative loss (maximise loss == minimise -loss) 91 | return lower_bound 92 | 93 | def learning_loss(self, x_samples, y_samples): 94 | return -self.forward(x_samples, y_samples) 95 | 96 | 97 | class CLUBSample(nn.Module): # Sampled version of the CLUB estimator 98 | def __init__(self, x_dim, y_dim, hidden_size): 99 | super(CLUBSample, self).__init__() 100 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size // 2), 101 | nn.ReLU(), 102 | nn.Linear(hidden_size // 2, y_dim)) 103 | 104 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size // 2), 105 | nn.ReLU(), 106 | nn.Linear(hidden_size // 2, y_dim), 107 | nn.Tanh()) 108 | 109 | def get_mu_logvar(self, x_samples): 110 | mu = self.p_mu(x_samples) 111 | logvar = self.p_logvar(x_samples) 112 | return mu, logvar 113 | 114 | def loglikeli(self, x_samples, y_samples): 115 | mu, logvar = self.get_mu_logvar(x_samples) 116 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean() 117 | 118 | def forward(self, x_samples, y_samples): 119 | mu, logvar = self.get_mu_logvar(x_samples) 120 | 121 | sample_size = x_samples.shape[0] 122 | # random_index = torch.randint(sample_size, (sample_size,)).long() 123 | random_index = torch.randperm(sample_size).long() 124 | 125 | positive = - (mu - y_samples) ** 2 / logvar.exp() 126 | negative = - (mu - y_samples[random_index]) ** 2 / logvar.exp() 127 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 128 | return upper_bound / 2. 129 | 130 | def learning_loss(self, x_samples, y_samples): 131 | return - self.loglikeli(x_samples, y_samples) 132 | 133 | # x = torch.randn(32, 10, 512) 134 | # y = torch.randn(32, 10, 2048) 135 | # 136 | # model = MINE(x_dim=512, y_dim=2048, hidden_size=256) 137 | # loss = model.learning_loss(x, y) 138 | # print(loss) -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class B2_ResNet(nn.Module): 83 | # ResNet50 with two branches 84 | def __init__(self): 85 | # self.inplanes = 128 86 | self.inplanes = 64 87 | super(B2_ResNet, self).__init__() 88 | 89 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 90 | bias=False) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 95 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 96 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2) 97 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2) 98 | 99 | self.inplanes = 512 100 | self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2) 101 | self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, 116 | kernel_size=1, stride=stride, bias=False), 117 | nn.BatchNorm2d(planes * block.expansion), 118 | ) 119 | 120 | layers = [] 121 | layers.append(block(self.inplanes, planes, stride, downsample)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes)) 125 | 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | x = self.bn1(x) 131 | x = self.relu(x) 132 | x = self.maxpool(x) 133 | 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x1 = self.layer3_1(x) 137 | x1 = self.layer4_1(x1) 138 | 139 | x2 = self.layer3_2(x) 140 | x2 = self.layer4_2(x2) 141 | 142 | return x1, x2 143 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/test.sh: -------------------------------------------------------------------------------- 1 | 2 | setting='S4' 3 | visual_backbone="pvt" # "resnet" or "pvt" 4 | 5 | python test_ta.py \ 6 | --session_name ${setting}_${visual_backbone} \ 7 | --visual_backbone ${visual_backbone} \ 8 | --weights "S4_pvt_best.pth" \ 9 | --test_batch_size 4 \ 10 | --tpavi_stages 0 1 2 3 \ 11 | --tpavi_va_flag \ 12 | --save_pred_mask 13 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/test_at.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import shutil 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import argparse 9 | import logging 10 | 11 | from config import cfg 12 | from dataloader import S4Dataset 13 | # from torchvggish import vggish 14 | 15 | from utils import pyutils 16 | from utils.utility import logger, mask_iou, Eval_Fmeasure, save_mask 17 | from utils.system import setup_logging 18 | import pdb 19 | 20 | from model.main_model_2 import AT_VQVAE_Encoder,AVT_VQVAE_Encoder 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--session_name", default="S4", type=str, help="the S4 setting") 25 | parser.add_argument("--visual_backbone", default="resnet", type=str, help="use resnet50 or pvt-v2 as the visual backbone") 26 | 27 | parser.add_argument("--test_batch_size", default=1, type=int) 28 | parser.add_argument("--max_epoches", default=15, type=int) 29 | parser.add_argument("--lr", default=0.0001, type=float) 30 | parser.add_argument("--num_workers", default=8, type=int) 31 | parser.add_argument("--wt_dec", default=5e-4, type=float) 32 | 33 | parser.add_argument("--tpavi_stages", default=[], nargs='+', type=int, help='add tpavi block in which stages: [0, 1, 2, 3') 34 | parser.add_argument("--tpavi_vv_flag", action='store_true', default=False, help='visual-visual self-attention') 35 | parser.add_argument("--tpavi_va_flag", action='store_true', default=False, help='visual-audio cross-attention') 36 | 37 | parser.add_argument("--weights",type=str) 38 | parser.add_argument("--save_pred_mask", action='store_true', default=False, help="save predited masks or not") 39 | parser.add_argument('--log_dir', default='./test_logs', type=str) 40 | 41 | args = parser.parse_args() 42 | 43 | if (args.visual_backbone).lower() == "resnet": 44 | from model import ResNet_AVSModel as AVSModel 45 | print('==> Use ResNet50 as the visual backbone...') 46 | elif (args.visual_backbone).lower() == "pvt": 47 | from model import PVT_AVSModel as AVSModel 48 | print('==> Use pvt-v2 as the visual backbone...') 49 | else: 50 | raise NotImplementedError("only support the resnet50 and pvt-v2") 51 | 52 | 53 | '''upstream_model setting''' 54 | text_dim = 768 55 | video_dim = 512 56 | audio_dim = 128 57 | text_lstm_dim = 128 58 | text_output_dim = 256 59 | video_output_dim = 2048 60 | audio_output_dim = 256 61 | n_embeddings = 400 62 | embedding_dim = 256 63 | start_epoch = -1 64 | model_resume = False 65 | total_step = 0 66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | Text_ar_lstm = nn.LSTM(text_dim, text_lstm_dim, num_layers=2, batch_first=True, bidirectional=True) 68 | 69 | # AT 70 | # Encoder = AT_VQVAE_Encoder(text_lstm_dim*2, audio_dim, text_output_dim, audio_output_dim, n_embeddings, embedding_dim) 71 | 72 | # AVT 73 | Encoder = AVT_VQVAE_Encoder(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim, n_embeddings, embedding_dim) 74 | 75 | AT_10_5_Linear = nn.Linear(10, 5) 76 | 77 | Text_ar_lstm.double() 78 | Encoder.double() 79 | 80 | Text_ar_lstm.cuda() 81 | Encoder.cuda() 82 | AT_10_5_Linear.cuda() 83 | 84 | if model_resume is True: 85 | path_checkpoints = "..." 86 | checkpoints = torch.load(path_checkpoints) 87 | Text_ar_lstm.load_state_dict(checkpoints['Text_ar_lstm_parameters']) 88 | Encoder.load_state_dict(checkpoints['Encoder_parameters']) 89 | start_epoch = checkpoints['epoch'] 90 | print("Resume from number {}-th model.".format(start_epoch)) 91 | 92 | # Fix seed 93 | FixSeed = 123 94 | random.seed(FixSeed) 95 | np.random.seed(FixSeed) 96 | torch.manual_seed(FixSeed) 97 | torch.cuda.manual_seed(FixSeed) 98 | 99 | 100 | # Log directory 101 | if not os.path.exists(args.log_dir): 102 | os.makedirs(args.log_dir) 103 | # Logs 104 | prefix = args.session_name 105 | log_dir = os.path.join(args.log_dir, '{}'.format(time.strftime(prefix + '_%Y%m%d-%H%M%S'))) 106 | args.log_dir = log_dir 107 | 108 | # Save scripts 109 | script_path = os.path.join(log_dir, 'scripts') 110 | if not os.path.exists(script_path): 111 | os.makedirs(script_path, exist_ok=True) 112 | 113 | scripts_to_save = ['train.sh', 'train.py', 'test.sh', 'test.py', 'config.py', 'dataloader.py', './model/ResNet_AVSModel.py', './model/PVT_AVSModel.py', 'loss.py'] 114 | for script in scripts_to_save: 115 | dst_path = os.path.join(script_path, script) 116 | try: 117 | shutil.copy(script, dst_path) 118 | except IOError: 119 | os.makedirs(os.path.dirname(dst_path), exist_ok=True) 120 | shutil.copy(script, dst_path) 121 | 122 | # Set logger 123 | log_path = os.path.join(log_dir, 'log') 124 | if not os.path.exists(log_path): 125 | os.makedirs(log_path, exist_ok=True) 126 | 127 | setup_logging(filename=os.path.join(log_path, 'log.txt')) 128 | logger = logging.getLogger(__name__) 129 | logger.info('==> Config: {}'.format(cfg)) 130 | logger.info('==> Arguments: {}'.format(args)) 131 | logger.info('==> Experiment: {}'.format(args.session_name)) 132 | 133 | # Model 134 | model = AVSModel.Pred_endecoder(channel=256, \ 135 | config=cfg, \ 136 | tpavi_stages=args.tpavi_stages, \ 137 | tpavi_vv_flag=args.tpavi_vv_flag, \ 138 | tpavi_va_flag=args.tpavi_va_flag) 139 | model.load_state_dict(torch.load(args.weights)) 140 | model = model.cuda() 141 | logger.info('=> Load trained model %s'%args.weights) 142 | 143 | # audio backbone 144 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 145 | 146 | # Test data 147 | split = 'test' 148 | test_dataset = S4Dataset(split) 149 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 150 | batch_size=args.test_batch_size, 151 | shuffle=False, 152 | num_workers=args.num_workers, 153 | pin_memory=True, 154 | collate_fn=test_dataset.collate_func) 155 | 156 | avg_meter_miou = pyutils.AverageMeter('miou') 157 | avg_meter_F = pyutils.AverageMeter('F_score') 158 | 159 | # Test 160 | model.eval() 161 | with torch.no_grad(): 162 | for n_iter, batch_data in enumerate(test_dataloader): 163 | query, imgs, audio_feature, mask = batch_data['query'],batch_data['imgs_tensor'],batch_data['audio_fea'],batch_data['masks_tensor'] 164 | category_list, video_name_list = batch_data['category'],batch_data['video_name'] 165 | query = query.double().cuda() 166 | imgs = imgs.cuda() 167 | audio_feature = audio_feature.cuda() 168 | mask = mask.cuda() 169 | B, frame, C, H, W = imgs.shape 170 | imgs = imgs.view(B*frame, C, H, W) 171 | mask = mask.view(B*frame, H, W) 172 | 173 | batch_dim = query.size()[0] 174 | hidden_dim = 128 175 | num_layers = 2 176 | text_hidden = (torch.zeros(2*num_layers, batch_dim, hidden_dim).double().cuda(), 177 | torch.zeros(2*num_layers, batch_dim, hidden_dim).double().cuda()) 178 | text_feature, text_hidden = Text_ar_lstm(query, text_hidden) 179 | text_feature = text_feature.cuda() 180 | 181 | text_feature = text_feature.transpose(2, 1).contiguous() # [batch, text_dim, length:10] 182 | text_feature = AT_10_5_Linear(text_feature.to(torch.float32)) 183 | text_feature = text_feature.transpose(2, 1).contiguous().to(torch.float64) # [batch, length:3, text_dim] 184 | 185 | text_vq = Encoder.Text_VQ_Encoder(text_feature)# [B, T, 256] 186 | text_vq = text_vq.reshape(-1, text_vq.shape[-1]) 187 | 188 | output, visual_map_list, a_fea_list = model(imgs, text_vq.to(torch.float32)) # [bs*5, 1, 224, 224] 189 | 190 | 191 | if args.save_pred_mask: 192 | mask_save_path = os.path.join(log_dir, 'pred_masks') 193 | save_mask(output.squeeze(1), mask_save_path, category_list, video_name_list) 194 | 195 | miou = mask_iou(output.squeeze(1), mask) 196 | avg_meter_miou.add({'miou': miou}) 197 | F_score = Eval_Fmeasure(output.squeeze(1), mask, log_dir) 198 | avg_meter_F.add({'F_score': F_score}) 199 | print('n_iter: {}, iou: {}, F_score: {}'.format(n_iter, miou, F_score)) 200 | 201 | 202 | miou = (avg_meter_miou.pop('miou')) 203 | F_score = (avg_meter_F.pop('F_score')) 204 | print('test miou:', miou.item()) 205 | print('test F_score:', F_score) 206 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score)) -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/test_ta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import shutil 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import argparse 9 | import logging 10 | 11 | from config import cfg 12 | from dataloader import S4Dataset 13 | # from torchvggish import vggish 14 | 15 | from utils import pyutils 16 | from utils.utility import logger, mask_iou, Eval_Fmeasure, save_mask 17 | from utils.system import setup_logging 18 | import pdb 19 | 20 | from model.main_model_2 import AT_VQVAE_Encoder,AVT_VQVAE_Encoder 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--session_name", default="S4", type=str, help="the S4 setting") 26 | parser.add_argument("--visual_backbone", default="resnet", type=str, help="use resnet50 or pvt-v2 as the visual backbone") 27 | 28 | parser.add_argument("--test_batch_size", default=1, type=int) 29 | parser.add_argument("--max_epoches", default=15, type=int) 30 | parser.add_argument("--lr", default=0.0001, type=float) 31 | parser.add_argument("--num_workers", default=8, type=int) 32 | parser.add_argument("--wt_dec", default=5e-4, type=float) 33 | 34 | parser.add_argument("--tpavi_stages", default=[], nargs='+', type=int, help='add tpavi block in which stages: [0, 1, 2, 3') 35 | parser.add_argument("--tpavi_vv_flag", action='store_true', default=False, help='visual-visual self-attention') 36 | parser.add_argument("--tpavi_va_flag", action='store_true', default=False, help='visual-audio cross-attention') 37 | 38 | parser.add_argument("--weights",type=str) 39 | parser.add_argument("--save_pred_mask", action='store_true', default=False, help="save predited masks or not") 40 | parser.add_argument('--log_dir', default='./test_logs', type=str) 41 | 42 | args = parser.parse_args() 43 | 44 | if (args.visual_backbone).lower() == "resnet": 45 | from model import ResNet_AVSModel as AVSModel 46 | print('==> Use ResNet50 as the visual backbone...') 47 | elif (args.visual_backbone).lower() == "pvt": 48 | from model import PVT_AVSModel as AVSModel 49 | print('==> Use pvt-v2 as the visual backbone...') 50 | else: 51 | raise NotImplementedError("only support the resnet50 and pvt-v2") 52 | 53 | 54 | '''upstream_model setting''' 55 | text_dim = 768 56 | video_dim = 512 57 | audio_dim = 128 58 | text_lstm_dim = 128 59 | text_output_dim = 256 60 | video_output_dim = 2048 61 | audio_output_dim = 256 62 | n_embeddings = 400 63 | embedding_dim = 256 64 | start_epoch = -1 65 | model_resume = False 66 | total_step = 0 67 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 68 | Text_ar_lstm = nn.LSTM(text_dim, text_lstm_dim, num_layers=2, batch_first=True, bidirectional=True) 69 | 70 | # AT 71 | # Encoder = AT_VQVAE_Encoder(text_lstm_dim*2, audio_dim, text_output_dim, audio_output_dim, n_embeddings, embedding_dim) 72 | 73 | # AVT 74 | Encoder = AVT_VQVAE_Encoder(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim, n_embeddings, embedding_dim) 75 | 76 | AT_10_5_Linear = nn.Linear(10, 5) 77 | 78 | Text_ar_lstm.double() 79 | Encoder.double() 80 | 81 | Text_ar_lstm.cuda() 82 | Encoder.cuda() 83 | AT_10_5_Linear.cuda() 84 | 85 | if model_resume is True: 86 | path_checkpoints = "..." 87 | checkpoints = torch.load(path_checkpoints) 88 | Text_ar_lstm.load_state_dict(checkpoints['Text_ar_lstm_parameters']) 89 | Encoder.load_state_dict(checkpoints['Encoder_parameters']) 90 | start_epoch = checkpoints['epoch'] 91 | print("Resume from number {}-th model.".format(start_epoch)) 92 | 93 | # Fix seed 94 | FixSeed = 123 95 | random.seed(FixSeed) 96 | np.random.seed(FixSeed) 97 | torch.manual_seed(FixSeed) 98 | torch.cuda.manual_seed(FixSeed) 99 | 100 | 101 | 102 | # Log directory 103 | if not os.path.exists(args.log_dir): 104 | os.makedirs(args.log_dir) 105 | # Logs 106 | prefix = args.session_name 107 | log_dir = os.path.join(args.log_dir, '{}'.format(time.strftime(prefix + '_%Y%m%d-%H%M%S'))) 108 | args.log_dir = log_dir 109 | 110 | # Save scripts 111 | script_path = os.path.join(log_dir, 'scripts') 112 | if not os.path.exists(script_path): 113 | os.makedirs(script_path, exist_ok=True) 114 | 115 | scripts_to_save = ['train.sh', 'train.py', 'test.sh', 'test.py', 'config.py', 'dataloader.py', './model/ResNet_AVSModel.py', './model/PVT_AVSModel.py', 'loss.py'] 116 | for script in scripts_to_save: 117 | dst_path = os.path.join(script_path, script) 118 | try: 119 | shutil.copy(script, dst_path) 120 | except IOError: 121 | os.makedirs(os.path.dirname(dst_path), exist_ok=True) 122 | shutil.copy(script, dst_path) 123 | 124 | # Set logger 125 | log_path = os.path.join(log_dir, 'log') 126 | if not os.path.exists(log_path): 127 | os.makedirs(log_path, exist_ok=True) 128 | 129 | setup_logging(filename=os.path.join(log_path, 'log.txt')) 130 | logger = logging.getLogger(__name__) 131 | logger.info('==> Config: {}'.format(cfg)) 132 | logger.info('==> Arguments: {}'.format(args)) 133 | logger.info('==> Experiment: {}'.format(args.session_name)) 134 | 135 | # Model 136 | model = AVSModel.Pred_endecoder(channel=256, \ 137 | config=cfg, \ 138 | tpavi_stages=args.tpavi_stages, \ 139 | tpavi_vv_flag=args.tpavi_vv_flag, \ 140 | tpavi_va_flag=args.tpavi_va_flag) 141 | model.load_state_dict(torch.load(args.weights)) 142 | model = model.cuda() 143 | # model = torch.nn.DataParallel(model).cuda() 144 | logger.info('=> Load trained model %s'%args.weights) 145 | 146 | # audio backbone 147 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 148 | 149 | # Test data 150 | split = 'test' 151 | test_dataset = S4Dataset(split) 152 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 153 | batch_size=args.test_batch_size, 154 | shuffle=False, 155 | num_workers=args.num_workers, 156 | pin_memory=True, 157 | collate_fn=test_dataset.collate_func) 158 | 159 | avg_meter_miou = pyutils.AverageMeter('miou') 160 | avg_meter_F = pyutils.AverageMeter('F_score') 161 | 162 | # Test 163 | model.eval() 164 | with torch.no_grad(): 165 | for n_iter, batch_data in enumerate(test_dataloader): 166 | query, imgs, audio_feature, mask = batch_data['query'],batch_data['imgs_tensor'],batch_data['audio_fea'],batch_data['masks_tensor'] 167 | category_list, video_name_list = batch_data['category'],batch_data['video_name'] 168 | query = query.double().cuda() 169 | imgs = imgs.cuda() 170 | audio_feature = audio_feature.cuda() 171 | mask = mask.cuda() 172 | B, frame, C, H, W = imgs.shape 173 | imgs = imgs.view(B*frame, C, H, W) 174 | mask = mask.view(B*frame, H, W) 175 | 176 | audio_feature = audio_feature.to(torch.float64) 177 | audio_feature = audio_feature.repeat(1, 2, 1)# [B, 5, audio_dim] -> [B, 10, audio_dim] 178 | 179 | audio_feature = audio_feature.transpose(2, 1).contiguous() # [batch, audio_dim, length:10] 180 | audio_feature = AT_10_5_Linear(audio_feature.to(torch.float32)) 181 | audio_feature = audio_feature.transpose(2, 1).contiguous().to(torch.float64) # [batch, length:3, audio_dim] 182 | 183 | audio_vq = Encoder.Audio_VQ_Encoder(audio_feature)# [B, T, 256] 184 | 185 | audio_vq = audio_vq.reshape(-1, audio_vq.shape[-1]) 186 | output, visual_map_list, a_fea_list = model(imgs, audio_vq.to(torch.float32)) # [bs*5, 1, 224, 224] 187 | 188 | 189 | if args.save_pred_mask: 190 | mask_save_path = os.path.join(log_dir, 'pred_masks') 191 | save_mask(output.squeeze(1), mask_save_path, category_list, video_name_list) 192 | 193 | miou = mask_iou(output.squeeze(1), mask) 194 | avg_meter_miou.add({'miou': miou}) 195 | F_score = Eval_Fmeasure(output.squeeze(1), mask, log_dir) 196 | avg_meter_F.add({'F_score': F_score}) 197 | print('n_iter: {}, iou: {}, F_score: {}'.format(n_iter, miou, F_score)) 198 | 199 | 200 | miou = (avg_meter_miou.pop('miou')) 201 | F_score = (avg_meter_F.pop('F_score')) 202 | print('test miou:', miou.item()) 203 | print('test F_score:', F_score) 204 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score)) -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/mel_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/mel_features.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish_input.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish_input.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish_params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/__pycache__/vggish_params.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/vggish.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch import hub 5 | 6 | from . import vggish_input, vggish_params 7 | 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, features): 11 | super(VGG, self).__init__() 12 | self.features = features 13 | self.embeddings = nn.Sequential( 14 | nn.Linear(512 * 4 * 6, 4096), 15 | nn.ReLU(True), 16 | nn.Linear(4096, 4096), 17 | nn.ReLU(True), 18 | nn.Linear(4096, 128), 19 | nn.ReLU(True)) 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | 24 | # Transpose the output from features to 25 | # remain compatible with vggish embeddings 26 | x = torch.transpose(x, 1, 3) 27 | x = torch.transpose(x, 1, 2) 28 | x = x.contiguous() 29 | x = x.view(x.size(0), -1) 30 | 31 | return self.embeddings(x) 32 | 33 | 34 | class Postprocessor(nn.Module): 35 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a 36 | numpy array in order to preserve the gradient. 37 | 38 | "The initial release of AudioSet included 128-D VGGish embeddings for each 39 | segment of AudioSet. These released embeddings were produced by applying 40 | a PCA transformation (technically, a whitening transform is included as well) 41 | and 8-bit quantization to the raw embedding output from VGGish, in order to 42 | stay compatible with the YouTube-8M project which provides visual embeddings 43 | in the same format for a large set of YouTube videos. This class implements 44 | the same PCA (with whitening) and quantization transformations." 45 | """ 46 | 47 | def __init__(self): 48 | """Constructs a postprocessor.""" 49 | super(Postprocessor, self).__init__() 50 | # Create empty matrix, for user's state_dict to load 51 | self.pca_eigen_vectors = torch.empty( 52 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), 53 | dtype=torch.float, 54 | ) 55 | self.pca_means = torch.empty( 56 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float 57 | ) 58 | 59 | self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) 60 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) 61 | 62 | def postprocess(self, embeddings_batch): 63 | """Applies tensor postprocessing to a batch of embeddings. 64 | 65 | Args: 66 | embeddings_batch: An tensor of shape [batch_size, embedding_size] 67 | containing output from the embedding layer of VGGish. 68 | 69 | Returns: 70 | A tensor of the same shape as the input, containing the PCA-transformed, 71 | quantized, and clipped version of the input. 72 | """ 73 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( 74 | embeddings_batch.shape, 75 | ) 76 | assert ( 77 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE 78 | ), "Bad batch shape: %r" % (embeddings_batch.shape,) 79 | 80 | # Apply PCA. 81 | # - Embeddings come in as [batch_size, embedding_size]. 82 | # - Transpose to [embedding_size, batch_size]. 83 | # - Subtract pca_means column vector from each column. 84 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 85 | # where both are are equal to embedding_size in our case. 86 | # - Transpose result back to [batch_size, embedding_size]. 87 | pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() 88 | 89 | # Quantize by: 90 | # - clipping to [min, max] range 91 | clipped_embeddings = torch.clamp( 92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL 93 | ) 94 | # - convert to 8-bit in range [0.0, 255.0] 95 | quantized_embeddings = torch.round( 96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) 97 | * ( 98 | 255.0 99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) 100 | ) 101 | ) 102 | return torch.squeeze(quantized_embeddings) 103 | 104 | def forward(self, x): 105 | return self.postprocess(x) 106 | 107 | 108 | def make_layers(): 109 | layers = [] 110 | in_channels = 1 111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | layers += [conv2d, nn.ReLU(inplace=True)] 117 | in_channels = v 118 | return nn.Sequential(*layers) 119 | 120 | 121 | def _vgg(): 122 | return VGG(make_layers()) 123 | 124 | 125 | # def _spectrogram(): 126 | # config = dict( 127 | # sr=16000, 128 | # n_fft=400, 129 | # n_mels=64, 130 | # hop_length=160, 131 | # window="hann", 132 | # center=False, 133 | # pad_mode="reflect", 134 | # htk=True, 135 | # fmin=125, 136 | # fmax=7500, 137 | # output_format='Magnitude', 138 | # # device=device, 139 | # ) 140 | # return Spectrogram.MelSpectrogram(**config) 141 | 142 | 143 | class VGGish(VGG): 144 | def __init__(self, cfg, device=None): 145 | super().__init__(make_layers()) 146 | if cfg.TRAIN.FREEZE_AUDIO_EXTRACTOR: 147 | state_dict = torch.load(cfg.TRAIN.PRETRAINED_VGGISH_MODEL_PATH) 148 | super().load_state_dict(state_dict) 149 | print(f'==> Load pretrained VGGish parameters from {cfg.TRAIN.PRETRAINED_VGGISH_MODEL_PATH}') 150 | 151 | if device is None: 152 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 153 | print("device: ", device) 154 | self.device = device 155 | 156 | self.preprocess = cfg.TRAIN.PREPROCESS_AUDIO_TO_LOG_MEL 157 | self.postprocess = cfg.TRAIN.POSTPROCESS_LOG_MEL_WITH_PCA 158 | if self.postprocess: 159 | self.pproc = Postprocessor() 160 | if cfg.TRAIN.FREEZE_AUDIO_EXTRACTOR : 161 | state_dict = torch.load(cfg.TRAIN.PRETRAINED_PCA_PARAMS_PATH) 162 | # TODO: Convert the state_dict to torch 163 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( 164 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float 165 | ) 166 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( 167 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float 168 | ) 169 | self.pproc.load_state_dict(state_dict) 170 | self.to(self.device) 171 | 172 | def forward(self, x): 173 | if self.preprocess: 174 | print(">>> pre processing...") 175 | x = self._preprocess(x) 176 | x = x.to(self.device) 177 | x = VGG.forward(self, x) 178 | if self.postprocess: 179 | print(">>> post processing...") 180 | x = self._postprocess(x) 181 | return x 182 | 183 | def _preprocess(self, x): 184 | # if isinstance(x, np.ndarray): 185 | # x = vggish_input.waveform_to_examples(x, fs) 186 | if isinstance(x, str): 187 | x = vggish_input.wavfile_to_examples(x) 188 | else: 189 | raise AttributeError 190 | return x 191 | 192 | def _postprocess(self, x): 193 | return self.pproc(x) 194 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | # Modification: Return torch tensors rather than numpy arrays 19 | import torch 20 | 21 | import numpy as np 22 | import resampy 23 | 24 | from . import mel_features 25 | from . import vggish_params 26 | 27 | import soundfile as sf 28 | 29 | 30 | def waveform_to_examples(data, sample_rate, return_tensor=True): 31 | """Converts audio waveform into an array of examples for VGGish. 32 | 33 | Args: 34 | data: np.array of either one dimension (mono) or two dimensions 35 | (multi-channel, with the outer dimension representing channels). 36 | Each sample is generally expected to lie in the range [-1.0, +1.0], 37 | although this is not required. 38 | sample_rate: Sample rate of data. 39 | return_tensor: Return data as a Pytorch tensor ready for VGGish 40 | 41 | Returns: 42 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 43 | a sequence of examples, each of which contains a patch of log mel 44 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 45 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 46 | 47 | """ 48 | # Convert to mono. 49 | if len(data.shape) > 1: 50 | data = np.mean(data, axis=1) 51 | # Resample to the rate assumed by VGGish. 52 | if sample_rate != vggish_params.SAMPLE_RATE: 53 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 54 | 55 | # Compute log mel spectrogram features. 56 | log_mel = mel_features.log_mel_spectrogram( 57 | data, 58 | audio_sample_rate=vggish_params.SAMPLE_RATE, 59 | log_offset=vggish_params.LOG_OFFSET, 60 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 61 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 62 | num_mel_bins=vggish_params.NUM_MEL_BINS, 63 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 64 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 65 | 66 | # Frame features into examples. 67 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 68 | example_window_length = int(round( 69 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 70 | example_hop_length = int(round( 71 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 72 | log_mel_examples = mel_features.frame( 73 | log_mel, 74 | window_length=example_window_length, 75 | hop_length=example_hop_length) 76 | 77 | if return_tensor: 78 | log_mel_examples = torch.tensor( 79 | log_mel_examples, requires_grad=True)[:, None, :, :].float() 80 | 81 | return log_mel_examples 82 | 83 | 84 | def wavfile_to_examples(wav_file, return_tensor=True): 85 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 86 | 87 | Args: 88 | wav_file: String path to a file, or a file-like object. The file 89 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 90 | torch: Return data as a Pytorch tensor ready for VGGish 91 | 92 | Returns: 93 | See waveform_to_examples. 94 | """ 95 | wav_data, sr = sf.read(wav_file, dtype='int16') 96 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 97 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 98 | return waveform_to_examples(samples, sr, return_tensor) 99 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/torchvggish/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 16000 28 | STFT_WINDOW_LENGTH_SECONDS = 0.025 29 | STFT_HOP_LENGTH_SECONDS = 0.010 30 | NUM_MEL_BINS = NUM_BANDS 31 | MEL_MIN_HZ = 125 32 | MEL_MAX_HZ = 7500 33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 36 | 37 | # Parameters used for embedding postprocessing. 38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 39 | PCA_MEANS_NAME = 'pca_means' 40 | QUANTIZE_MIN_VAL = -2.0 41 | QUANTIZE_MAX_VAL = +2.0 42 | 43 | # Hyperparameters used in training. 44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 47 | 48 | # Names of ops, tensors, and features. 49 | INPUT_OP_NAME = 'vggish/input_features' 50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 51 | OUTPUT_OP_NAME = 'vggish/embedding' 52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 54 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/train.sh: -------------------------------------------------------------------------------- 1 | setting='S4' 2 | visual_backbone="pvt" # "resnet" or "pvt" 3 | # at:audio for train, text for test 4 | # ta:text for train, audio for test 5 | python train_ta.py \ 6 | --session_name ${setting}_${visual_backbone} \ 7 | --visual_backbone ${visual_backbone} \ 8 | --train_batch_size 4 \ 9 | --lr 0.0001 \ 10 | --tpavi_stages 0 1 2 3 \ 11 | --tpavi_va_flag 12 | 13 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/pyutils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/pyutils.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/utility.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/__pycache__/utility.cpython-38.pyc -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/system.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import shutil 4 | import numpy as np 5 | from PIL import Image 6 | import logging 7 | import random 8 | 9 | import torch 10 | 11 | 12 | def setup_logging(filename, resume=False): 13 | root_logger = logging.getLogger() 14 | 15 | ch = logging.StreamHandler() 16 | fh = logging.FileHandler(filename=filename, mode='a' if resume else 'w') 17 | 18 | root_logger.setLevel(logging.INFO) 19 | ch.setLevel(logging.INFO) 20 | fh.setLevel(logging.INFO) 21 | 22 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s") 23 | ch.setFormatter(formatter) 24 | fh.setFormatter(formatter) 25 | 26 | root_logger.addHandler(ch) 27 | root_logger.addHandler(fh) 28 | 29 | 30 | def setup_seed(seed): 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | os.environ['PYTHONHASHSEED'] = str(seed) 39 | 40 | 41 | class AverageMeter(object): 42 | 43 | def __init__(self, window=-1): 44 | self.window = window 45 | self.reset() 46 | 47 | def reset(self): 48 | self.val = 0 49 | self.avg = 0 50 | self.sum = 0 51 | self.cnt = 0 52 | self.max = -np.Inf 53 | 54 | if self.window > 0: 55 | self.val_arr = np.zeros(self.window) 56 | self.arr_idx = 0 57 | 58 | def update(self, val, n=1): 59 | self.val = val 60 | self.cnt += n 61 | self.max = max(self.max, val) 62 | 63 | if self.window > 0: 64 | self.val_arr[self.arr_idx] = val 65 | self.arr_idx = (self.arr_idx + 1) % self.window 66 | self.avg = self.val_arr.mean() 67 | else: 68 | self.sum += val * n 69 | self.avg = self.sum / self.cnt 70 | 71 | 72 | class FrameSecondMeter(object): 73 | 74 | def __init__(self): 75 | self.st = time.time() 76 | self.fps = None 77 | self.ed = None 78 | self.frame_n = 0 79 | 80 | def add_frame_n(self, frame_n): 81 | self.frame_n += frame_n 82 | 83 | def end(self): 84 | self.ed = time.time() 85 | self.fps = self.frame_n / (self.ed - self.st) 86 | 87 | 88 | def gct(f='l'): 89 | ''' 90 | get current time 91 | :param f: 'l' for log, 'f' for file name 92 | :return: formatted time 93 | ''' 94 | if f == 'l': 95 | return time.strftime('%m/%d %H:%M:%S', time.localtime(time.time())) 96 | elif f == 'f': 97 | return time.strftime('%m_%d_%H_%M', time.localtime(time.time())) 98 | 99 | 100 | def save_scripts(path, scripts_to_save=None): 101 | if not os.path.exists(os.path.join(path, 'scripts')): 102 | os.makedirs(os.path.join(path, 'scripts')) 103 | 104 | if scripts_to_save is not None: 105 | for script in scripts_to_save: 106 | dst_path = os.path.join(path, 'scripts', script) 107 | try: 108 | shutil.copy(script, dst_path) 109 | except IOError: 110 | os.makedirs(os.path.dirname(dst_path)) 111 | shutil.copy(script, dst_path) 112 | 113 | 114 | def count_model_size(model): 115 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters()) / 1e6 116 | 117 | 118 | def load_image_in_PIL(path, mode='RGB'): 119 | img = Image.open(path) 120 | img.load() # Very important for loading large image 121 | return img.convert(mode) 122 | 123 | 124 | def print_mem(info=None): 125 | if info: 126 | print(info, end=' ') 127 | mem_allocated = round(torch.cuda.memory_allocated() / 1048576) 128 | mem_cached = round(torch.cuda.memory_cached() / 1048576) 129 | print(f'Mem allocated: {mem_allocated}MB, Mem cached: {mem_cached}MB') 130 | 131 | 132 | def set_bn_eval(m): 133 | classname = m.__class__.__name__ 134 | if classname.find('BatchNorm') != -1: 135 | m.eval() 136 | 137 | 138 | def match_name_keywords(n, name_keywords): 139 | out = False 140 | for b in name_keywords: 141 | if b in n: 142 | out = True 143 | break 144 | return out -------------------------------------------------------------------------------- /code/AVSBench_dowmstream/avs_scripts/avs_s4/utils/utility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import os 5 | import shutil 6 | import logging 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import sys 12 | import time 13 | import pandas as pd 14 | import pdb 15 | from torchvision import transforms 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | def save_checkpoint(state, epoch, is_best, checkpoint_dir='./models', filename='checkpoint', thres=100): 20 | """ 21 | - state 22 | - epoch 23 | - is_best 24 | - checkpoint_dir: default, ./models 25 | - filename: default, checkpoint 26 | - freq: default, 10 27 | - thres: default, 100 28 | """ 29 | if not os.path.isdir(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | 32 | if epoch >= thres: 33 | file_path = os.path.join(checkpoint_dir, filename + '_{}'.format(str(epoch)) + '.pth.tar') 34 | else: 35 | file_path = os.path.join(checkpoint_dir, filename + '.pth.tar') 36 | torch.save(state, file_path) 37 | logger.info('==> save model at {}'.format(file_path)) 38 | 39 | if is_best: 40 | cpy_file = os.path.join(checkpoint_dir, filename+'_model_best.pth.tar') 41 | shutil.copyfile(file_path, cpy_file) 42 | logger.info('==> save best model at {}'.format(cpy_file)) 43 | 44 | 45 | def mask_iou(pred, target, eps=1e-7, size_average=True): 46 | r""" 47 | param: 48 | pred: size [N x H x W] 49 | target: size [N x H x W] 50 | output: 51 | iou: size [1] (size_average=True) or [N] (size_average=False) 52 | """ 53 | assert len(pred.shape) == 3 and pred.shape == target.shape 54 | 55 | N = pred.size(0) 56 | num_pixels = pred.size(-1) * pred.size(-2) 57 | no_obj_flag = (target.sum(2).sum(1) == 0) 58 | 59 | temp_pred = torch.sigmoid(pred) 60 | pred = (temp_pred > 0.5).int() 61 | inter = (pred * target).sum(2).sum(1) 62 | union = torch.max(pred, target).sum(2).sum(1) 63 | 64 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 65 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 66 | union[no_obj_flag] = num_pixels 67 | 68 | iou = torch.sum(inter / (union+eps)) / N 69 | 70 | return iou 71 | 72 | 73 | 74 | def _eval_pr(y_pred, y, num, cuda_flag=True): 75 | if cuda_flag: 76 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 77 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 78 | else: 79 | prec, recall = torch.zeros(num), torch.zeros(num) 80 | thlist = torch.linspace(0, 1 - 1e-10, num) 81 | for i in range(num): 82 | y_temp = (y_pred >= thlist[i]).float() 83 | tp = (y_temp * y).sum() 84 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 85 | 86 | return prec, recall 87 | 88 | def Eval_Fmeasure(pred, gt, measure_path, pr_num=255): 89 | r""" 90 | param: 91 | pred: size [N x H x W] 92 | gt: size [N x H x W] 93 | output: 94 | iou: size [1] (size_average=True) or [N] (size_average=False) 95 | """ 96 | print('=> eval [FMeasure]..') 97 | pred = torch.sigmoid(pred) # =======================================[important] 98 | N = pred.size(0) 99 | beta2 = 0.3 100 | avg_f, img_num = 0.0, 0 101 | score = torch.zeros(pr_num) 102 | fLog = open(os.path.join(measure_path, 'FMeasure.txt'), 'w') 103 | print("{} videos in this batch".format(N)) 104 | 105 | for img_id in range(N): 106 | # examples with totally black GTs are out of consideration 107 | if torch.mean(gt[img_id]) == 0.0: 108 | continue 109 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num) 110 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 111 | f_score[f_score != f_score] = 0 # for Nan 112 | avg_f += f_score 113 | img_num += 1 114 | score = avg_f / img_num 115 | # print('score: ', score) 116 | fLog.close() 117 | 118 | return score.max().item() 119 | 120 | 121 | 122 | def save_mask(pred_masks, save_base_path, category_list, video_name_list): 123 | # pred_mask: [bs*5, 1, 224, 224] 124 | # print(f"=> {len(video_name_list)} videos in this batch") 125 | 126 | if not os.path.exists(save_base_path): 127 | os.makedirs(save_base_path, exist_ok=True) 128 | 129 | pred_masks = pred_masks.squeeze(2) 130 | pred_masks = (torch.sigmoid(pred_masks) > 0.5).int() 131 | 132 | pred_masks = pred_masks.view(-1, 5, pred_masks.shape[-2], pred_masks.shape[-1]) 133 | pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8) 134 | pred_masks *= 255 135 | bs = pred_masks.shape[0] 136 | 137 | for idx in range(bs): 138 | category, video_name = category_list[idx], video_name_list[idx] 139 | mask_save_path = os.path.join(save_base_path, category, video_name) 140 | if not os.path.exists(mask_save_path): 141 | os.makedirs(mask_save_path, exist_ok=True) 142 | one_video_masks = pred_masks[idx] # [5, 1, 224, 224] 143 | for video_id in range(len(one_video_masks)): 144 | one_mask = one_video_masks[video_id] 145 | output_name = "%s_%d.png"%(video_name, video_id) 146 | im = Image.fromarray(one_mask).convert('P') 147 | im.save(os.path.join(mask_save_path, output_name), format='PNG') 148 | 149 | 150 | def save_raw_img_mask(anno_file_path, raw_img_base_path, mask_base_path, split='test', r=0.5): 151 | df = pd.read_csv(anno_file_path, sep=',') 152 | df_test = df[df['split'] == split] 153 | count = 0 154 | for video_id in range(len(df_test)): 155 | video_name, category = df_test.iloc[video_id][0], df_test.iloc[video_id][2] 156 | raw_img_path = os.path.join(raw_img_base_path, split, category, video_name) 157 | for img_id in range(5): 158 | img_name = "%s_%d.png"%(video_name, img_id + 1) 159 | raw_img = cv2.imread(os.path.join(raw_img_path, img_name)) 160 | mask = cv2.imread(os.path.join(mask_base_path, 'pred_masks', category, video_name, "%s_%d.png"%(video_name, img_id))) 161 | # pdb.set_trace() 162 | raw_img_mask = cv2.addWeighted(raw_img, 1, mask, r, 0) 163 | save_img_path = os.path.join(mask_base_path, 'img_add_masks', category, video_name) 164 | if not os.path.exists(save_img_path): 165 | os.makedirs(save_img_path, exist_ok=True) 166 | cv2.imwrite(os.path.join(save_img_path, img_name), raw_img_mask) 167 | count += 1 168 | print(f'count: {count} videos') 169 | 170 | 171 | 172 | 173 | if __name__ == "__main__": 174 | pred1 = torch.ones(4, 5, 5) 175 | target1 = torch.ones(4, 5, 5) 176 | iou1 = mask_iou(pred1, target1) 177 | 178 | pred2 = torch.zeros(4, 5, 5) 179 | target2 = torch.zeros(4, 5, 5) 180 | iou2 = mask_iou(pred2, target2) 181 | 182 | pred3 = torch.zeros(4, 5, 5) 183 | target3 = torch.ones(4, 5, 5) 184 | iou3 = mask_iou(pred3, target3) 185 | 186 | pred4 = torch.ones(4, 5, 5) 187 | target4 = torch.zeros(4, 5, 5) 188 | iou4 = mask_iou(pred4, target4) 189 | 190 | pred5 = torch.ones(4, 5, 5) 191 | target5 = torch.ones(4, 5, 5) 192 | target5[:2] = torch.zeros(5, 5) 193 | iou5 = mask_iou(pred5, target5) 194 | 195 | pred6 = torch.zeros(4, 5, 5) 196 | pred6[:2] = torch.ones(5, 5) 197 | target6 = torch.ones(4, 5, 5) 198 | iou6 = mask_iou(pred6, target6) 199 | 200 | one_mask = torch.randn(224, 224) 201 | one_mask = (torch.sigmoid(one_mask) > 0.5).numpy().astype(np.uint8) 202 | one_real_mask = one_mask * 255 203 | 204 | pdb.set_trace() 205 | 206 | 207 | -------------------------------------------------------------------------------- /code/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/src/.DS_Store -------------------------------------------------------------------------------- /code/src/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .vscode/ 4 | exp/ 5 | Exps/ 6 | data/ -------------------------------------------------------------------------------- /code/src/ave.sh: -------------------------------------------------------------------------------- 1 | # ave_va or ave_av 2 | # va:video for train, audio for test 3 | # av:audio for train, video for test 4 | python ave.py \ 5 | --gpu 0 \ 6 | --lr 0.0004 \ 7 | --clip_gradient 0.5 \ 8 | --snapshot_pref "./Exps/ave/" \ 9 | --n_epoch 30 \ 10 | --b 80 \ 11 | --test_batch_size 128 \ 12 | --dataset_name "ave_av" \ 13 | --print_freq 1 \ 14 | --eval_freq 1 -------------------------------------------------------------------------------- /code/src/ave_avvp.sh: -------------------------------------------------------------------------------- 1 | python ave_avvp.py \ 2 | --gpu 0 \ 3 | --lr 0.0004 \ 4 | --clip_gradient 0.5 \ 5 | --snapshot_pref "./Exps/ave_avvp/" \ 6 | --n_epoch 50 \ 7 | --b 80 \ 8 | --test_batch_size 64 \ 9 | --dataset_name "ave_avvp_av" \ 10 | --print_freq 1 \ 11 | --eval_freq 1 -------------------------------------------------------------------------------- /code/src/avvp.sh: -------------------------------------------------------------------------------- 1 | # avvp_av or avvp_va 2 | python avvp.py \ 3 | --gpu 0 \ 4 | --lr 0.0004 \ 5 | --clip_gradient 0.5 \ 6 | --snapshot_pref "./Exps/avvp/" \ 7 | --n_epoch 50 \ 8 | --b 80 \ 9 | --test_batch_size 64 \ 10 | --dataset_name "avvp_av" \ 11 | --print_freq 1 \ 12 | --eval_freq 1 -------------------------------------------------------------------------------- /code/src/configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | # ======================================= Learning configs ================================== 2 | n_epoch: 100 3 | batch_size: 64 4 | test_batch_size: 16 5 | lr: 0.001 6 | mi_lr: 0.001 7 | loss_weights: 0.5 8 | clip_gradient: 0.8 9 | start_epoch: 0 10 | weight_decay: 5e-4 11 | 12 | # ======================================= Print and snapshot configs ================================== 13 | snapshot_pref: ./exp/debug/123 14 | print_freq: 20 15 | eval_freq: 1 16 | 17 | 18 | -------------------------------------------------------------------------------- /code/src/configs/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description="A project implemented in pyTorch") 4 | 5 | # =========================== Learning Configs ============================ 6 | parser.add_argument('--dataset_name', type=str) 7 | parser.add_argument('--n_epoch', type=int) 8 | parser.add_argument('-b', '--batch_size', type=int) 9 | parser.add_argument('--test_batch_size', type=int) 10 | parser.add_argument('--lr', type=float) 11 | parser.add_argument('--gpu', type=str) 12 | parser.add_argument('--snapshot_pref', type=str) 13 | parser.add_argument('--resume', type=str, default="") 14 | parser.add_argument('--evaluate', action='store_true') 15 | parser.add_argument('--clip_gradient', type=float) 16 | parser.add_argument('--loss_weights', type=float) 17 | parser.add_argument('--start_epoch', type=int) 18 | parser.add_argument('--model_save_path', default='../../../checkpoint') 19 | # parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 20 | # help='momentum') 21 | parser.add_argument('--weight_decay', '--wd', type=float, 22 | metavar='W', help='weight decay (default: 5e-4)') 23 | 24 | # =========================== Display Configs ============================ 25 | parser.add_argument('--print_freq', type=int) 26 | parser.add_argument('--save_freq', type=int) 27 | parser.add_argument('--eval_freq', type=int) -------------------------------------------------------------------------------- /code/src/current_configs.yaml: -------------------------------------------------------------------------------- 1 | # ======================================= Learning configs ================================== 2 | n_epoch: 6 3 | batch_size: 80 4 | test_batch_size: 64 5 | lr: 0.0004 6 | mi_lr: 0.001 7 | loss_weights: 0.5 8 | clip_gradient: 0.5 9 | start_epoch: 0 10 | weight_decay: 5e-4 11 | snapshot_pref: ./Exps/pretrain/ 12 | print_freq: 1 13 | eval_freq: 1 14 | -------------------------------------------------------------------------------- /code/src/dataset/AVE_AVVP_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import h5py 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import pickle 8 | import zipfile 9 | from io import BytesIO 10 | 11 | class AVEDataset(Dataset): 12 | def __init__(self, data_root, split='train'): 13 | super(AVEDataset, self).__init__() 14 | self.split = split 15 | feature_root = 'AVE-ECCV18-master/data' 16 | self.visual_feature_path = os.path.join(feature_root, 'visual_feature.h5') 17 | self.audio_feature_path = os.path.join(feature_root, 'audio_feature.h5') 18 | # Now for the supervised task 19 | self.labels_path = os.path.join(data_root, f'{split}_labels.h5') 20 | self.sample_order_path = os.path.join(data_root, f'{split}_order.h5') 21 | self.h5_isOpen = False 22 | 23 | def __getitem__(self, index): 24 | if not self.h5_isOpen: 25 | self.visual_feature = h5py.File(self.visual_feature_path, 'r')['avadataset'] 26 | self.audio_feature = h5py.File(self.audio_feature_path, 'r')['avadataset'] 27 | self.labels = h5py.File(self.labels_path, 'r')['avadataset'] 28 | self.sample_order = h5py.File(self.sample_order_path, 'r')['order'] 29 | self.h5_isOpen = True 30 | sample_index = self.sample_order[index] 31 | visual_feat = self.visual_feature[sample_index] 32 | audio_feat = self.audio_feature[sample_index] 33 | label = self.labels[sample_index] 34 | return visual_feat, audio_feat, label 35 | 36 | def __len__(self): 37 | f = h5py.File(self.sample_order_path, 'r') 38 | sample_num = len(f['order']) 39 | f.close() 40 | return sample_num 41 | 42 | def generate_category_list(): 43 | file_path = 'AVE_AVVP/AVE_AVVP_Categories.txt' 44 | category_list = [] 45 | with open(file_path, 'r') as fr: 46 | for line in fr.readlines(): 47 | category_list.append(line.strip()) 48 | return category_list 49 | 50 | class AVVPDataset(Dataset): 51 | # for AVEL task 52 | def __init__(self, meta_csv_path, fea_base_path, split='train', modality='video'): 53 | super(AVVPDataset, self).__init__() 54 | self.modality = modality 55 | self.fea_base_path = fea_base_path 56 | self.split_df = pd.read_csv(meta_csv_path,sep='\t') 57 | self.all_categories = generate_category_list() 58 | print(f'total {len(self.all_categories)} positive classes in AVVP, 1 negative classes in AVVP') 59 | print(f'{len(self.split_df)} samples are used for {split}') 60 | 61 | def __getitem__(self, index): 62 | one_video_df = self.split_df.iloc[index] 63 | categorys, video_id = one_video_df['event_labels'].split(','), one_video_df['filename'] 64 | onsets, offsets = one_video_df['onset'].split(','), one_video_df['offset'].split(',') 65 | onsets = list(map(int, onsets)) 66 | offsets = list(map(int, offsets)) 67 | 68 | fea = self._load_fea(self.fea_base_path, video_id[:11]) 69 | 70 | if(self.modality=='audio'): 71 | if fea.shape[0] < 10: 72 | cur_t = fea.shape[0] 73 | add_arr = np.tile(fea[-1, :], (10-cur_t, 1)) 74 | fea = np.concatenate([fea, add_arr], axis=0) 75 | elif fea.shape[0] > 10: 76 | fea = fea[:10, :] 77 | 78 | avel_label = self._obtain_avel_label(onsets, offsets, categorys) # [10,26] 79 | 80 | return torch.from_numpy(fea), \ 81 | torch.from_numpy(avel_label), \ 82 | video_id 83 | 84 | def _load_fea(self, fea_base_path, video_id): 85 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 86 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 87 | for name in zfile.namelist(): 88 | if '.pkl' not in name: 89 | continue 90 | with zfile.open(name, mode='r') as fea_file: 91 | content = BytesIO(fea_file.read()) 92 | fea = pickle.load(content) 93 | return fea 94 | 95 | def _obtain_avel_label(self, onsets, offsets, categorys): 96 | T, category_num = 10, len(self.all_categories) 97 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 25+1] 98 | label[:, -1] = np.ones(T) 99 | iter_num = len(categorys) 100 | for i in range(iter_num): 101 | avc_label = np.zeros(T) 102 | avc_label[onsets[i]:offsets[i]] = 1 103 | class_id = self.all_categories.index(categorys[i]) 104 | bg_flag = 1 - avc_label 105 | 106 | 107 | """ 108 | The "&" operation on lists is used to find the common elements between two lists, 109 | rather than performing a bitwise "and" operation on each element, 110 | so it needs to be implemented using a loop. 111 | The reason for using "|" here is that if it is not "|", 112 | but a simple assignment, it will cause the previous part of the same label to be overwritten. 113 | 114 | IgN7v8nWmx8_30_40 0,5,0,6,9 1,9,5,8,10 Speech,Speech,Violin_fiddle,Violin_fiddle,Violin_fiddle 115 | For example, in the example given above, 116 | the second "Speech" will overwrite the first "Speech", so "|" operation needs to be used. 117 | """ 118 | for j in range(10): 119 | label[j, class_id] = int(label[j, class_id]) | int(avc_label[j]) 120 | 121 | """ 122 | The "&" operation on lists is used to find the common elements between two lists, 123 | rather than performing a bitwise "and" operation on each element, 124 | so it needs to be implemented using a loop. 125 | """ 126 | for j in range(10): 127 | label[j, -1] = int(label[j, -1]) & int(bg_flag[j]) 128 | return label 129 | 130 | def __len__(self,): 131 | return len(self.split_df) -------------------------------------------------------------------------------- /code/src/dataset/AVE_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | class AVEDataset(Dataset): 7 | def __init__(self, data_root, split='train'): 8 | super(AVEDataset, self).__init__() 9 | self.split = split 10 | self.visual_feature_path = os.path.join(data_root, 'visual_feature.h5') 11 | self.audio_feature_path = os.path.join(data_root, 'audio_feature.h5') 12 | # Now for the supervised task 13 | self.labels_path = os.path.join(data_root, 'labels.h5') 14 | self.sample_order_path = os.path.join(data_root, f'{split}_order.h5') 15 | self.h5_isOpen = False 16 | 17 | def __getitem__(self, index): 18 | if not self.h5_isOpen: 19 | self.visual_feature = h5py.File(self.visual_feature_path, 'r')['avadataset'] 20 | self.audio_feature = h5py.File(self.audio_feature_path, 'r')['avadataset'] 21 | self.labels = h5py.File(self.labels_path, 'r')['avadataset'] 22 | self.sample_order = h5py.File(self.sample_order_path, 'r')['order'] 23 | self.h5_isOpen = True 24 | sample_index = self.sample_order[index] 25 | visual_feat = self.visual_feature[sample_index] 26 | audio_feat = self.audio_feature[sample_index] 27 | label = self.labels[sample_index] 28 | 29 | return visual_feat, audio_feat, label 30 | 31 | 32 | def __len__(self): 33 | f = h5py.File(self.sample_order_path, 'r') 34 | sample_num = len(f['order']) 35 | f.close() 36 | return sample_num 37 | 38 | -------------------------------------------------------------------------------- /code/src/dataset/AVVP_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import Dataset 6 | import pickle 7 | import zipfile 8 | from io import BytesIO 9 | 10 | def generate_category_list(): 11 | file_path = 'AVVP/data/AVVP_Categories.txt' 12 | category_list = [] 13 | with open(file_path, 'r') as fr: 14 | for line in fr.readlines(): 15 | category_list.append(line.strip()) 16 | return category_list 17 | 18 | 19 | class AVVPDataset(Dataset): 20 | # for AVEL task 21 | def __init__(self, meta_csv_path, fea_base_path, split='train', modality='video'): 22 | super(AVVPDataset, self).__init__() 23 | self.modality = modality 24 | self.fea_base_path = fea_base_path 25 | self.split_df = pd.read_csv(meta_csv_path,sep='\t') 26 | self.all_categories = generate_category_list() 27 | print(f'total {len(self.all_categories)} positive classes in AVVP, 1 negative classes in AVVP') 28 | print(f'{len(self.split_df)} samples are used for {split}') 29 | 30 | def __getitem__(self, index): 31 | one_video_df = self.split_df.iloc[index] 32 | categorys, video_id = one_video_df['event_labels'].split(','), one_video_df['filename'] 33 | onsets, offsets = one_video_df['onset'].split(','), one_video_df['offset'].split(',') 34 | onsets = list(map(int, onsets)) 35 | offsets = list(map(int, offsets)) 36 | fea = self._load_fea(self.fea_base_path, video_id[:11]) 37 | 38 | if(self.modality=='audio'): 39 | if fea.shape[0] < 10: 40 | cur_t = fea.shape[0] 41 | add_arr = np.tile(fea[-1, :], (10-cur_t, 1)) 42 | fea = np.concatenate([fea, add_arr], axis=0) 43 | elif fea.shape[0] > 10: 44 | fea = fea[:10, :] 45 | 46 | avel_label = self._obtain_avel_label(onsets, offsets, categorys) # [10,26] 47 | 48 | return torch.from_numpy(fea), \ 49 | torch.from_numpy(avel_label), \ 50 | video_id 51 | 52 | def _load_fea(self, fea_base_path, video_id): 53 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 54 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 55 | for name in zfile.namelist(): 56 | if '.pkl' not in name: 57 | continue 58 | with zfile.open(name, mode='r') as fea_file: 59 | content = BytesIO(fea_file.read()) 60 | fea = pickle.load(content) 61 | return fea 62 | 63 | def _obtain_avel_label(self, onsets, offsets, categorys):# avc_label: [1, 10] 64 | T, category_num = 10, len(self.all_categories) 65 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 25+1] 66 | label[:, -1] = np.ones(T) 67 | iter_num = len(categorys) 68 | for i in range(iter_num): 69 | avc_label = np.zeros(T) 70 | avc_label[onsets[i]:offsets[i]] = 1 71 | class_id = self.all_categories.index(categorys[i]) 72 | bg_flag = 1 - avc_label 73 | for j in range(10): 74 | label[j, class_id] = int(label[j, class_id]) | int(avc_label[j]) 75 | for j in range(10): 76 | label[j, -1] = int(label[j, -1]) & int(bg_flag[j]) 77 | return label 78 | 79 | def __len__(self,): 80 | return len(self.split_df) 81 | 82 | 83 | class AVVPDatasetTrain(Dataset): 84 | # for AVEL task 85 | def __init__(self, meta_csv_path, fea_base_path, split='train', modality='video'): 86 | super(AVVPDatasetTrain, self).__init__() 87 | self.modality = modality 88 | self.fea_base_path = fea_base_path 89 | self.split_df = pd.read_csv(meta_csv_path, sep='\t') 90 | self.all_categories = generate_category_list() 91 | print(f'total {len(self.all_categories)} classes in AVVPTrain') 92 | print(f'{len(self.split_df)} samples are used for Train') 93 | 94 | def __getitem__(self, index): 95 | one_video_df = self.split_df.iloc[index] 96 | categorys, video_id = one_video_df['event_labels'].split(','), one_video_df['filename'] 97 | fea = self._load_fea(self.fea_base_path, video_id[:11]) 98 | if(self.modality=='audio'): 99 | if fea.shape[0] < 10: 100 | cur_t = fea.shape[0] 101 | add_arr = np.tile(fea[-1, :], (10-cur_t, 1)) 102 | fea = np.concatenate([fea, add_arr], axis=0) 103 | elif fea.shape[0] > 10: 104 | fea = fea[:10, :] 105 | 106 | avc_label = np.ones(10) # [10,1] 107 | avel_label = self._obtain_avel_label(avc_label, categorys) # [10,26] 108 | 109 | return torch.from_numpy(fea), \ 110 | torch.from_numpy(avel_label) 111 | 112 | def _load_fea(self, fea_base_path, video_id): 113 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 114 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 115 | for name in zfile.namelist(): 116 | if '.pkl' not in name: 117 | continue 118 | with zfile.open(name, mode='r') as fea_file: 119 | content = BytesIO(fea_file.read()) 120 | fea = pickle.load(content) 121 | return fea 122 | 123 | 124 | def _obtain_avel_label(self, avc_label, categorys):# avc_label: [1, 10] 125 | T, category_num = 10, len(self.all_categories) 126 | 127 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 25+1] 128 | for category in categorys: 129 | class_id = self.all_categories.index(category) 130 | bg_flag = 1 - avc_label 131 | label[:, class_id] = avc_label 132 | label[:, -1] = bg_flag 133 | 134 | return label 135 | 136 | def __len__(self,): 137 | return len(self.split_df) 138 | 139 | class AVVPDatasetEval(Dataset): 140 | # for AVEL task 141 | def __init__(self, meta_csv_path, fea_base_path, split='train', modality='video'): 142 | super(AVVPDatasetEval, self).__init__() 143 | self.modality = modality 144 | self.fea_base_path = fea_base_path 145 | self.split_df = pd.read_csv(meta_csv_path) 146 | self.all_categories = generate_category_list() 147 | print(f'total {len(self.all_categories)} classes in AVVPEval') 148 | print(f'{len(self.split_df)} samples are used for Eval') 149 | 150 | def __getitem__(self, index): 151 | one_video_df = self.split_df.iloc[index] 152 | category, video_id = one_video_df['event_labels'], one_video_df['filename'] 153 | onset, offset = one_video_df['onset'].astype(int), one_video_df['offset'].astype(int) 154 | 155 | fea = self._load_fea(self.fea_base_path, video_id[:11]) 156 | 157 | if(self.modality=='audio'): 158 | if fea.shape[0] < 10: 159 | cur_t = fea.shape[0] 160 | add_arr = np.tile(fea[-1, :], (10-cur_t, 1)) 161 | fea = np.concatenate([fea, add_arr], axis=0) 162 | elif fea.shape[0] > 10: 163 | fea = fea[:10, :] 164 | 165 | fea = fea[onset:offset, :] 166 | 167 | avc_label = np.ones(offset-onset) # [offset-onset,1] 168 | avel_label = self._obtain_avel_label(onset, offset, avc_label, category) # [offset-onset,26] 169 | sample = {'feature': torch.from_numpy(fea), 'label': torch.from_numpy(avel_label), 'length':offset-onset} 170 | 171 | return sample 172 | 173 | 174 | def _load_fea(self, fea_base_path, video_id): 175 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 176 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 177 | for name in zfile.namelist(): 178 | if '.pkl' not in name: 179 | continue 180 | with zfile.open(name, mode='r') as fea_file: 181 | content = BytesIO(fea_file.read()) 182 | fea = pickle.load(content) 183 | return fea 184 | 185 | 186 | def _obtain_avel_label(self, onset, offset, avc_label, category): 187 | # avc_label: [1, 10] 188 | class_id = self.all_categories.index(category) 189 | T, category_num = offset-onset, len(self.all_categories) 190 | label = np.zeros((T, category_num + 1)) 191 | bg_flag = 1 - avc_label 192 | 193 | label[:, class_id] = avc_label 194 | label[:, -1] = bg_flag 195 | 196 | return label 197 | 198 | def __len__(self,): 199 | return len(self.split_df) -------------------------------------------------------------------------------- /code/src/dataset/UCF_VGGSOUND_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import pandas as pd 5 | import pickle 6 | import zipfile 7 | from io import BytesIO 8 | import numpy as np 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | 12 | def generate_category_list_vgg2ucf(): 13 | file_path = 'VGGSoundsameUCF101.txt' 14 | category_list = [] 15 | with open(file_path, 'r') as fr: 16 | for line in fr.readlines(): 17 | category_list.append(line.strip()) 18 | return category_list 19 | 20 | def generate_category_list_ucf2vgg(): 21 | file_path = 'UCF101sameVGGSound.txt' 22 | category_list = [] 23 | with open(file_path, 'r') as fr: 24 | for line in fr.readlines(): 25 | category_list.append(line.strip()) 26 | return category_list 27 | 28 | class VGGSoundDataset(Dataset): 29 | def __init__(self, meta_csv_path, fea_base_path, split=None, modality=None): 30 | super(VGGSoundDataset, self).__init__() 31 | self.modality = modality 32 | self.fea_base_path = fea_base_path 33 | self.split_df = pd.read_csv(meta_csv_path,sep=',') 34 | self.all_categories = generate_category_list_vgg2ucf() 35 | 36 | def __getitem__(self, index): 37 | one_video_df = self.split_df.iloc[index] 38 | video_id = one_video_df['video_id'] 39 | category = one_video_df['category'] 40 | 41 | audio_fea = self._load_fea(self.fea_base_path, video_id) # [10, 128] 42 | avc_label = np.ones(10) 43 | avel_label = self._obtain_avel_label(avc_label, category) 44 | 45 | if self.modality=='audio': 46 | if audio_fea.shape[0] < 10: 47 | cur_t = audio_fea.shape[0] 48 | add_arr = np.tile(audio_fea[-1, :], (10-cur_t, 1)) 49 | audio_fea = np.concatenate([audio_fea, add_arr], axis=0) 50 | elif audio_fea.shape[0] > 10: 51 | audio_fea = audio_fea[:10, :] 52 | audio_fea = audio_fea.astype(np.float64) 53 | 54 | return torch.from_numpy(audio_fea), torch.from_numpy(avel_label) 55 | 56 | def _load_fea(self, fea_base_path, video_id): 57 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 58 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 59 | for name in zfile.namelist(): 60 | if '.pkl' not in name: 61 | continue 62 | with zfile.open(name, mode='r') as fea_file: 63 | content = BytesIO(fea_file.read()) 64 | fea = pickle.load(content) 65 | return fea 66 | 67 | def _obtain_avel_label(self, avc_label, category): 68 | # avc_label: [1, 10] 69 | class_id = self.all_categories.index(category) 70 | T, category_num = 10, len(self.all_categories) 71 | 72 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 141+1] 73 | bg_flag = 1 - avc_label 74 | label[:, class_id] = avc_label 75 | label[:, -1] = bg_flag 76 | return label 77 | 78 | def __len__(self,): 79 | return len(self.split_df) 80 | 81 | class UCFDataset(Dataset): 82 | def __init__(self, meta_csv_path, fea_base_path, split=None, modality=None): 83 | super(UCFDataset, self).__init__() 84 | self.modality = modality 85 | self.fea_base_path = fea_base_path 86 | self.split_df = pd.read_csv(meta_csv_path,sep=',') 87 | self.all_categories = generate_category_list_ucf2vgg() 88 | 89 | def __getitem__(self, index): 90 | one_video_df = self.split_df.iloc[index] 91 | video_id = one_video_df['video_id'] 92 | category = one_video_df['category'] 93 | 94 | video_fea = self._load_fea(self.fea_base_path, video_id) # [10, 7, 7, 512] 95 | avc_label = np.ones(10) 96 | avel_label = self._obtain_avel_label(avc_label, category) # [10,17] 97 | 98 | if video_fea.shape[0] < 10: 99 | cur_t = video_fea.shape[0] 100 | add_arr = np.tile(video_fea[-1, :], (10-cur_t,1,1,1)) 101 | video_fea = np.concatenate([video_fea, add_arr], axis=0) 102 | elif video_fea.shape[0] > 10: 103 | video_fea = video_fea[:10, :, :, :] 104 | 105 | video_fea = video_fea.astype(np.float64) 106 | 107 | return torch.from_numpy(video_fea), torch.from_numpy(avel_label) 108 | 109 | def _load_fea(self, fea_base_path, video_id): 110 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 111 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 112 | for name in zfile.namelist(): 113 | if '.pkl' not in name: 114 | continue 115 | with zfile.open(name, mode='r') as fea_file: 116 | content = BytesIO(fea_file.read()) 117 | fea = pickle.load(content) 118 | return fea 119 | 120 | def _obtain_avel_label(self, avc_label, category): 121 | # avc_label: [1, 10] 122 | class_id = self.all_categories.index(category) 123 | T, category_num = 10, len(self.all_categories) 124 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 141+1] 125 | bg_flag = 1 - avc_label 126 | label[:, class_id] = avc_label 127 | label[:, -1] = bg_flag 128 | return label 129 | 130 | def __len__(self,): 131 | return len(self.split_df) -------------------------------------------------------------------------------- /code/src/dataset/VGGSOUND_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from torchvision import transforms, models 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import pickle 11 | import zipfile 12 | from io import BytesIO 13 | import pdb 14 | import random 15 | SEED = 57 16 | random.seed(SEED) 17 | 18 | def generate_category_list(): 19 | file_path = 'VggsoundAVEL40kCategories.txt' 20 | category_list = [] 21 | with open(file_path, 'r') as fr: 22 | for line in fr.readlines(): 23 | category_list.append(line.strip()) 24 | return category_list 25 | 26 | # AV 27 | class VGGSoundDataset(Dataset): 28 | def __init__(self, meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='train'): 29 | super(VGGSoundDataset, self).__init__() 30 | self.audio_fea_base_path = audio_fea_base_path 31 | self.video_fea_base_path = video_fea_base_path 32 | self.avc_label_base_path = avc_label_base_path 33 | all_df = pd.read_csv(meta_csv_path) 34 | # train = 24k 35 | # train + test + val = 40k 36 | df_train = all_df[all_df['split'] == 'train'] 37 | df_test = all_df[all_df['split'] == 'test'] 38 | df_val = all_df[all_df['split'] == 'val'] 39 | self.split_df = pd.concat([df_train,df_test,df_val]) 40 | # Output the proportion of train, test, and valid. 41 | print(f'{len(self.split_df)}/{len(all_df)} videos are used for {split}') 42 | self.all_categories = generate_category_list() 43 | print(f'total {len(self.all_categories)} classes in VggsoundAVEL40k') 44 | 45 | def __getitem__(self, index): 46 | one_video_df = self.split_df.iloc[index] 47 | category, video_id = one_video_df['category'], one_video_df['video_id'] 48 | 49 | audio_fea = self._load_fea(self.audio_fea_base_path, video_id) # [10, 128] 50 | video_fea = self._load_fea(self.video_fea_base_path, video_id) # [10, 7, 7, 512] 51 | avc_label = self._load_fea(self.avc_label_base_path, video_id) # [10,1] 52 | avel_label = self._obtain_avel_label(avc_label, category) # [10,142] 53 | 54 | if audio_fea.shape[0] < 10: 55 | cur_t = audio_fea.shape[0] 56 | add_arr = np.tile(audio_fea[-1, :], (10-cur_t, 1)) 57 | audio_fea = np.concatenate([audio_fea, add_arr], axis=0) 58 | elif audio_fea.shape[0] > 10: 59 | audio_fea = audio_fea[:10, :] 60 | 61 | return torch.from_numpy(video_fea), \ 62 | torch.from_numpy(audio_fea), \ 63 | torch.from_numpy(avel_label) 64 | 65 | 66 | def _load_fea(self, fea_base_path, video_id): 67 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 68 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 69 | for name in zfile.namelist(): 70 | if '.pkl' not in name: 71 | continue 72 | with zfile.open(name, mode='r') as fea_file: 73 | content = BytesIO(fea_file.read()) 74 | fea = pickle.load(content) 75 | return fea 76 | 77 | 78 | def _obtain_avel_label(self, avc_label, category): 79 | # avc_label: [1, 10] 80 | class_id = self.all_categories.index(category) 81 | T, category_num = 10, len(self.all_categories) 82 | 83 | label = np.zeros((T, category_num + 1)) # add 'background' category [10, 141+1] 84 | bg_flag = 1 - avc_label 85 | 86 | label[:, class_id] = avc_label 87 | label[:, -1] = bg_flag 88 | return label 89 | 90 | def __len__(self,): 91 | return len(self.split_df) 92 | 93 | # AT 94 | class VGGSoundDataset_AT(Dataset): 95 | 96 | def __init__(self, meta_csv_path, audio_fea_base_path, split='train'): 97 | super(VGGSoundDataset_AT, self).__init__() 98 | self.label2prompt = pd.read_csv('vggsoundCategories2Prompts.csv') 99 | self.audio_fea_base_path = audio_fea_base_path 100 | all_df = pd.read_csv(meta_csv_path) 101 | 102 | df_train = all_df[all_df['split'] == 'train'] 103 | df_test = all_df[all_df['split'] == 'test'] 104 | df_val = all_df[all_df['split'] == 'val'] 105 | self.split_df = pd.concat([df_train,df_test,df_val]) 106 | 107 | print(f'{len(self.split_df)}/{len(all_df)} audios are used for {split}') 108 | self.all_categories = generate_category_list() 109 | print(f'total {len(self.all_categories)} classes in Vggsound40K_AT') 110 | 111 | def __getitem__(self, index): 112 | one_video_df = self.split_df.iloc[index] 113 | category, audio_id = one_video_df['category'], one_video_df['video_id'] 114 | 115 | audio_fea = self._load_fea(self.audio_fea_base_path, audio_id) # [10, 128] 116 | text_fea = self.label2prompt.loc[self.label2prompt['label'] == category].values[0][1] 117 | 118 | if audio_fea.shape[0] < 10: 119 | cur_t = audio_fea.shape[0] 120 | add_arr = np.tile(audio_fea[-1, :], (10-cur_t, 1)) 121 | audio_fea = np.concatenate([audio_fea, add_arr], axis=0) 122 | elif audio_fea.shape[0] > 10: 123 | audio_fea = audio_fea[:10, :] 124 | 125 | sample = {'audio_fea': audio_fea, 'text_fea': text_fea} 126 | return sample 127 | 128 | def _load_fea(self, fea_base_path, audio_id): 129 | fea_path = os.path.join(fea_base_path, "%s.zip"%audio_id) 130 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 131 | for name in zfile.namelist(): 132 | if '.pkl' not in name: 133 | continue 134 | with zfile.open(name, mode='r') as fea_file: 135 | content = BytesIO(fea_file.read()) 136 | fea = pickle.load(content) 137 | return fea 138 | 139 | def __len__(self,): 140 | return len(self.split_df) 141 | 142 | 143 | #AVT 144 | class VGGSoundDataset_AVT(Dataset): 145 | def __init__(self, meta_csv_path, audio_fea_base_path, video_fea_base_path, split='train'): 146 | super(VGGSoundDataset_AVT, self).__init__() 147 | self.label2prompt = pd.read_csv('vggsoundCategories2Prompts.csv') 148 | self.audio_fea_base_path = audio_fea_base_path 149 | self.video_fea_base_path = video_fea_base_path 150 | all_df = pd.read_csv(meta_csv_path) 151 | 152 | df_train = all_df[all_df['split'] == 'train'] 153 | df_test = all_df[all_df['split'] == 'test'] 154 | df_val = all_df[all_df['split'] == 'val'] 155 | self.split_df = pd.concat([df_train,df_test,df_val]) 156 | 157 | print(f'{len(self.split_df)}/{len(all_df)} samples are used for {split}') 158 | self.all_categories = generate_category_list() 159 | print(f'total {len(self.all_categories)} classes in Vggsound40K_AVT') 160 | 161 | def __getitem__(self, index): 162 | one_video_df = self.split_df.iloc[index] 163 | category, audio_id = one_video_df['category'], one_video_df['video_id'] 164 | 165 | audio_fea = self._load_fea(self.audio_fea_base_path, audio_id) # [10, 128] 166 | video_fea = self._load_fea(self.video_fea_base_path, audio_id) # [10, 7, 7, 512] 167 | text_fea = self.label2prompt.loc[self.label2prompt['label'] == category].values[0][1] 168 | 169 | if audio_fea.shape[0] < 10: 170 | cur_t = audio_fea.shape[0] 171 | add_arr = np.tile(audio_fea[-1, :], (10-cur_t, 1)) 172 | audio_fea = np.concatenate([audio_fea, add_arr], axis=0) 173 | elif audio_fea.shape[0] > 10: 174 | audio_fea = audio_fea[:10, :] 175 | 176 | sample = {'video_fea': video_fea, 'audio_fea': audio_fea, 'text_fea': text_fea} 177 | return sample 178 | 179 | def _load_fea(self, fea_base_path, audio_id): 180 | fea_path = os.path.join(fea_base_path, "%s.zip"%audio_id) 181 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 182 | for name in zfile.namelist(): 183 | if '.pkl' not in name: 184 | continue 185 | with zfile.open(name, mode='r') as fea_file: 186 | content = BytesIO(fea_file.read()) 187 | fea = pickle.load(content) 188 | return fea 189 | 190 | def __len__(self,): 191 | return len(self.split_df) -------------------------------------------------------------------------------- /code/src/dataset/VGGSOUND_dataset179k.py: -------------------------------------------------------------------------------- 1 | """used for train 81k""" 2 | import os 3 | import h5py 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, models 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import pickle 12 | import zipfile 13 | from io import BytesIO 14 | import pdb 15 | import csv 16 | 17 | # The number of categories is the same for 81k and 40k. 18 | def generate_category_list(): 19 | file_path = 'VggsoundAVEL40kCategories.txt' 20 | category_list = [] 21 | with open(file_path, 'r') as fr: 22 | for line in fr.readlines(): 23 | category_list.append(line.strip()) 24 | return category_list 25 | 26 | 27 | class VGGSoundDataset(Dataset): 28 | def __init__(self, meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='train'): 29 | super(VGGSoundDataset, self).__init__() 30 | self.audio_fea_base_path = audio_fea_base_path 31 | self.video_fea_base_path = video_fea_base_path 32 | self.avc_label_base_path = avc_label_base_path 33 | all_df = pd.read_csv(meta_csv_path) 34 | self.split_df = all_df 35 | # Output the proportion of train, test, and valid. 36 | print(f'{len(self.split_df)}/{len(all_df)} videos are used for {split}') 37 | self.all_categories = generate_category_list() 38 | print(f'total {len(self.all_categories)} classes in VggsoundAVEL81k') 39 | 40 | 41 | def __getitem__(self, index): 42 | one_video_df = self.split_df.iloc[index] 43 | video_id = one_video_df['id'][:-4]# drop '.mp4' 44 | 45 | audio_fea = self._load_fea(self.audio_fea_base_path, video_id) # [10, 128] 46 | video_fea = self._load_fea(self.video_fea_base_path, video_id) # [10, 7, 7, 512] 47 | 48 | if audio_fea.shape[0] < 10: 49 | cur_t = audio_fea.shape[0] 50 | add_arr = np.tile(audio_fea[-1, :], (10-cur_t, 1)) 51 | audio_fea = np.concatenate([audio_fea, add_arr], axis=0) 52 | elif audio_fea.shape[0] > 10: 53 | audio_fea = audio_fea[:10, :] 54 | 55 | return torch.from_numpy(video_fea), \ 56 | torch.from_numpy(audio_fea) 57 | 58 | def _load_fea(self, fea_base_path, video_id): 59 | fea_path = os.path.join(fea_base_path, "%s.zip"%video_id) 60 | with zipfile.ZipFile(fea_path, mode='r') as zfile: 61 | for name in zfile.namelist(): 62 | if '.pkl' not in name: 63 | continue 64 | with zfile.open(name, mode='r') as fea_file: 65 | content = BytesIO(fea_file.read()) 66 | fea = pickle.load(content) 67 | return fea 68 | 69 | def __len__(self,): 70 | return len(self.split_df) -------------------------------------------------------------------------------- /code/src/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/src/dataset/__init__.py -------------------------------------------------------------------------------- /code/src/model/CLUB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CLUBSample_group(nn.Module): # Sampled version of the CLUB estimator 7 | def __init__(self, x_dim, y_dim, hidden_size): 8 | super(CLUBSample_group, self).__init__() 9 | self.x_dim = x_dim 10 | self.y_dim = y_dim 11 | self.hidden_size = hidden_size 12 | self.p_mu = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size // 2), 13 | nn.ReLU(), 14 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 15 | nn.ReLU(), 16 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 17 | nn.ReLU(), 18 | nn.Linear(self.hidden_size // 2, self.y_dim)) 19 | 20 | self.p_logvar = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size // 2), 21 | nn.ReLU(), 22 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 23 | nn.ReLU(), 24 | nn.Linear(self.hidden_size // 2, self.hidden_size // 2), 25 | nn.ReLU(), 26 | nn.Linear(self.hidden_size // 2, self.y_dim), 27 | nn.Tanh()) 28 | 29 | def get_mu_logvar(self, x_samples): 30 | mu = self.p_mu(x_samples) 31 | logvar = self.p_logvar(x_samples) 32 | return mu, logvar 33 | 34 | def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood 35 | mu, logvar = self.get_mu_logvar(x_samples) # mu/logvar: (bs, y_dim) 36 | # mu = mu.unsqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, mu.shape[ 37 | # -1]) # (bs, y_dim) -> (bs, 1, y_dim) -> (bs, T, y_dim) -> (bs*T, y_dim) 38 | mu = mu.reshape(-1, mu.shape[-1]) 39 | #logvar = logvar.unsqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, logvar.shape[-1]) 40 | logvar = logvar.reshape(-1, logvar.shape[-1]) 41 | y_samples = y_samples.reshape(-1, y_samples.shape[-1]) # (bs, T, y_dim) -> (bs*T, y_dim) 42 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) / 2 43 | 44 | def mi_est(self, x_samples, y_samples): # x_samples: (bs, x_dim); y_samples: (bs, T, y_dim) 45 | 46 | mu, logvar = self.get_mu_logvar(x_samples) 47 | 48 | sample_size = x_samples.shape[0] 49 | # random_index = torch.randint(sample_size, (sample_size,)).long() 50 | random_index = torch.randperm(sample_size).long() 51 | 52 | # log of conditional probability of positive sample pairs 53 | #mu_exp1 = mu.unsqueeze(1).expand(-1, y_samples.shape[1], -1) # (bs, y_dim) -> (bs, T, y_dim) 54 | mu_exp1 = mu 55 | 56 | # logvar_exp1 = logvar.unqueeze(1).expand(-1, y_samples.shape[1], -1).reshape(-1, logvar.shape[-1]) 57 | positive = - ((mu_exp1 - y_samples) ** 2).mean(dim=1) / logvar.mean(dim=1).exp() # mean along T 58 | negative = - ((mu_exp1 - y_samples[random_index]) ** 2).mean(dim=1) / logvar.mean(dim=1).exp() # mean along T 59 | 60 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() / 2 61 | -------------------------------------------------------------------------------- /code/src/model/Dual_lstm.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import copy 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn import Module 7 | from torch.nn import MultiheadAttention 8 | from torch.nn import ModuleList 9 | from torch.nn.init import xavier_uniform_ 10 | from torch.nn import Dropout 11 | from torch.nn import Linear 12 | from torch.nn import LayerNorm 13 | import math 14 | from torch.autograd import Variable 15 | 16 | 17 | 18 | class Dual_lstm_cell(nn.Module): 19 | def __init__(self, visual_input_dim, audio_input_dim, hidden_dim, alph=0.5, bias=True): 20 | super(Dual_lstm_cell, self).__init__() 21 | 22 | self.visual_input_dim = visual_input_dim 23 | self.audio_input_dim = audio_input_dim 24 | self.hidden_dim = hidden_dim 25 | self.alph = alph 26 | self.vs_linear = nn.Linear(self.visual_input_dim, 4 * self.hidden_dim, bias=bias) 27 | self.vh_linear = nn.Linear(self.hidden_dim, 4* self.hidden_dim, bias=bias) 28 | self.as_linear = nn.Linear(self.audio_input_dim, 4 * self.hidden_dim, bias=bias) 29 | self.ah_linear = nn.Linear(self.hidden_dim, 4 * self.hidden_dim, bias=bias) 30 | 31 | self.as_linear2 = nn.Linear(self.audio_input_dim, 4*self.hidden_dim, bias=bias) 32 | self.ah_linear2 = nn.Linear(self.hidden_dim, 4*self.hidden_dim, bias=bias) 33 | self.vs_linear2 = nn.Linear(self.visual_input_dim, 4*self.hidden_dim, bias=bias) 34 | self.vh_linear2 = nn.Linear(self.hidden_dim, 4*self.hidden_dim, bias=bias) 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | std = 1.0 / math.sqrt(self.hidden_dim) 39 | for w in self.parameters(): 40 | w.data.uniform_(-std, std) 41 | 42 | def forward(self, visual_state, visual_hidden, visual_cell, audio_state, audio_hidden, audio_cell): 43 | visual_gates = self.vs_linear(visual_state) + self.vh_linear(visual_hidden) 44 | #self.alph*self.as_linear(audio_state) + self.alph*self.ah_linear(audio_hidden) 45 | 46 | 47 | audio_gates = self.as_linear2(audio_state) + self.ah_linear2(audio_hidden) 48 | #self.alph*self.vs_linear2(visual_state) + self.alph*self.vh_linear2(visual_hidden) 49 | 50 | visual_i_gate, visual_f_gate, visual_c_gate, visual_o_gate = visual_gates.chunk(4,1) 51 | audio_i_gate, audio_f_gate, audio_c_gate, audio_o_gate = audio_gates.chunk(4,1) 52 | 53 | visual_i_gate = F.sigmoid(visual_i_gate) 54 | visual_f_gate = F.sigmoid(visual_f_gate) 55 | visual_c_gate = F.tanh(visual_c_gate) 56 | visual_o_gate = F.sigmoid(visual_o_gate) 57 | 58 | visual_cell = visual_f_gate * visual_cell + visual_i_gate * visual_c_gate 59 | visual_output = visual_o_gate * torch.tanh(visual_cell) 60 | 61 | audio_i_gate = F.sigmoid(audio_i_gate) 62 | audio_f_gate = F.sigmoid(audio_f_gate) 63 | audio_c_gate = F.tanh(audio_c_gate) 64 | audio_o_gate = F.sigmoid(audio_o_gate) 65 | 66 | audio_cell = audio_f_gate * audio_cell + audio_i_gate * audio_c_gate 67 | audio_output = audio_o_gate * torch.tanh(audio_cell) 68 | 69 | return visual_output, visual_cell, audio_output, audio_cell 70 | 71 | class Dual_lstm(nn.Module): 72 | def __init__(self): 73 | 74 | super(Dual_lstm, self).__init__() 75 | 76 | self.video_input_dim = 512 77 | self.video_fc_dim = 512 78 | self.d_model = 256 79 | self.v_fc = nn.Linear(self.video_input_dim, self.video_fc_dim) 80 | self.LSTM_cell = Dual_lstm_cell(visual_input_dim=512, audio_input_dim=128, hidden_dim=256) 81 | #self.LSTM_cell_r = Dual_lstm_cell(visual_input_dim=512, audio_input_dim=128, hidden_dim=256) 82 | 83 | 84 | self.relu = nn.ReLU() 85 | self.dropout = nn.Dropout(0.2) 86 | 87 | 88 | def forward(self, audio_feature, visual_feature): 89 | audio_rnn_input = audio_feature 90 | 91 | visual_rnn_input = visual_feature 92 | 93 | if torch.cuda.is_available(): 94 | visual_hidden = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 95 | visual_hidden_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 96 | else: 97 | visual_hidden = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 98 | visual_hidden_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 99 | 100 | if torch.cuda.is_available(): 101 | visual_cell = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 102 | visual_cell_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model).cuda()) 103 | else: 104 | visual_cell = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 105 | visual_cell_r = Variable(torch.zeros(visual_rnn_input.size(0), self.d_model)) 106 | 107 | if torch.cuda.is_available(): 108 | audio_hidden = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 109 | audio_hidden_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 110 | else: 111 | audio_hidden = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 112 | audio_hidden_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 113 | 114 | if torch.cuda.is_available(): 115 | audio_cell = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 116 | audio_cell_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model).cuda()) 117 | else: 118 | audio_cell = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 119 | audio_cell_r = Variable(torch.zeros(audio_rnn_input.size(0), self.d_model)) 120 | 121 | visual_output = [] 122 | audio_output = [] 123 | visual_output_r = [] 124 | audio_output_r = [] 125 | length = visual_rnn_input.size(1) 126 | 127 | visual_hidden = visual_hidden.double() 128 | visual_cell = visual_cell.double() 129 | audio_hidden = audio_hidden.double() 130 | audio_cell = audio_cell.double() 131 | visual_hidden_r = visual_hidden_r.double() 132 | visual_cell_r = visual_cell_r.double() 133 | audio_hidden_r = audio_hidden_r.double() 134 | audio_cell_r = audio_cell_r.double() 135 | 136 | 137 | for i in range(length): 138 | visual_hidden, visual_cell, audio_hidden, audio_cell = self.LSTM_cell(visual_rnn_input[:,i,:], visual_hidden, visual_cell, 139 | audio_rnn_input[:,i,:], audio_hidden, audio_cell) 140 | visual_output.append(visual_hidden) 141 | audio_output.append(audio_hidden) 142 | 143 | visual_output = torch.stack(visual_output,dim=1) 144 | audio_output = torch.stack(audio_output, dim=1) 145 | 146 | 147 | # for i in range(length): 148 | # visual_hidden_r, visual_cell_r, audio_hidden_r, audio_cell_r = self.LSTM_cell_r(visual_rnn_input[:,length-1-i,:], visual_hidden_r, 149 | # visual_cell_r, audio_rnn_input[:,length-1-i,:], 150 | # audio_hidden_r, audio_cell_r) 151 | # visual_output_r.append(visual_hidden_r) 152 | # audio_output_r.append(audio_hidden_r) 153 | 154 | # visual_output_r = torch.stack(visual_output_r, dim=1) 155 | # visual_output_r = torch.flip(visual_output_r, dims=[1]) 156 | # audio_output_r = torch.stack(audio_output_r, dim=1) 157 | # audio_output_r = torch.flip(audio_output_r, dims=[1]) 158 | # visual_output = torch.cat((visual_output, visual_output_r), dim=2) 159 | # audio_output = torch.cat((audio_output, audio_output_r), dim=2) 160 | return audio_output, visual_output 161 | 162 | 163 | # model = Dual_lstm() 164 | # visual_feature = torch.randn(32, 10,512) 165 | # audio_feature = torch.randn(32, 10, 128) 166 | # model(audio_feature, visual_feature) 167 | # 168 | -------------------------------------------------------------------------------- /code/src/model/UniEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.transformer import TransformerEncoder 4 | 5 | class UNIEncoder(nn.Module): 6 | def __init__(self): 7 | super(UNIEncoder, self).__init__() 8 | # define transformer head 9 | self.tx = TransformerEncoder(d_model=256, 10 | d_kv=64, 11 | d_ff=4096, 12 | num_layers=24, 13 | num_heads=16, 14 | pre_norm=True, 15 | use_bias=True, 16 | activation="gelu", 17 | dropout_rate=0.1, 18 | layer_norm_epsilon=1e-6) 19 | 20 | # define post-tx projection head - it could be logits or embd space 21 | self.post_proj = nn.ModuleDict({# ReLU or GELU 22 | "video": nn.Sequential(nn.Linear(256, 256),nn.GELU()),#d_model=256 d_post_proj=256 23 | "audio": nn.Sequential(nn.Linear(256, 256),nn.GELU()) 24 | }) 25 | 26 | def _flatten_inputs(self, 27 | inputs): 28 | input_shape = inputs.shape 29 | bs = inputs.shape[0] 30 | d_embd = inputs.shape[-1] 31 | inputs = inputs.view(bs, -1, d_embd) 32 | 33 | return inputs, input_shape 34 | 35 | def _append_special_tokens(self, 36 | inputs, 37 | modality): 38 | batch_size = inputs.shape[0] 39 | agg_token = { 40 | "video": torch.nn.Parameter(torch.Tensor(256,)),#d_model 41 | "audio": torch.nn.Parameter(torch.Tensor(256,)), 42 | } 43 | special_embd = agg_token[modality][None, None, :].to(inputs.device) 44 | special_embd = special_embd.repeat(batch_size, 1, 1) 45 | 46 | return torch.cat([special_embd, inputs], dim=1) 47 | 48 | def _extend_attn_mask(self, 49 | attention_mask): 50 | attn_mask_shape = attention_mask.shape 51 | if len(attn_mask_shape) > 2: 52 | raise NotImplementedError 53 | 54 | batch_size = attn_mask_shape[0] 55 | extention_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype) 56 | extended_attention_mask = torch.cat([extention_mask, attention_mask], dim=1) 57 | return extended_attention_mask 58 | 59 | def _modality_call(self, 60 | inputs, 61 | modality, 62 | training=False, 63 | attention_mask=None, 64 | input_shape=None): 65 | embeddings = inputs 66 | if input_shape is None: 67 | embeddings, input_shape = self._flatten_inputs(embeddings) 68 | 69 | # print("pool:",embeddings) 70 | # print(features) 71 | 72 | # append modalities special tokens: [vid, aud, txt] 73 | tx_inputs = self._append_special_tokens(embeddings, modality) 74 | print("pool:",embeddings) 75 | 76 | # extend attention_mask accordingly 77 | if attention_mask is not None: 78 | attention_mask = self._extend_attn_mask(attention_mask) 79 | 80 | # call Transformer 81 | tx_outputs = self.tx(tx_inputs, attention_mask) 82 | 83 | # get last hidden states and perform final linear projection 84 | last_hidden_states = tx_outputs["hidden_states"][-1] 85 | modality_outputs = self.post_proj[modality](last_hidden_states) 86 | output_shape = list(input_shape[:-1]) + [modality_outputs.shape[-1]] 87 | # output_shape = list(256) + [modality_outputs.shape[-1]] 88 | 89 | features_pooled = modality_outputs[:, 0, :] 90 | features = modality_outputs[:, 1:, :].reshape(output_shape) 91 | 92 | # print("pool:",features_pooled) 93 | # print(features) 94 | 95 | # add token-level Transformer outputs 96 | outputs = {"features_pooled": features_pooled, 97 | "features": features} 98 | 99 | return outputs 100 | 101 | def forward(self, video_semantic_result, audio_semantic_result): 102 | 103 | """ 104 | outputs = {"features_pooled": features_pooled, 105 | "features": features} 106 | """ 107 | 108 | 109 | 110 | video_outputs = self._modality_call(inputs=video_semantic_result, 111 | modality='video', 112 | training=self.training, 113 | attention_mask=None) 114 | audio_outputs = self._modality_call(inputs=audio_semantic_result, 115 | modality='audio', 116 | training=self.training, 117 | attention_mask=None) 118 | 119 | # print("video_outputs:",video_outputs["features"].size(), video_outputs["features"].dtype) 120 | # print("video_semantic_result:",video_semantic_result.size(), video_semantic_result.dtype) 121 | 122 | # print(video_semantic_result) 123 | 124 | return video_outputs["features"], audio_outputs["features"] 125 | -------------------------------------------------------------------------------- /code/src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/code/src/model/__init__.py -------------------------------------------------------------------------------- /code/src/model/mine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # class Mine(nn.Module): 7 | # def __init__(self): 8 | # super(Mine, self).__init__() 9 | # self.fc1_x = nn.Linear(2048, 512) 10 | # self.fc1_y = nn.Linear(2048, 512) 11 | # self.fc2 = nn.Linear(512,1) 12 | # def forward(self, x,y): 13 | # h1 = F.leaky_relu(self.fc1_x(x)+self.fc1_y(y)) 14 | # h2 = self.fc2(h1) 15 | # return h2 16 | # 17 | # Mine = Mine() 18 | # def mi_estimator(x, y, y_): 19 | # 20 | # joint, marginal = Mine(x, y), Mine(x, y_) 21 | # return torch.mean(joint) - torch.log(torch.mean(torch.exp(marginal))) 22 | 23 | # x = torch.randn(32, 10, 2048) 24 | # y = torch.randn(32, 10, 2048) 25 | # y_ = torch.randn(32, 10, 2048) 26 | # joint, marginal = Mine(x, y), Mine(x, y_) 27 | # loss = torch.mean(joint) - torch.log(torch.mean(torch.exp(marginal))) 28 | # print(loss) 29 | 30 | # class Mine2(nn.Module): 31 | # def __init__(self, x_dim, y_dim, hidden_dim): 32 | # super(Mine2, self).__init__() 33 | 34 | # 35 | # 36 | # class MINE(nn.Module): 37 | # def __init__(self, hidden_size=256): 38 | # super(MINE, self).__init__() 39 | # self.layers = nn.Sequential(nn.Linear(512, hidden_size), 40 | # nn.ReLU(), 41 | # nn.Linear(hidden_size, 1)) 42 | # 43 | # def forward(self, x, y): 44 | # batch_size = x.size(0) 45 | # tiled_x = torch.cat([x, x, ], dim=0) 46 | # print("tiled_x:",tiled_x.size()) 47 | # idx = torch.randperm(batch_size) 48 | # 49 | # shuffled_y = y[idx] 50 | # concat_y = torch.cat([y, shuffled_y], dim=0) 51 | # print("concat_y:", concat_y.size()) 52 | # 53 | # 54 | # inputs = torch.cat([tiled_x, concat_y], dim=1) 55 | # print("inputs:",inputs.size()) 56 | # logits = self.layers(inputs) 57 | # 58 | # pred_xy = logits[:batch_size] 59 | # pred_x_y = logits[batch_size:] 60 | # loss = -(torch.mean(pred_xy) 61 | # - torch.log(torch.mean(torch.exp(pred_x_y)))) 62 | # 63 | # return loss 64 | # # 65 | 66 | 67 | class MINE(nn.Module): 68 | def __init__(self, x_dim, y_dim, hidden_size): 69 | super(MINE, self).__init__() 70 | self.T_func = nn.Sequential(nn.Linear(x_dim + y_dim, hidden_size), 71 | nn.ReLU(), 72 | nn.Linear(hidden_size, 1)) 73 | 74 | def forward(self, x_samples, y_samples): # samples have shape [sample_size, dim] 75 | # shuffle and concatenate 76 | sample_size = y_samples.shape[0] 77 | random_index = torch.randint(sample_size, (sample_size,)).long() 78 | 79 | y_shuffle = y_samples[random_index] 80 | #print("y_shuffle", y_shuffle.size()) 81 | 82 | T0 = self.T_func(torch.cat([x_samples, y_samples], dim=-1).to(torch.float32)) 83 | #print("T0:",T0.size()) 84 | T1 = self.T_func(torch.cat([x_samples, y_shuffle], dim=-1).to(torch.float32)) 85 | #print("T1:", T1.size()) 86 | 87 | lower_bound = T0.mean() - torch.log(T1.exp().mean()) 88 | 89 | # compute the negative loss (maximise loss == minimise -loss) 90 | return lower_bound 91 | 92 | def learning_loss(self, x_samples, y_samples): 93 | return -self.forward(x_samples, y_samples) 94 | 95 | 96 | class CLUBSample(nn.Module): # Sampled version of the CLUB estimator 97 | def __init__(self, x_dim, y_dim, hidden_size): 98 | super(CLUBSample, self).__init__() 99 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size // 2), 100 | nn.ReLU(), 101 | nn.Linear(hidden_size // 2, y_dim)) 102 | 103 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size // 2), 104 | nn.ReLU(), 105 | nn.Linear(hidden_size // 2, y_dim), 106 | nn.Tanh()) 107 | 108 | def get_mu_logvar(self, x_samples): 109 | mu = self.p_mu(x_samples) 110 | logvar = self.p_logvar(x_samples) 111 | return mu, logvar 112 | 113 | def loglikeli(self, x_samples, y_samples): 114 | mu, logvar = self.get_mu_logvar(x_samples) 115 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean() 116 | 117 | def forward(self, x_samples, y_samples): 118 | mu, logvar = self.get_mu_logvar(x_samples) 119 | 120 | sample_size = x_samples.shape[0] 121 | # random_index = torch.randint(sample_size, (sample_size,)).long() 122 | random_index = torch.randperm(sample_size).long() 123 | 124 | positive = - (mu - y_samples) ** 2 / logvar.exp() 125 | negative = - (mu - y_samples[random_index]) ** 2 / logvar.exp() 126 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 127 | return upper_bound / 2. 128 | 129 | def learning_loss(self, x_samples, y_samples): 130 | return - self.loglikeli(x_samples, y_samples) 131 | 132 | # x = torch.randn(32, 10, 512) 133 | # y = torch.randn(32, 10, 2048) 134 | # 135 | # model = MINE(x_dim=512, y_dim=2048, hidden_size=256) 136 | # loss = model.learning_loss(x, y) 137 | # print(loss) -------------------------------------------------------------------------------- /code/src/pretrain.sh: -------------------------------------------------------------------------------- 1 | python pretrain.py \ 2 | --gpu 0 \ 3 | --lr 0.0004 \ 4 | --clip_gradient 0.5 \ 5 | --snapshot_pref "./Exps/pretrain/" \ 6 | --n_epoch 6 \ 7 | --b 80 \ 8 | --test_batch_size 64 \ 9 | --dataset_name "vggsound_AVT" \ 10 | --print_freq 1 11 | -------------------------------------------------------------------------------- /code/src/ucf_vggsound.sh: -------------------------------------------------------------------------------- 1 | python ucf_vggsound.py \ 2 | --gpu 0 \ 3 | --lr 0.0004 \ 4 | --clip_gradient 0.5 \ 5 | --snapshot_pref "./Exps/ucf_vggsound/" \ 6 | --n_epoch 30 \ 7 | --b 80 \ 8 | --test_batch_size 64 \ 9 | --dataset_name "vgga_ucfv" \ 10 | --print_freq 1 \ 11 | --eval_freq 1 -------------------------------------------------------------------------------- /code/src/utils/Recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class Recorder(object): 6 | def __init__(self, snapshot_pref, ignore_folder): 7 | if not os.path.isdir(snapshot_pref): 8 | os.mkdir(snapshot_pref) 9 | self.save_path = snapshot_pref 10 | self.log_file = self.save_path + "log.txt" 11 | self.readme = self.save_path + "README.md" 12 | self.opt_file = self.save_path + "opt.log" 13 | self.code_path = os.path.join(self.save_path, "code/") 14 | # self.weight_folder = self.save_path + "weight/" 15 | # self.weight_fig_folder = self.save_path + "weight_fig/" 16 | # if os.path.isfile(self.log_file): 17 | # os.remove(self.log_file) 18 | if os.path.isfile(self.readme): 19 | os.remove(self.readme) 20 | if not os.path.isdir(self.code_path): 21 | os.mkdir(self.code_path) 22 | self.copy_code(dst=self.code_path, ignore_folder=ignore_folder) 23 | """if os.path.isdir(self.weight_folder): 24 | shutil.rmtree(self.weight_folder, ignore_errors=True) 25 | os.mkdir(self.weight_folder) 26 | if os.path.isdir(self.weight_fig_folder): 27 | shutil.rmtree(self.weight_fig_folder, ignore_errors=True) 28 | os.mkdir(self.weight_fig_folder)""" 29 | 30 | print ("\n======> Result will be saved at: ", self.save_path) 31 | 32 | def copy_code(self, src="./", dst="./code/", ignore_folder='Exps'): 33 | import uuid 34 | if os.path.isdir(dst): 35 | # dst = "/".join(dst.split("/")[:-1])+"_"+str(uuid.uuid4())+"/" 36 | dst = "/".join(dst.split("/")) + "code_" + str(uuid.uuid4()) + "/" 37 | file_abs_list = [] 38 | src_abs = os.path.abspath(src) 39 | for root, dirs, files in os.walk(src_abs): 40 | if ignore_folder not in root: 41 | for name in files: 42 | file_abs_list.append(root + "/" + name) 43 | 44 | for file_abs in file_abs_list: 45 | file_split = file_abs.split("/")[-1].split('.') 46 | # if len(file_split) >= 2 and file_split[1] == "py": 47 | if os.path.getsize(file_abs)/1024/1024 < 10 and not file_split[-1] == "pyc": 48 | src_file = file_abs 49 | dst_file = dst + file_abs.replace(src_abs, "") 50 | if not os.path.exists(os.path.dirname(dst_file)): 51 | os.makedirs(os.path.dirname(dst_file)) 52 | shutil.copyfile(src=src_file, dst=dst_file) 53 | try: 54 | shutil.copyfile(src=src_file, dst=dst_file) 55 | except: 56 | print("copy file error") 57 | 58 | def writeopt(self, opt): 59 | with open(self.opt_file, "w") as f: 60 | for k, v in opt.__dict__.items(): 61 | f.write(str(k)+": "+str(v)+"\n") 62 | 63 | def writelog(self, input_data): 64 | txt_file = open(self.log_file, 'a+') 65 | txt_file.write(str(input_data) + "\n") 66 | txt_file.close() 67 | 68 | def writereadme(self, input_data): 69 | txt_file = open(self.readme, 'a+') 70 | txt_file.write(str(input_data) + "\n") 71 | txt_file.close() 72 | 73 | 74 | def gennetwork(self, var): 75 | self.graph.draw(var=var) 76 | 77 | def savenetwork(self): 78 | self.graph.save(file_name=self.save_path+"network.svg") 79 | 80 | """def writeweights(self, input_data, block_id, layer_id, epoch_id): 81 | txt_path = self.weight_folder + "conv_weight_" + str(epoch_id) + ".log" 82 | txt_file = open(txt_path, 'a+') 83 | write_str = "%d\t%d\t%d\t" % (epoch_id, block_id, layer_id) 84 | for x in input_data: 85 | write_str += str(x) + "\t" 86 | txt_file.write(write_str+"\n") 87 | 88 | def drawhist(self): 89 | drawer = DrawHistogram(txt_folder=self.weight_folder, fig_folder=self.weight_fig_folder) 90 | drawer.draw()""" 91 | 92 | -------------------------------------------------------------------------------- /code/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /code/src/utils/container.py: -------------------------------------------------------------------------------- 1 | from gensim.models import KeyedVectors 2 | 3 | 4 | class ModelContainer(object): 5 | def __init__(self): 6 | super(ModelContainer, self).__init__() 7 | self.model = {} 8 | self.data = {} 9 | 10 | def save_model(self, model, name, data, criterion, greater_better=True): 11 | assert criterion in data, "Incompatible criterion name!" 12 | if (name not in self.model) or ((self.data[name][criterion] <= data[criterion]) == greater_better): 13 | self.model.update({name: model}) 14 | self.data.update({name: data}) 15 | return self.data[name] 16 | 17 | def fetch_model(self, name): 18 | assert name in self.model, "Invalid model name!" 19 | return self.model[name], self.data[name] 20 | 21 | 22 | def merge_dicts(a, b): 23 | for key, value in b.items(): 24 | a[key] = value + a[key] if key in a else value 25 | return a 26 | 27 | 28 | class ValueContainer(object): 29 | def __init__(self): 30 | super().__init__() 31 | self.data = {} 32 | 33 | def reset(self, model_name): 34 | self.data[model_name] = {"count": 0} 35 | 36 | def update(self, model_name, metrics, step=1): 37 | if model_name not in self.data: 38 | self.reset(model_name) 39 | if not isinstance(metrics, dict): 40 | metrics = {"__default__": metrics} 41 | self.data[model_name] = merge_dicts(self.data[model_name], metrics) 42 | self.data[model_name]["count"] += step 43 | 44 | def calculate_average(self, model_name, reset=True): 45 | result = {} 46 | cum_result = self.data[model_name] 47 | for key in cum_result: 48 | if key != "count": 49 | result[key] = cum_result[key] / cum_result["count"] 50 | if reset: 51 | self.reset(model_name) 52 | if "__default__" in result: 53 | return result["__default__"] 54 | return result 55 | 56 | 57 | class ResourceContainer: 58 | def __init__(self): 59 | super().__init__() 60 | self.vocab = None 61 | self.resource = {} 62 | 63 | def save_resource(self, name, resource): 64 | self.resource[name] = resource 65 | 66 | def fetch_resource(self, name): 67 | if name in self.resource: 68 | return self.resource[name] 69 | return None 70 | 71 | def fetch_vocab(self, vocab_path): 72 | if self.vocab is None: 73 | self.vocab = KeyedVectors.load_word2vec_format(vocab_path, binary=True) 74 | return self.vocab 75 | 76 | 77 | 78 | modelContainer = ModelContainer() 79 | metricsContainer = ValueContainer() 80 | resourceContainer = ResourceContainer() 81 | -------------------------------------------------------------------------------- /code/src/utils/draw.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Draw_Heatmap(nn.Module): 7 | def __init__(self): 8 | super(Draw_Heatmap, self).__init__() 9 | 10 | def forward(self, m1, m2, m3, video_id, modality,epoch): 11 | matrix1 = m1.cpu().numpy() 12 | matrix2 = m2.cpu().numpy() 13 | matrix3 = m3.cpu().numpy() 14 | 15 | matrices = [matrix1, matrix2, matrix3] 16 | 17 | fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(15, 13)) 18 | 19 | for i, matrix in enumerate(matrices): 20 | im = axs[i].imshow(matrix, cmap='viridis_r', interpolation='nearest') 21 | for j in range(matrix.shape[0]): 22 | for k in range(matrix.shape[1]): 23 | text = axs[i].text(k, j, np.round(float(matrix[j, k]), 2), ha="center", va="center", color="w",fontsize=8,weight=3) 24 | 25 | axs[0].set_title('pred', y=-0.16) 26 | axs[1].set_title('label', y=-0.16) 27 | axs[2].set_title('result', y=-0.16) 28 | plt.text(0,-25,s=video_id,fontsize=20,weight=6,color="k") 29 | 30 | fig.colorbar(im, ax=axs) 31 | 32 | if(modality=="va"): 33 | plt.savefig('../../heatmap/va/epoch_' f'{epoch}_' f'{video_id}''.png',dpi=200,bbox_inches = 'tight') 34 | elif(modality=="av"): 35 | plt.savefig('../../heatmap/av/epoch_' f'{epoch}_' f'{video_id}''.png',dpi=200,bbox_inches = 'tight') 36 | 37 | -------------------------------------------------------------------------------- /code/src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from ruamel import yaml 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def Prepare_logger(args, eval=False): 25 | logger = logging.getLogger(__name__) 26 | logger.propagate = False 27 | logger.setLevel(logging.INFO) 28 | handler = logging.StreamHandler() 29 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 30 | handler.setFormatter(formatter) 31 | handler.setLevel(0) 32 | logger.addHandler(handler) 33 | 34 | date = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 35 | logfile = args.snapshot_pref+date+'.log' if not eval else args.snapshot_pref + f'/{date}-Eval.log' 36 | file_handler = logging.FileHandler(logfile, mode='w') 37 | file_handler.setLevel(logging.INFO) 38 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 39 | file_handler.setFormatter(formatter) 40 | logger.addHandler(file_handler) 41 | 42 | return logger 43 | 44 | 45 | def get_configs(dataset): 46 | data = yaml.load(open('./configs/dataset_cfg.yaml')) 47 | return data[dataset] 48 | 49 | def get_and_save_args(parser): 50 | args = parser.parse_args() 51 | # dataset = args.dataset 52 | 53 | default_config = yaml.load(open('./configs/default_config.yaml', 'r'), Loader=yaml.RoundTripLoader) 54 | current_config = vars(args) 55 | for k, v in current_config.items(): 56 | if k in default_config: 57 | if (v != default_config[k]) and (v is not None): 58 | print(f"Updating: {k}: {default_config[k]} (default) ----> {v}") 59 | default_config[k] = v 60 | yaml.dump(default_config, open('./current_configs.yaml', 'w'), indent=4, Dumper=yaml.RoundTripDumper) 61 | return default_config -------------------------------------------------------------------------------- /figs/MM_EMA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/figs/MM_EMA.pdf -------------------------------------------------------------------------------- /figs/illustration.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/figs/illustration.pdf -------------------------------------------------------------------------------- /figs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihuangcode/CMG/fc12eab63aaf818271ac56fe4059d28824f4f92f/figs/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | antlr4-python3-runtime==4.8 3 | anyio==3.3.1 4 | argon2-cffi==21.1.0 5 | attrs==21.2.0 6 | Babel==2.9.1 7 | backcall==0.2.0 8 | bert-embedding==1.0.1 9 | bitarray==2.6.0 10 | bleach==4.1.0 11 | blis==0.7.8 12 | brotlipy==0.7.0 13 | cachetools==4.2.2 14 | catalogue==2.0.8 15 | certifi==2021.5.30 16 | click==8.1.3 17 | colorama==0.4.5 18 | cycler==0.10.0 19 | cymem==2.0.6 20 | Cython==0.29.32 21 | dataclasses==0.6 22 | debugpy==1.4.3 23 | decorator==5.1.0 24 | defusedxml==0.7.1 25 | entrypoints==0.3 26 | ftfy==6.1.1 27 | gensim==4.0.0 28 | gluonnlp==0.6.0 29 | google-auth==1.35.0 30 | google-auth-oauthlib==0.4.6 31 | graphviz==0.8.4 32 | grpcio==1.40.0 33 | h5py==3.4.0 34 | hydra-core==1.0.7 35 | imageio==2.21.1 36 | importlib-resources==5.9.0 37 | info-nce-pytorch==0.1.4 38 | ipykernel==6.4.1 39 | ipython==7.27.0 40 | ipython-genutils==0.2.0 41 | ipywidgets==7.6.5 42 | jedi==0.18.0 43 | Jinja2==3.0.1 44 | joblib==1.1.0 45 | json5==0.9.6 46 | jsonschema==3.2.0 47 | jupyter-client==7.0.3 48 | jupyter-core==4.8.1 49 | jupyter-server==1.11.0 50 | jupyterlab==3.1.12 51 | jupyterlab-pygments==0.1.2 52 | jupyterlab-server==2.8.1 53 | jupyterlab-widgets==1.0.2 54 | kiwisolver==1.3.2 55 | langcodes==3.2.1 56 | lxml==4.8.0 57 | Markdown==3.3.4 58 | MarkupSafe==2.0.1 59 | matplotlib==3.4.3 60 | matplotlib-inline==0.1.3 61 | mistune==0.8.4 62 | murmurhash==1.0.7 63 | nbclassic==0.3.2 64 | nbclient==0.5.4 65 | nbconvert==6.1.0 66 | nbformat==5.1.3 67 | nest-asyncio==1.5.1 68 | networkx==2.6.3 69 | nltk==3.7 70 | notebook==6.4.4 71 | oauthlib==3.1.1 72 | omegaconf==2.0.6 73 | opencv-python==4.6.0.66 74 | packaging==21.0 75 | pandas==1.3.5 76 | pandocfilters==1.5.0 77 | parso==0.8.2 78 | pathy==0.6.2 79 | pexpect==4.8.0 80 | pickleshare==0.7.5 81 | Pillow==8.3.2 82 | portalocker==2.5.1 83 | preshed==3.0.6 84 | prometheus-client==0.11.0 85 | prompt-toolkit==3.0.20 86 | protobuf==4.24.4 87 | ptyprocess==0.7.0 88 | pyasn1==0.4.8 89 | pyasn1-modules==0.2.8 90 | pycosat==0.6.3 91 | pydantic==1.9.2 92 | Pygments==2.10.0 93 | pyparsing==2.4.7 94 | pyrsistent==0.18.0 95 | python-dateutil==2.8.2 96 | pytz==2021.1 97 | PyWavelets==1.3.0 98 | PyYAML==6.0 99 | pyzmq==22.3.0 100 | regex==2022.7.25 101 | requests-oauthlib==1.3.0 102 | requests-unixsocket==0.2.0 103 | rsa==4.7.2 104 | ruamel.yaml==0.17.35 105 | ruamel.yaml.clib==0.2.8 106 | sacrebleu==2.2.0 107 | scikit-image==0.16.2 108 | scipy==1.7.3 109 | Send2Trash==1.8.0 110 | smart-open==5.2.1 111 | sniffio==1.2.0 112 | spacy==3.4.1 113 | spacy-legacy==3.0.9 114 | spacy-loggers==1.0.3 115 | srsly==2.4.4 116 | supervisor==4.2.2 117 | tabulate==0.8.10 118 | tensorboard==2.6.0 119 | tensorboard-data-server==0.6.1 120 | tensorboard-plugin-wit==1.8.0 121 | tensorboardX==2.6.2.2 122 | terminado==0.12.1 123 | testpath==0.5.0 124 | thinc==8.1.0 125 | torch==1.13.0 126 | torch-scatter==2.0.9 127 | torchaudio==0.9.0 128 | tornado==6.1 129 | traitlets==5.1.0 130 | typer==0.4.2 131 | typing==3.6.6 132 | typing-extensions==3.10.0.2 133 | tzdata==2023.3 134 | wasabi==0.10.1 135 | wcwidth==0.2.5 136 | webencodings==0.5.1 137 | websocket-client==1.2.1 138 | Werkzeug==2.0.1 139 | widgetsnbextension==3.5.1 140 | zipp 141 | torchvision==0.14.0 142 | --------------------------------------------------------------------------------