├── AVSBench ├── AVSBench_Dataset.py ├── README.md ├── eval_utils.py └── metadata │ ├── avs1_ms3_test.csv │ └── avs1_s4_test.csv ├── Distributed_Experiment.sh ├── Eval.py ├── Flickr ├── Flickr_Dataset.py ├── README.md ├── eval_utils.py ├── extend_eval_utils.py └── metadata │ ├── flickr_10k.csv │ ├── flickr_144k.csv │ ├── flickr_test.csv │ └── flickr_test_plus_silent.csv ├── README.md ├── SingleGPU_Experiment.sh ├── Test_PTModels.py ├── Test_PTModels.sh ├── Train_ACL.py ├── VGGSS ├── README.md ├── VGGSS_Dataset.py ├── eval_utils.py ├── extend_eval_utils.py └── metadata │ ├── vggss.json │ ├── vggss_10k.csv │ ├── vggss_144k.csv │ ├── vggss_heard.csv │ ├── vggss_heard_test.csv │ ├── vggss_test.csv │ ├── vggss_test_plus_silent.csv │ └── vggss_unheard_test.csv ├── asset └── summary_wacv.png ├── config ├── model │ └── ACL_ViT16.yaml └── train │ └── Exp_ACL_v1.yaml ├── loss_utils.py ├── modules ├── AudioToken │ ├── AudioToken.py │ └── embedder.py ├── BEATs │ ├── BEATs.py │ ├── Tokenizers.py │ ├── backbone.py │ ├── modules.py │ └── quantizer.py ├── CLIPSeg │ └── clipseg_for_audio.py ├── FGA │ ├── atten.py │ └── fga_model.py ├── arg_utils.py ├── mask_utils.py └── models.py ├── pretrain └── README.md ├── util.py └── viz_utils.py /AVSBench/AVSBench_Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchaudio 4 | from torchvision import transforms as vt 5 | from PIL import Image 6 | import os 7 | import csv 8 | from typing import Dict, Optional, Union, List 9 | 10 | 11 | class AVSBenchDataset(Dataset): 12 | def __init__(self, data_path: str, split: str, is_train: bool = True, set_length: int = 10, 13 | input_resolution: int = 224) -> None: 14 | """ 15 | Initialize AVSBench Dataset. 16 | 17 | Args: 18 | data_path (str): Path to the dataset. 19 | split (str): Dataset split (Use csv file name in metadata directory). 20 | is_train (bool, optional): Whether it's a training set. Default is True. 21 | set_length (int, optional): Duration of input audio. Default is 10. 22 | input_resolution (int, optional): Resolution of input images. Default is 224. 23 | """ 24 | super(AVSBenchDataset, self).__init__() 25 | 26 | self.SAMPLE_RATE = 16000 27 | self.split = split 28 | self.set_length = set_length 29 | self.csv_dir = 'AVSBench/metadata/' + split + '.csv' 30 | self.setting = split.split('_')[1] 31 | 32 | ''' Audio files ''' 33 | self.audio_path = os.path.join(data_path, self.setting, 'audio_wav') 34 | audio_files = set([fn.split('.wav')[0] for fn in os.listdir(self.audio_path) if fn.endswith('.wav')]) 35 | 36 | ''' Image files ''' 37 | self.image_path = os.path.join(data_path, self.setting, 'visual_frames') 38 | image_files = set([fn.split('.png')[0] for fn in os.listdir(self.image_path) if fn.endswith('.png')]) 39 | 40 | ''' Ground truth (Bounding box) ''' 41 | if is_train: 42 | self.gt_path = None 43 | else: 44 | self.gt_path = os.path.join(data_path, self.setting, 'gt_masks') 45 | 46 | ''' Ground truth (Text label) ''' 47 | self.label_dict = {item[0]: item[1] for item in csv.reader(open(self.csv_dir))} 48 | 49 | ''' Available files''' 50 | subset = set([item[0] for item in csv.reader(open(self.csv_dir))]) 51 | self.file_list = sorted(list(image_files.intersection(subset))) 52 | 53 | ''' Transform ''' 54 | if is_train: 55 | self.image_transform = vt.Compose([ 56 | vt.Resize((int(input_resolution * 1.1), int(input_resolution * 1.1)), vt.InterpolationMode.BICUBIC), 57 | vt.ToTensor(), 58 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 59 | vt.RandomCrop((input_resolution, input_resolution)), 60 | vt.RandomHorizontalFlip(), 61 | ]) 62 | else: 63 | self.image_transform = vt.Compose([ 64 | vt.Resize((input_resolution, input_resolution), vt.InterpolationMode.BICUBIC), 65 | vt.ToTensor(), 66 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 67 | ]) 68 | 69 | def __len__(self): 70 | """ 71 | Return the number of items in the dataset. 72 | """ 73 | return len(self.file_list) 74 | 75 | def get_audio(self, item: int) -> torch.Tensor: 76 | """ 77 | Get audio data for a given item. 78 | 79 | Args: 80 | item (int): Index of the item. 81 | 82 | Returns: 83 | torch.Tensor: Audio data. 84 | """ 85 | audio_file, _ = torchaudio.load(os.path.join(self.audio_path, self.file_list[item][:-2] + '.wav')) 86 | audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo 5 sec -> 10 sec 87 | audio_file = audio_file.squeeze(0) 88 | 89 | # slicing or padding based on set_length 90 | # slicing 91 | if audio_file.shape[0] > (self.SAMPLE_RATE * self.set_length): 92 | audio_file = audio_file[:self.SAMPLE_RATE * self.set_length] 93 | # zero padding 94 | if audio_file.shape[0] < (self.SAMPLE_RATE * self.set_length): 95 | pad_len = (self.SAMPLE_RATE * self.set_length) - audio_file.shape[0] 96 | pad_val = torch.zeros(pad_len) 97 | audio_file = torch.cat((audio_file, pad_val), dim=0) 98 | 99 | return audio_file 100 | 101 | def get_image(self, item: int) -> Image.Image: 102 | """ 103 | Get image data for a given item. 104 | 105 | Args: 106 | item (int): Index of the item. 107 | 108 | Returns: 109 | Image.Image: Image data. 110 | """ 111 | image_file = Image.open(os.path.join(self.image_path, self.file_list[item] + '.png')) 112 | return image_file 113 | 114 | def get_gt(self, item: int) -> Optional[torch.Tensor]: 115 | """ 116 | Get ground truth data for a given item. 117 | 118 | Args: 119 | item (int): Index of the item. 120 | 121 | Returns: 122 | Optional[torch.Tensor]: Ground truth data. 123 | """ 124 | # Ground truth 125 | if self.gt_path is None: 126 | return None 127 | else: 128 | gt = vt.ToTensor()( 129 | Image.open(os.path.join(self.gt_path, self.file_list[item] + '.png')).convert('1')).float() 130 | return gt 131 | 132 | def __getitem__(self, item: int) -> Dict[str, Union[torch.Tensor, torch.Tensor, Optional[torch.Tensor], str, str]]: 133 | """ 134 | Get item from the dataset. 135 | 136 | Args: 137 | item (int): Index of the item. 138 | 139 | Returns: 140 | Dict[str, Union[torch.Tensor, torch.Tensor, Optinal[torch.Tensor], str, str]]: Data example 141 | """ 142 | file_id = self.file_list[item] 143 | 144 | ''' Load data ''' 145 | audio_file = self.get_audio(item) 146 | image_file = self.get_image(item) 147 | label = self.label_dict[self.file_list[item]].replace('_', ' ') 148 | gts = self.get_gt(item) 149 | 150 | ''' Transform ''' 151 | audio = audio_file 152 | image = self.image_transform(image_file) 153 | 154 | out = {'images': image, 'audios': audio, 'gts': gts, 'labels': label, 'ids': file_id} 155 | out = {key: value for key, value in out.items() if value is not None} 156 | return out 157 | -------------------------------------------------------------------------------- /AVSBench/README.md: -------------------------------------------------------------------------------- 1 | # Directory guide for AVSBench 2 | ```commandline 3 | ├── AVS1/ 4 | | ├── s4 5 | | | └── audio_wav 6 | | | | └── ... 7 | | | | └── *** .wav 8 | | | └── gt_masks 9 | | | | └── ... 10 | | | | └── *** png 11 | | | └── visual_frames 12 | | | | └── ... 13 | | | | └── *** .png 14 | ``` 15 | All .wav files sampled 16k 16 | 17 | ## Important 18 | Fix bug in official test code (Issue: F-Score results vary depending on the batch number) 19 | 20 | Considering the notable impact of this issue on the performance of self-supervised learning models, we suggest utilizing our updated test code. 21 | 22 | We already discussed this issue with the author who released the official code. 23 | -------------------------------------------------------------------------------- /AVSBench/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Tuple, Dict 4 | 5 | 6 | class Evaluator(object): 7 | def __init__(self) -> None: 8 | """ 9 | Initialize the AVSBench evaluator. 10 | 11 | Attributes: 12 | miou (List[float]): Buffer of mIoU values. 13 | F (List[float]): Buffer of F-measure values. 14 | N (int): Counter for the number of evaluations. 15 | metrics (List[str]): List of metric names. 16 | """ 17 | super(Evaluator, self).__init__() 18 | self.miou = [] 19 | self.F = [] 20 | self.N = 0 21 | self.metrics = ['mIoU', 'Fmeasure'] 22 | 23 | def evaluate_batch(self, pred: torch.Tensor, target: torch.Tensor, thr: List[float] = None) -> None: 24 | """ 25 | Evaluate a batch of predictions against ground truth. 26 | 27 | Args: 28 | pred (torch.Tensor): Model predictions. 29 | target (torch.Tensor): Ground truth. 30 | thr (List[float], optional): List of thresholds. If None, calculate threshold as median. Default is None. 31 | 32 | Notes: 33 | Updates metric buffers (self.mask_iou, self.Eval_Fmeasusre) 34 | """ 35 | thrs = [] 36 | 37 | for j in range(pred.size(0)): 38 | infer = pred[j] 39 | if thr is None: 40 | thrs.append(np.sort(infer.detach().cpu().numpy().flatten())[int(infer.shape[1] * infer.shape[2] / 2)]) 41 | else: 42 | thrs.append(thr) 43 | 44 | infers, gts = pred.squeeze(1), target.squeeze(1) 45 | self.mask_iou(infers, gts, thrs) 46 | self.Eval_Fmeasure(infers, gts) 47 | 48 | def mask_iou(self, preds: torch.Tensor, targets: torch.Tensor, thrs: List[float], eps: float = 1e-7) -> float: 49 | """ 50 | Calculate mask IoU. 51 | 52 | Args: 53 | preds (torch.Tensor): Model predictions. 54 | targets (torch.Tensor): Ground truth. 55 | thrs (List[float]): List of thresholds. 56 | eps (float, optional): Small epsilon to avoid division by zero. Default is 1e-7. 57 | 58 | Returns: 59 | float: mIoU value. 60 | """ 61 | assert len(preds.shape) == 3 and preds.shape == targets.shape 62 | self.N += 1 63 | 64 | N = preds.size(0) 65 | miou = 0.0 66 | for i in range(N): 67 | pred = preds[i].unsqueeze(0) 68 | target = targets[i].unsqueeze(0) 69 | 70 | num_pixels = pred.size(-1) * pred.size(-2) 71 | no_obj_flag = (target.sum(2).sum(1) == 0) 72 | 73 | pred = (pred > thrs[i]).int() 74 | inter = (pred * target).sum(2).sum(1) 75 | union = torch.max(pred, target).sum(2).sum(1) 76 | 77 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 78 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 79 | union[no_obj_flag] = num_pixels 80 | miou += (torch.sum(inter / (union + eps))).squeeze() 81 | miou = miou / N 82 | self.miou.append(miou.detach().cpu()) 83 | 84 | return miou 85 | 86 | @staticmethod 87 | def _eval_pr(y_pred: torch.Tensor, y: torch.Tensor, num: int, cuda_flag: bool = True) \ 88 | -> Tuple[torch.Tensor, torch.Tensor]: 89 | """ 90 | Calculate precision and recall. 91 | 92 | Args: 93 | y_pred (torch.Tensor): Model predictions. 94 | y (torch.Tensor): Ground truth. 95 | num (int): Number of threshold values. 96 | cuda_flag (bool, optional): Whether to use CUDA. Default is True. 97 | 98 | Returns: 99 | Tuple[torch.Tensor, torch.Tensor]: Precision and recall values. 100 | """ 101 | if cuda_flag: 102 | prec, recall = torch.zeros(num).to(y_pred.device), torch.zeros(num).to(y_pred.device) 103 | thlist = torch.linspace(0, 1 - 1e-10, num).to(y_pred.device) 104 | else: 105 | prec, recall = torch.zeros(num), torch.zeros(num) 106 | thlist = torch.linspace(0, 1 - 1e-10, num) 107 | for i in range(num): 108 | y_temp = (y_pred >= thlist[i]).float() 109 | tp = (y_temp * y).sum() 110 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 111 | 112 | return prec, recall 113 | 114 | def Eval_Fmeasure(self, pred: torch.Tensor, gt: torch.Tensor, pr_num: int = 255) -> float: 115 | """ 116 | Evaluate F-measure. 117 | 118 | Args: 119 | pred (torch.Tensor): Model predictions. 120 | gt (torch.Tensor): Ground truth. 121 | pr_num (int, optional): Number of precision-recall values. Default is 255. 122 | 123 | Returns: 124 | float: F-measure value. 125 | 126 | Notes: 127 | Fix bug in official test code (Issue: Results vary depending on the batch number) 128 | The official code had an issue because it optimized precision-recall thresholds for each mini-batch. 129 | """ 130 | N = pred.size(0) 131 | beta2 = 0.3 132 | avg_f, img_num = 0.0, 0 133 | score = torch.zeros(pr_num).to(pred.device) 134 | 135 | for img_id in range(N): 136 | # examples with totally black GTs are out of consideration 137 | if torch.sum(gt[img_id]) == 0.0: 138 | continue 139 | prec, recall = self._eval_pr(pred[img_id], gt[img_id], pr_num) 140 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 141 | f_score[f_score != f_score] = 0 # for Nan 142 | avg_f += f_score 143 | img_num += 1 144 | score = avg_f / img_num 145 | self.F.append(f_score.detach().cpu().numpy()) 146 | # print('score: ', score) 147 | 148 | return score.max().item() 149 | 150 | def finalize_mIoU(self) -> float: 151 | """ 152 | Calculate the final mIoU value. 153 | 154 | Returns: 155 | float: Final mIoU value. 156 | """ 157 | miou = np.sum(np.array(self.miou)) / self.N 158 | return miou 159 | 160 | def finalize_Fmeasure(self) -> float: 161 | """ 162 | Calculate the final F-measure value. 163 | 164 | Returns: 165 | float: Final F-measure value. 166 | 167 | Notes: 168 | Fix bug in official test code (Issue: Results vary depending on the batch number) 169 | The official code had an issue because it optimized precision-recall thresholds for each mini-batch 170 | """ 171 | # F = np.sum(np.array(self.F)) / self.N 172 | F = np.max(np.mean(self.F, axis=0)) 173 | 174 | return F 175 | 176 | def finalize(self) -> Tuple[List[str], Dict[str, float]]: 177 | """ 178 | Finalize evaluation and return the results. 179 | 180 | Returns: 181 | Tuple[List[str], Dict[str, float]]: Tuple containing metric names and corresponding values. 182 | """ 183 | mIoU = self.finalize_mIoU() * 100 184 | F = self.finalize_Fmeasure() * 100 185 | return self.metrics, {self.metrics[0]: mIoU, self.metrics[1]: F} 186 | -------------------------------------------------------------------------------- /AVSBench/metadata/avs1_ms3_test.csv: -------------------------------------------------------------------------------- 1 | 0WxgIKuetYI_0_1,man 2 | 0WxgIKuetYI_0_2,man 3 | 0WxgIKuetYI_0_3,man 4 | 0WxgIKuetYI_0_4,man 5 | 0WxgIKuetYI_0_5,man 6 | 0bzkGQLy7b4_2_1,woman 7 | 0bzkGQLy7b4_2_2,woman 8 | 0bzkGQLy7b4_2_3,woman 9 | 0bzkGQLy7b4_2_4,woman 10 | 0bzkGQLy7b4_2_5,woman 11 | 0mqckt2Uca8_0_1,piano 12 | 0mqckt2Uca8_0_2,piano 13 | 0mqckt2Uca8_0_3,piano 14 | 0mqckt2Uca8_0_4,piano 15 | 0mqckt2Uca8_0_5,piano 16 | 0mqckt2Uca8_2_1,piano 17 | 0mqckt2Uca8_2_2,piano 18 | 0mqckt2Uca8_2_3,piano 19 | 0mqckt2Uca8_2_4,piano 20 | 0mqckt2Uca8_2_5,piano 21 | 26xaTuDMCIQ_1_1,"guitar, violin" 22 | 26xaTuDMCIQ_1_2,"guitar, violin" 23 | 26xaTuDMCIQ_1_3,"guitar, violin" 24 | 26xaTuDMCIQ_1_4,"guitar, violin" 25 | 26xaTuDMCIQ_1_5,"guitar, violin" 26 | 28XJmUz906g_0_1,"dog, mower" 27 | 28XJmUz906g_0_2,"dog, mower" 28 | 28XJmUz906g_0_3,"dog, mower" 29 | 28XJmUz906g_0_4,"dog, mower" 30 | 28XJmUz906g_0_5,"dog, mower" 31 | 3vheOi4y540_1_1,man 32 | 3vheOi4y540_1_2,man 33 | 3vheOi4y540_1_3,man 34 | 3vheOi4y540_1_4,man 35 | 3vheOi4y540_1_5,man 36 | 9xp46AwF9BY_3_1,"violin, piano" 37 | 9xp46AwF9BY_3_2,"violin, piano" 38 | 9xp46AwF9BY_3_3,"violin, piano" 39 | 9xp46AwF9BY_3_4,"violin, piano" 40 | 9xp46AwF9BY_3_5,"violin, piano" 41 | BZyTr2Pku2A_2_1,"woman, piano, guitar" 42 | BZyTr2Pku2A_2_2,"woman, piano, guitar" 43 | BZyTr2Pku2A_2_3,"woman, piano, guitar" 44 | BZyTr2Pku2A_2_4,"woman, piano, guitar" 45 | BZyTr2Pku2A_2_5,"woman, piano, guitar" 46 | BvtaXGFoeRA_0_1,"guitar, woman, tabla" 47 | BvtaXGFoeRA_0_2,"guitar, woman, tabla" 48 | BvtaXGFoeRA_0_3,"guitar, woman, tabla" 49 | BvtaXGFoeRA_0_4,"guitar, woman, tabla" 50 | BvtaXGFoeRA_0_5,"guitar, woman, tabla" 51 | CNbSq__q-BY_1,lion 52 | CNbSq__q-BY_2,lion 53 | CNbSq__q-BY_3,lion 54 | CNbSq__q-BY_4,lion 55 | CNbSq__q-BY_5,lion 56 | DGC9d3n7Tpw_0_1,piano 57 | DGC9d3n7Tpw_0_2,piano 58 | DGC9d3n7Tpw_0_3,piano 59 | DGC9d3n7Tpw_0_4,piano 60 | DGC9d3n7Tpw_0_5,piano 61 | ELZzyhWMdqA_0_1,piano 62 | ELZzyhWMdqA_0_2,piano 63 | ELZzyhWMdqA_0_3,piano 64 | ELZzyhWMdqA_0_4,piano 65 | ELZzyhWMdqA_0_5,piano 66 | ELZzyhWMdqA_1_1,guitar 67 | ELZzyhWMdqA_1_2,guitar 68 | ELZzyhWMdqA_1_3,guitar 69 | ELZzyhWMdqA_1_4,guitar 70 | ELZzyhWMdqA_1_5,guitar 71 | EnpqqJXs0aY_2_1,"baby, woman" 72 | EnpqqJXs0aY_2_2,"baby, woman" 73 | EnpqqJXs0aY_2_3,"baby, woman" 74 | EnpqqJXs0aY_2_4,"baby, woman" 75 | EnpqqJXs0aY_2_5,"baby, woman" 76 | FgNwZuMFQwY_0_1,piano 77 | FgNwZuMFQwY_0_2,piano 78 | FgNwZuMFQwY_0_3,piano 79 | FgNwZuMFQwY_0_4,piano 80 | FgNwZuMFQwY_0_5,piano 81 | FgNwZuMFQwY_1_1,"marimba, piano" 82 | FgNwZuMFQwY_1_2,"marimba, piano" 83 | FgNwZuMFQwY_1_3,"marimba, piano" 84 | FgNwZuMFQwY_1_4,"marimba, piano" 85 | FgNwZuMFQwY_1_5,"marimba, piano" 86 | FhF2q_P-vAA_0_1,"tabla, guitar" 87 | FhF2q_P-vAA_0_2,"tabla, guitar" 88 | FhF2q_P-vAA_0_3,"tabla, guitar" 89 | FhF2q_P-vAA_0_4,"tabla, guitar" 90 | FhF2q_P-vAA_0_5,"tabla, guitar" 91 | GJqSBf6p8qg_0_1,"violin, guitar" 92 | GJqSBf6p8qg_0_2,"violin, guitar" 93 | GJqSBf6p8qg_0_3,"violin, guitar" 94 | GJqSBf6p8qg_0_4,"violin, guitar" 95 | GJqSBf6p8qg_0_5,"violin, guitar" 96 | GsfG9ZC8rUU_0_1,piano 97 | GsfG9ZC8rUU_0_2,piano 98 | GsfG9ZC8rUU_0_3,piano 99 | GsfG9ZC8rUU_0_4,piano 100 | GsfG9ZC8rUU_0_5,piano 101 | I8K17mzV-QU_0_1,keyboard 102 | I8K17mzV-QU_0_2,keyboard 103 | I8K17mzV-QU_0_3,keyboard 104 | I8K17mzV-QU_0_4,keyboard 105 | I8K17mzV-QU_0_5,keyboard 106 | Kd42SbTQv2U_11_1,"man, dog" 107 | Kd42SbTQv2U_11_2,"man, dog" 108 | Kd42SbTQv2U_11_3,"man, dog" 109 | Kd42SbTQv2U_11_4,"man, dog" 110 | Kd42SbTQv2U_11_5,"man, dog" 111 | OPXHT-rt5BQ_1_1,"woman, ukulele" 112 | OPXHT-rt5BQ_1_2,"woman, ukulele" 113 | OPXHT-rt5BQ_1_3,"woman, ukulele" 114 | OPXHT-rt5BQ_1_4,"woman, ukulele" 115 | OPXHT-rt5BQ_1_5,"woman, ukulele" 116 | OfcjlybYwFs_2_1,"man, woman" 117 | OfcjlybYwFs_2_2,"man, woman" 118 | OfcjlybYwFs_2_3,"man, woman" 119 | OfcjlybYwFs_2_4,"man, woman" 120 | OfcjlybYwFs_2_5,"man, woman" 121 | Ox7OpIu5CLI_1_1,baby 122 | Ox7OpIu5CLI_1_2,baby 123 | Ox7OpIu5CLI_1_3,baby 124 | Ox7OpIu5CLI_1_4,baby 125 | Ox7OpIu5CLI_1_5,baby 126 | Ox7OpIu5CLI_6_1,baby 127 | Ox7OpIu5CLI_6_2,baby 128 | Ox7OpIu5CLI_6_3,baby 129 | Ox7OpIu5CLI_6_4,baby 130 | Ox7OpIu5CLI_6_5,baby 131 | Ox7OpIu5CLI_7_1,baby 132 | Ox7OpIu5CLI_7_2,baby 133 | Ox7OpIu5CLI_7_3,baby 134 | Ox7OpIu5CLI_7_4,baby 135 | Ox7OpIu5CLI_7_5,baby 136 | PDhRwjqN83U_0_1,violin 137 | PDhRwjqN83U_0_2,violin 138 | PDhRwjqN83U_0_3,violin 139 | PDhRwjqN83U_0_4,violin 140 | PDhRwjqN83U_0_5,violin 141 | QYvhdbgEJPM_0_1,"guitar, violin" 142 | QYvhdbgEJPM_0_2,"guitar, violin" 143 | QYvhdbgEJPM_0_3,"guitar, violin" 144 | QYvhdbgEJPM_0_4,"guitar, violin" 145 | QYvhdbgEJPM_0_5,"guitar, violin" 146 | SWLG_3suH7w_0_1,"violin, guitar" 147 | SWLG_3suH7w_0_2,"violin, guitar" 148 | SWLG_3suH7w_0_3,"violin, guitar" 149 | SWLG_3suH7w_0_4,"violin, guitar" 150 | SWLG_3suH7w_0_5,"violin, guitar" 151 | TeLISCmZ_w4_4_1,"piano, man" 152 | TeLISCmZ_w4_4_2,"piano, man" 153 | TeLISCmZ_w4_4_3,"piano, man" 154 | TeLISCmZ_w4_4_4,"piano, man" 155 | TeLISCmZ_w4_4_5,"piano, man" 156 | UsTf7brftGg_0_1,"piano, man" 157 | UsTf7brftGg_0_2,"piano, man" 158 | UsTf7brftGg_0_3,"piano, man" 159 | UsTf7brftGg_0_4,"piano, man" 160 | UsTf7brftGg_0_5,"piano, man" 161 | UtBQrBLsTQY_3_1,"man, bird" 162 | UtBQrBLsTQY_3_2,"man, bird" 163 | UtBQrBLsTQY_3_3,"man, bird" 164 | UtBQrBLsTQY_3_4,"man, bird" 165 | UtBQrBLsTQY_3_5,"man, bird" 166 | V9JdDs7RK3c_1_1,"woman, piano" 167 | V9JdDs7RK3c_1_2,"woman, piano" 168 | V9JdDs7RK3c_1_3,"woman, piano" 169 | V9JdDs7RK3c_1_4,"woman, piano" 170 | V9JdDs7RK3c_1_5,"woman, piano" 171 | Vwdib3HWRBI_0_1,"man, piano" 172 | Vwdib3HWRBI_0_2,"man, piano" 173 | Vwdib3HWRBI_0_3,"man, piano" 174 | Vwdib3HWRBI_0_4,"man, piano" 175 | Vwdib3HWRBI_0_5,"man, piano" 176 | WjyvwX_nZ6Y_0_1,"man, woman" 177 | WjyvwX_nZ6Y_0_2,"man, woman" 178 | WjyvwX_nZ6Y_0_3,"man, woman" 179 | WjyvwX_nZ6Y_0_4,"man, woman" 180 | WjyvwX_nZ6Y_0_5,"man, woman" 181 | YYbShzoZWRo_16_1,baby 182 | YYbShzoZWRo_16_2,baby 183 | YYbShzoZWRo_16_3,baby 184 | YYbShzoZWRo_16_4,baby 185 | YYbShzoZWRo_16_5,baby 186 | YYbShzoZWRo_25_1,"dog, baby" 187 | YYbShzoZWRo_25_2,"dog, baby" 188 | YYbShzoZWRo_25_3,"dog, baby" 189 | YYbShzoZWRo_25_4,"dog, baby" 190 | YYbShzoZWRo_25_5,"dog, baby" 191 | YYbShzoZWRo_26_1,baby 192 | YYbShzoZWRo_26_2,baby 193 | YYbShzoZWRo_26_3,baby 194 | YYbShzoZWRo_26_4,baby 195 | YYbShzoZWRo_26_5,baby 196 | YYbShzoZWRo_3_1,"baby, dog" 197 | YYbShzoZWRo_3_2,"baby, dog" 198 | YYbShzoZWRo_3_3,"baby, dog" 199 | YYbShzoZWRo_3_4,"baby, dog" 200 | YYbShzoZWRo_3_5,"baby, dog" 201 | ZedzGfGN1tk_1_1,"man, guitar" 202 | ZedzGfGN1tk_1_2,"man, guitar" 203 | ZedzGfGN1tk_1_3,"man, guitar" 204 | ZedzGfGN1tk_1_4,"man, guitar" 205 | ZedzGfGN1tk_1_5,"man, guitar" 206 | _fvpYLXtQHQ_1_1,"guitar, violin" 207 | _fvpYLXtQHQ_1_2,"guitar, violin" 208 | _fvpYLXtQHQ_1_3,"guitar, violin" 209 | _fvpYLXtQHQ_1_4,"guitar, violin" 210 | _fvpYLXtQHQ_1_5,"guitar, violin" 211 | bhDqxWQUIXg_1_1,tabla 212 | bhDqxWQUIXg_1_2,tabla 213 | bhDqxWQUIXg_1_3,tabla 214 | bhDqxWQUIXg_1_4,tabla 215 | bhDqxWQUIXg_1_5,tabla 216 | emLStHiqfyo_2_1,violin 217 | emLStHiqfyo_2_2,violin 218 | emLStHiqfyo_2_3,violin 219 | emLStHiqfyo_2_4,violin 220 | emLStHiqfyo_2_5,violin 221 | gmVP1bAB_Jk_3_1,baby 222 | gmVP1bAB_Jk_3_2,baby 223 | gmVP1bAB_Jk_3_3,baby 224 | gmVP1bAB_Jk_3_4,baby 225 | gmVP1bAB_Jk_3_5,baby 226 | gsAoNfN9Pvk_2_1,"violin, piano" 227 | gsAoNfN9Pvk_2_2,"violin, piano" 228 | gsAoNfN9Pvk_2_3,"violin, piano" 229 | gsAoNfN9Pvk_2_4,"violin, piano" 230 | gsAoNfN9Pvk_2_5,"violin, piano" 231 | kgw1OadfJWA_1,cat 232 | kgw1OadfJWA_2,cat 233 | kgw1OadfJWA_3,cat 234 | kgw1OadfJWA_4,cat 235 | kgw1OadfJWA_5,cat 236 | ljq0lCGKapY_0_1,marimba 237 | ljq0lCGKapY_0_2,marimba 238 | ljq0lCGKapY_0_3,marimba 239 | ljq0lCGKapY_0_4,marimba 240 | ljq0lCGKapY_0_5,marimba 241 | nnuN-Zt60TM_0_1,piano 242 | nnuN-Zt60TM_0_2,piano 243 | nnuN-Zt60TM_0_3,piano 244 | nnuN-Zt60TM_0_4,piano 245 | nnuN-Zt60TM_0_5,piano 246 | nnuN-Zt60TM_3_1,"violin, piano" 247 | nnuN-Zt60TM_3_2,"violin, piano" 248 | nnuN-Zt60TM_3_3,"violin, piano" 249 | nnuN-Zt60TM_3_4,"violin, piano" 250 | nnuN-Zt60TM_3_5,"violin, piano" 251 | o1bZB4fKv2U_1_1,"violin, man, piano" 252 | o1bZB4fKv2U_1_2,"violin, man, piano" 253 | o1bZB4fKv2U_1_3,"violin, man, piano" 254 | o1bZB4fKv2U_1_4,"violin, man, piano" 255 | o1bZB4fKv2U_1_5,"violin, man, piano" 256 | pFLeIwERa8o_1,dog 257 | pFLeIwERa8o_2,dog 258 | pFLeIwERa8o_3,dog 259 | pFLeIwERa8o_4,dog 260 | pFLeIwERa8o_5,dog 261 | pxa8kn8h5ew_0_1,tabla 262 | pxa8kn8h5ew_0_2,tabla 263 | pxa8kn8h5ew_0_3,tabla 264 | pxa8kn8h5ew_0_4,tabla 265 | pxa8kn8h5ew_0_5,tabla 266 | q6Vwbg3SOSc_0_1,gun 267 | q6Vwbg3SOSc_0_2,gun 268 | q6Vwbg3SOSc_0_3,gun 269 | q6Vwbg3SOSc_0_4,gun 270 | q6Vwbg3SOSc_0_5,gun 271 | qD-UEwVEDP0_0_1,lion 272 | qD-UEwVEDP0_0_2,lion 273 | qD-UEwVEDP0_0_3,lion 274 | qD-UEwVEDP0_0_4,lion 275 | qD-UEwVEDP0_0_5,lion 276 | rq5RDUm2Hnk_1_1,woman 277 | rq5RDUm2Hnk_1_2,woman 278 | rq5RDUm2Hnk_1_3,woman 279 | rq5RDUm2Hnk_1_4,woman 280 | rq5RDUm2Hnk_1_5,woman 281 | s6dj_CqT0Mk_0_1,"ukulele, guitar" 282 | s6dj_CqT0Mk_0_2,"ukulele, guitar" 283 | s6dj_CqT0Mk_0_3,"ukulele, guitar" 284 | s6dj_CqT0Mk_0_4,"ukulele, guitar" 285 | s6dj_CqT0Mk_0_5,"ukulele, guitar" 286 | s6dj_CqT0Mk_1_1,"ukulele, guitar" 287 | s6dj_CqT0Mk_1_2,"ukulele, guitar" 288 | s6dj_CqT0Mk_1_3,"ukulele, guitar" 289 | s6dj_CqT0Mk_1_4,"ukulele, guitar" 290 | s6dj_CqT0Mk_1_5,"ukulele, guitar" 291 | u4v75o5WxYw_7_1,"man, piano" 292 | u4v75o5WxYw_7_2,"man, piano" 293 | u4v75o5WxYw_7_3,"man, piano" 294 | u4v75o5WxYw_7_4,"man, piano" 295 | u4v75o5WxYw_7_5,"man, piano" 296 | uCK7Gok3u0U_5_1,piano 297 | uCK7Gok3u0U_5_2,piano 298 | uCK7Gok3u0U_5_3,piano 299 | uCK7Gok3u0U_5_4,piano 300 | uCK7Gok3u0U_5_5,piano 301 | xGeqjlPz4kw_4_1,man 302 | xGeqjlPz4kw_4_2,man 303 | xGeqjlPz4kw_4_3,man 304 | xGeqjlPz4kw_4_4,man 305 | xGeqjlPz4kw_4_5,man 306 | xLXcCv45AYg_22_1,dog 307 | xLXcCv45AYg_22_2,dog 308 | xLXcCv45AYg_22_3,dog 309 | xLXcCv45AYg_22_4,dog 310 | xLXcCv45AYg_22_5,dog 311 | x_w-93ctLN0_2_1,baby 312 | x_w-93ctLN0_2_2,baby 313 | x_w-93ctLN0_2_3,baby 314 | x_w-93ctLN0_2_4,baby 315 | x_w-93ctLN0_2_5,baby 316 | xv7eM6-UmkY_14_1,"baby, man" 317 | xv7eM6-UmkY_14_2,"baby, man" 318 | xv7eM6-UmkY_14_3,"baby, man" 319 | xv7eM6-UmkY_14_4,"baby, man" 320 | xv7eM6-UmkY_14_5,"baby, man" -------------------------------------------------------------------------------- /Distributed_Experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0,1" 4 | export OMP_NUM_THREADS="4" 5 | 6 | python -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 --master_port 12345 \ 7 | Train_ACL.py \ 8 | --model_name ACL_ViT16 \ 9 | --train_config Exp_ACL_v1 \ 10 | --exp_name aclifa_2gpu \ 11 | --vggss_path {put dataset directory} \ 12 | --flickr_path {put dataset directory} \ 13 | --avs_path {put dataset directory} \ 14 | --save_path {put logging directory} 15 | 16 | -------------------------------------------------------------------------------- /Flickr/Flickr_Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import numpy as np 5 | import torchaudio 6 | from torchvision import transforms as vt 7 | from PIL import Image 8 | import os 9 | import csv 10 | import xml.etree.ElementTree as ET 11 | from typing import Dict, List, Optional, Union 12 | 13 | 14 | def load_all_bboxes(annotation_dir: str) -> Dict[str, List[np.ndarray]]: 15 | """ 16 | Load all bounding boxes from XML annotations. 17 | 18 | Args: 19 | annotation_dir (str): Directory containing XML annotation files. 20 | 21 | Returns: 22 | Dict[str, List[np.ndarray]]: Dictionary containing bounding boxes for each file. 23 | """ 24 | gt_bboxes = {} 25 | anno_files = os.listdir(annotation_dir) 26 | for filename in anno_files: 27 | file = filename.split('.')[0] 28 | gt = ET.parse(os.path.join(annotation_dir, filename)).getroot() 29 | bboxes = [] 30 | for child in gt: 31 | for childs in child: 32 | bbox = [] 33 | if childs.tag == 'bbox': 34 | for index, ch in enumerate(childs): 35 | if index == 0: 36 | continue 37 | bbox.append(int(224 * int(ch.text) / 256)) 38 | bboxes.append(np.array(bbox)) 39 | gt_bboxes[file] = bboxes 40 | 41 | return gt_bboxes 42 | 43 | 44 | def bbox2gtmap(bboxes: List[List[int]]) -> np.ndarray: 45 | """ 46 | Convert bounding boxes to numpy ground truth map. 47 | 48 | Args: 49 | bboxes (List[List[int]]): List of bounding boxes. 50 | 51 | Returns: 52 | np.ndarray: Ground truth map. 53 | """ 54 | gt_map = np.zeros([224, 224]) 55 | for xmin, ymin, xmax, ymax in bboxes: 56 | temp = np.zeros([224, 224]) 57 | temp[ymin:ymax, xmin:xmax] = 1 58 | gt_map += temp 59 | gt_map = gt_map / 2 60 | gt_map[gt_map > 1] = 1 61 | 62 | return gt_map 63 | 64 | 65 | class FlickrDataset(Dataset): 66 | def __init__(self, data_path: str, split: str, is_train: bool = True, set_length: int = 10, 67 | input_resolution: int = 224): 68 | """ 69 | Initialize Flickr SoundNet Dataset. 70 | 71 | Args: 72 | data_path (str): Path to the dataset. 73 | split (str): Dataset split (Use csv file name in metadata directory). 74 | is_train (bool, optional): Whether it's a training set. Default is True. 75 | set_length (int, optional): Duration of input audio. Default is 10. 76 | input_resolution (int, optional): Resolution of input images. Default is 224. 77 | """ 78 | super(FlickrDataset, self).__init__() 79 | 80 | self.SAMPLE_RATE = 16000 81 | self.split = split 82 | self.set_length = set_length 83 | self.csv_dir = 'Flickr/metadata/' + split + '.csv' 84 | self.data_path = data_path 85 | self.is_trainset = True 86 | 87 | if split.split('.')[0].split('_')[-1] == 'test': 88 | self.data_path = os.path.join(data_path, 'test') 89 | self.is_trainset = False 90 | 91 | ''' Audio files ''' 92 | self.audio_path = os.path.join(self.data_path, 'audio') 93 | audio_files = set([fn.split('.wav')[0] for fn in os.listdir(self.audio_path) if fn.endswith('.wav')]) 94 | 95 | ''' Image files ''' 96 | self.image_path = os.path.join(self.data_path, 'frames') 97 | image_files = set([fn.split('.jpg')[0] for fn in os.listdir(self.image_path)]) 98 | 99 | ''' Ground truth (Bounding box) ''' 100 | gt_path = os.path.join(self.data_path, 'Annotations') 101 | if is_train: 102 | self.bbox_dict = None 103 | else: 104 | self.bbox_dict = load_all_bboxes(gt_path) 105 | 106 | ''' Ground truth (Text label) ''' 107 | self.label_dict = {item[0]: item[1] for item in csv.reader(open(self.csv_dir))} 108 | 109 | ''' Available files''' 110 | subset = set([item[0] for item in csv.reader(open(self.csv_dir))]) 111 | self.file_list = list(audio_files.intersection(image_files).intersection(subset)) 112 | self.file_list = sorted(self.file_list) 113 | 114 | ''' Transform ''' 115 | if is_train: 116 | self.image_transform = vt.Compose([ 117 | vt.Resize((int(input_resolution * 1.1), int(input_resolution * 1.1)), vt.InterpolationMode.BICUBIC), 118 | vt.ToTensor(), 119 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 120 | vt.RandomCrop((input_resolution, input_resolution)), 121 | vt.RandomHorizontalFlip(), 122 | ]) 123 | else: 124 | self.image_transform = vt.Compose([ 125 | vt.Resize((input_resolution, input_resolution), vt.InterpolationMode.BICUBIC), 126 | vt.ToTensor(), 127 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 128 | ]) 129 | 130 | def __len__(self): 131 | """ 132 | Return the number of items in the dataset. 133 | """ 134 | return len(self.file_list) 135 | 136 | def get_audio(self, item: int) -> torch.Tensor: 137 | """ 138 | Get audio data for a given item. 139 | 140 | Args: 141 | item (int): Index of the item. 142 | 143 | Returns: 144 | torch.Tensor: Audio data. 145 | """ 146 | audio_file, _ = torchaudio.load(os.path.join(self.audio_path, self.file_list[item] + '.wav')) 147 | audio_file = audio_file.squeeze(0) 148 | 149 | # slicing or padding based on set_length 150 | # slicing 151 | if audio_file.shape[0] > (self.SAMPLE_RATE * self.set_length): 152 | audio_file = audio_file[:self.SAMPLE_RATE * self.set_length] 153 | # zero padding 154 | if audio_file.shape[0] < (self.SAMPLE_RATE * self.set_length): 155 | pad_len = (self.SAMPLE_RATE * self.set_length) - audio_file.shape[0] 156 | pad_val = torch.zeros(pad_len) 157 | audio_file = torch.cat((audio_file, pad_val), dim=0) 158 | 159 | return audio_file 160 | 161 | def get_image(self, item: int) -> Image.Image: 162 | """ 163 | Get image data for a given item. 164 | 165 | Args: 166 | item (int): Index of the item. 167 | 168 | Returns: 169 | Image.Image: Image data. 170 | """ 171 | if self.is_trainset: 172 | file = os.listdir(os.path.join(self.image_path, self.file_list[item]))[0] # Get first frame for train 173 | image_file = Image.open(os.path.join(self.image_path, self.file_list[item], file)) 174 | else: 175 | image_file = Image.open(os.path.join(self.image_path, self.file_list[item] + '.jpg')) 176 | return image_file 177 | 178 | def get_bbox(self, item: int) -> Optional[torch.Tensor]: 179 | """ 180 | Get ground truth data for a given item. 181 | 182 | Args: 183 | item (int): Index of the item. 184 | 185 | Returns: 186 | Optional[torch.Tensor]: Ground truth data. 187 | """ 188 | # Bounding Box 189 | if self.bbox_dict is None: 190 | return None 191 | else: 192 | bbox = self.bbox_dict[self.file_list[item]] 193 | return vt.ToTensor()(bbox2gtmap(bbox)) 194 | 195 | def __getitem__(self, item: int) -> Dict[str, Union[torch.Tensor, torch.Tensor, Optional[torch.Tensor], str, str]]: 196 | """ 197 | Get item from the dataset. 198 | 199 | Args: 200 | item (int): Index of the item. 201 | 202 | Returns: 203 | Dict[str, Union[torch.Tensor, torch.Tensor, Optinal[torch.Tensor], str, str]]: Data example 204 | """ 205 | file_id = self.file_list[item] 206 | 207 | ''' Load data ''' 208 | audio_file = self.get_audio(item) 209 | image_file = self.get_image(item) 210 | label = self.label_dict[self.file_list[item]].replace('_', ' ') 211 | bboxes = self.get_bbox(item) 212 | 213 | ''' Transform ''' 214 | audio = audio_file 215 | image = self.image_transform(image_file) 216 | 217 | out = {'images': image, 'audios': audio, 'bboxes': bboxes, 'labels': label, 'ids': file_id} 218 | out = {key: value for key, value in out.items() if value is not None} 219 | return out 220 | 221 | 222 | class ExtendFlickrDataset(Dataset): 223 | def __init__(self, data_path: str, set_length: int = 10, input_resolution: int = 224): 224 | """ 225 | Initialize Extended Flickr Dataset. 226 | 227 | Args: 228 | data_path (str): Path to the dataset. 229 | set_length (int, optional): Duration of input audio. Default is 10. 230 | input_resolution (int, optional): Resolution of input images. Default is 224. 231 | """ 232 | super(ExtendFlickrDataset, self).__init__() 233 | 234 | self.SAMPLE_RATE = 16000 235 | self.set_length = set_length 236 | self.data_path = os.path.join(data_path, 'test') 237 | self.csv_dir = 'Flickr/metadata/flickr_test.csv' 238 | self.extend_csv_dir = 'Flickr/metadata/flickr_test_plus_silent.csv' 239 | self.split = 'exflickr' 240 | 241 | ''' Audio files ''' 242 | self.audio_path = os.path.join(data_path, 'extend_audio') 243 | audio_files = set([fn.split('.wav')[0] for fn in os.listdir(self.audio_path) if fn.endswith('.wav')]) 244 | 245 | ''' Image files ''' 246 | self.image_path = os.path.join(data_path, 'extend_frames') 247 | image_files = set([fn.split('.jpg')[0] for fn in os.listdir(self.image_path) if fn.endswith('.jpg')]) 248 | 249 | ''' Ground truth (Bounding box) ''' 250 | gt_path = os.path.join(self.data_path, 'Annotations') 251 | self.bbox_dict = load_all_bboxes(gt_path) 252 | 253 | ''' Ground truth (Text label) ''' 254 | self.label_dict = {item[0]: item[1] for item in csv.reader(open(self.csv_dir))} 255 | 256 | ''' Available files''' 257 | subset = set([item[0] for item in csv.reader(open(self.csv_dir))]) 258 | file_list = sorted(list(audio_files.intersection(image_files).intersection(subset))) 259 | self.image_files = [dt + '.jpg' for dt in file_list] 260 | self.audio_files = [dt + '.wav' for dt in file_list] 261 | self.bboxes = [self.bbox_dict[dt] for dt in file_list] 262 | self.labels = [self.label_dict[dt] for dt in file_list] 263 | 264 | ''' Add non-sounding files ''' 265 | for item in csv.reader(open(self.extend_csv_dir)): 266 | if item[2] == 'non-sounding': 267 | self.image_files.append(f'{item[0]}.jpg') 268 | self.audio_files.append(f'{item[1]}.wav') 269 | self.bboxes.append([]) 270 | self.labels.append('non-sounding') 271 | 272 | ''' Transform ''' 273 | self.image_transform = vt.Compose([ 274 | vt.Resize((input_resolution, input_resolution), vt.InterpolationMode.BICUBIC), 275 | vt.ToTensor(), 276 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 277 | ]) 278 | 279 | def __len__(self): 280 | """ 281 | Return the number of items in the dataset. 282 | """ 283 | return len(self.image_files) 284 | 285 | def get_audio(self, item: int) -> torch.Tensor: 286 | """ 287 | Get audio data for a given item. 288 | 289 | Args: 290 | item (int): Index of the item. 291 | 292 | Returns: 293 | torch.Tensor: Audio data. 294 | """ 295 | audio_file, _ = torchaudio.load(os.path.join(self.audio_path, self.audio_files[item])) 296 | audio_file = audio_file.squeeze(0) 297 | 298 | # slicing or padding based on set_length 299 | # slicing 300 | if audio_file.shape[0] > (self.SAMPLE_RATE * self.set_length): 301 | audio_file = audio_file[:self.SAMPLE_RATE * self.set_length] 302 | # zero padding 303 | if audio_file.shape[0] < (self.SAMPLE_RATE * self.set_length): 304 | pad_len = (self.SAMPLE_RATE * self.set_length) - audio_file.shape[0] 305 | pad_val = torch.zeros(pad_len) 306 | audio_file = torch.cat((audio_file, pad_val), dim=0) 307 | 308 | return audio_file 309 | 310 | def get_image(self, item: int) -> Image.Image: 311 | """ 312 | Get image data for a given item. 313 | 314 | Args: 315 | item (int): Index of the item. 316 | 317 | Returns: 318 | Image.Image: Image data. 319 | """ 320 | image_file = Image.open(os.path.join(self.image_path, self.image_files[item])) 321 | return image_file 322 | 323 | def get_bbox(self, item: int) -> Optional[torch.Tensor]: 324 | """ 325 | Get ground truth data for a given item. 326 | 327 | Args: 328 | item (int): Index of the item. 329 | 330 | Returns: 331 | Optional[torch.Tensor]: Ground truth data. 332 | """ 333 | # Bounding Box 334 | if len(self.bboxes[item]) == 0: 335 | return vt.ToTensor()(np.zeros([224, 224])) 336 | else: 337 | bbox = self.bboxes[item] 338 | return vt.ToTensor()(bbox2gtmap(bbox)) 339 | 340 | def __getitem__(self, item: int) -> Dict[str, Union[torch.Tensor, torch.Tensor, Optional[torch.Tensor], str, str]]: 341 | """ 342 | Get item from the dataset. 343 | 344 | Args: 345 | item (int): Index of the item. 346 | 347 | Returns: 348 | Dict[str, Union[torch.Tensor, torch.Tensor, Optinal[torch.Tensor], str, str]]: Data example 349 | """ 350 | ''' Load data ''' 351 | audio_file = self.get_audio(item) 352 | image_file = self.get_image(item) 353 | bboxes = self.get_bbox(item) 354 | label = self.labels[item] 355 | file_id = self.image_files[item].split('.')[0] + '_' + self.audio_files[item].split('.')[0] 356 | 357 | ''' Transform ''' 358 | audio = audio_file 359 | image = self.image_transform(image_file) 360 | 361 | out = {'images': image, 'audios': audio, 'bboxes': bboxes, 'labels': label, 'ids': file_id} 362 | out = {key: value for key, value in out.items() if value is not None} 363 | return out 364 | -------------------------------------------------------------------------------- /Flickr/README.md: -------------------------------------------------------------------------------- 1 | # Directory guide for Flickr 2 | ```commandline 3 | ├── Flikcr/ 4 | | ├── audio 5 | | | └── ... 6 | | | └── *** .wav 7 | | ├── extend_audio 8 | | | └── ... 9 | | | └── *** .wav 10 | | ├── frames 11 | | | └── ... 12 | | | └── *** .jpg 13 | | ├── extend_frames 14 | | | └── ... 15 | | | └── *** .jpg 16 | | ├── Test 17 | | | ├── Annotations 18 | | | | └── ... 19 | | | | └── *** .xml 20 | | | ├── audio 21 | | | | └── ... 22 | | | | └── *** .wav 23 | | | ├── frames 24 | | | | └── ... 25 | | | | └── *** .jpg 26 | | | ├── extend_audio 27 | | | └── ... 28 | | | | └── *** .wav 29 | | | ├── extend_frames 30 | | | | └── ... 31 | | | | └── *** .jpg 32 | 33 | ``` 34 | All .wav files sampled 16k -------------------------------------------------------------------------------- /Flickr/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn import metrics 4 | from typing import List, Optional, Tuple, Dict 5 | 6 | 7 | class Evaluator(object): 8 | def __init__(self) -> None: 9 | """ 10 | Initialize the Flickr Evaluator. 11 | 12 | Attributes: 13 | ciou (List[float]): Buffer of cIoU values. 14 | AUC (List[float]): Buffer of AUC values. 15 | N (int): Counter for the number of evaluations. 16 | metrics (List[str]): List of metric names. 17 | """ 18 | super(Evaluator, self).__init__() 19 | self.ciou = [] 20 | self.AUC = [] 21 | self.N = 0 22 | self.metrics = ['cIoU', 'AUC'] 23 | 24 | def evaluate_batch(self, pred: torch.Tensor, target: torch.Tensor, thr: Optional[float] = None) -> None: 25 | """ 26 | Evaluate a batch of predictions against ground truth. 27 | 28 | Args: 29 | pred (torch.Tensor): Model predictions. 30 | target (torch.Tensor): Ground truth maps. 31 | thr (Optional[float]): Threshold for binary classification. If None, dynamically determined. 32 | 33 | Returns: 34 | None 35 | """ 36 | for j in range(pred.size(0)): 37 | infer = pred[j] 38 | gt = target[j] 39 | if thr is None: 40 | thr = np.sort(infer.detach().cpu().numpy().flatten())[int(infer.shape[1] * infer.shape[2] / 2)] 41 | self.cal_CIOU(infer, gt, thr) 42 | 43 | def cal_CIOU(self, infer: torch.Tensor, gtmap: torch.Tensor, thres: float = 0.01) -> List[float]: 44 | """ 45 | Calculate cIoU (consensus Intersection over Union). 46 | 47 | Args: 48 | infer (torch.Tensor): Model prediction. 49 | gtmap (torch.Tensor): Ground truth map. 50 | thres (float): Threshold for binary classification. 51 | 52 | Returns: 53 | List[float]: List of cIoU values for each instance in the batch. 54 | """ 55 | infer_map = torch.zeros_like(gtmap) 56 | infer_map[infer >= thres] = 1 57 | ciou = (infer_map * gtmap).sum(2).sum(1) / (gtmap.sum(2).sum(1) + (infer_map * (gtmap == 0)).sum(2).sum(1)) 58 | 59 | for i in range(gtmap.size(0)): 60 | self.ciou.append(ciou[i].detach().cpu()) 61 | return ciou 62 | 63 | def finalize_AUC(self) -> float: 64 | """ 65 | Calculate the Area Under the Curve (AUC). 66 | 67 | Returns: 68 | float: AUC value. 69 | """ 70 | cious = [np.sum(np.array(self.ciou) >= 0.05 * i) / len(self.ciou) 71 | for i in range(21)] 72 | thr = [0.05 * i for i in range(21)] 73 | auc = metrics.auc(thr, cious) 74 | return auc 75 | 76 | def finalize_AP50(self) -> float: 77 | """ 78 | Calculate Average Precision (cIoU@0.5). 79 | 80 | Returns: 81 | float: cIoU@0.5 value. 82 | """ 83 | ap50 = np.mean(np.array(self.ciou) >= 0.5) 84 | return ap50 85 | 86 | def finalize_cIoU(self) -> float: 87 | """ 88 | Calculate mean cIoU. 89 | 90 | Returns: 91 | float: Mean cIoU value. 92 | """ 93 | ciou = np.mean(np.array(self.ciou)) 94 | return ciou 95 | 96 | def finalize(self) -> Tuple[List[str], Dict[str, float]]: 97 | """ 98 | Finalize and return evaluation metrics. 99 | 100 | Returns: 101 | Tuple[List[str], Dict[str, float]]: List of metric names and corresponding values. 102 | """ 103 | ap50 = self.finalize_AP50() * 100 104 | auc = self.finalize_AUC() * 100 105 | return self.metrics, {self.metrics[0]: ap50, self.metrics[1]: auc} 106 | -------------------------------------------------------------------------------- /Flickr/extend_eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import metrics 4 | 5 | 6 | class Evaluator(object): 7 | def __init__(self, iou_thrs=(0.5, ), default_conf_thr=0.5, pred_size=0.5, pred_thr=0.5, 8 | results_dir='./results'): 9 | """ 10 | Initialize the Extended Flickr evaluator. 11 | 12 | Notes: 13 | Taking computation speed into consideration, it is set to output only the 'all' subset. (AP, Max-F1) 14 | """ 15 | super(Evaluator, self).__init__() 16 | self.iou_thrs = iou_thrs 17 | self.default_conf_thr = default_conf_thr 18 | self.min_sizes = {'small': 0, 'medium': 32 ** 2, 'large': 96 ** 2, 'huge': 144 ** 2} 19 | self.max_sizes = {'small': 32 ** 2, 'medium': 96 ** 2, 'large': 144 ** 2, 'huge': 10000 ** 2} 20 | 21 | self.ciou_list = [] 22 | self.area_list = [] 23 | self.confidence_list = [] 24 | self.name_list = [] 25 | self.bb_list = [] 26 | # self.metrics = ['AP', 'Max-F1', 'LocAcc'] 27 | self.metrics = ['AP', 'Max-F1'] 28 | 29 | self.results_dir = results_dir 30 | self.viz_save_dir = f"{results_dir}/viz_conf" + str(default_conf_thr) + "_predsize" + str( 31 | pred_size) + "_predthr" + str(pred_thr) 32 | self.results_save_dir = f"{results_dir}/results_conf" + str(default_conf_thr) + "_predsize" + str( 33 | pred_size) + "_predthr" + str(pred_thr) 34 | 35 | @staticmethod 36 | def calc_precision_recall(bb_list, ciou_list, confidence_list, confidence_thr, ciou_thr=0.5): 37 | assert len(bb_list) == len(ciou_list) == len(confidence_list) 38 | true_pos, false_pos, false_neg = 0, 0, 0 39 | for bb, ciou, confidence in zip(bb_list, ciou_list, confidence_list): 40 | if bb == 0: 41 | # no sounding objects in frame 42 | if confidence >= confidence_thr: 43 | # sounding object detected 44 | false_pos += 1 45 | else: 46 | # sounding objects in frame 47 | if confidence >= confidence_thr: 48 | # sounding object detected... 49 | if ciou >= ciou_thr: # ...in correct place 50 | true_pos += 1 51 | else: # ...in wrong place 52 | false_pos += 1 53 | else: 54 | # no sounding objects detected 55 | false_neg += 1 56 | 57 | precision = 1. if true_pos + false_pos == 0 else true_pos / (true_pos + false_pos) 58 | recall = 1. if true_pos + false_neg == 0 else true_pos / (true_pos + false_neg) 59 | 60 | return precision, recall 61 | 62 | def calc_ap(self, bb_list_full, ciou_list_full, confidence_list_full, iou_thr=0.5): 63 | 64 | assert len(bb_list_full) == len(ciou_list_full) == len(confidence_list_full) 65 | 66 | # for visible objects 67 | # ss = [i for i, bb in enumerate(bb_list_full) if bb > 0] 68 | # bb_list = [bb_list_full[i] for i in ss] 69 | # ciou_list = [ciou_list_full[i] for i in ss] 70 | # confidence_list = [confidence_list_full[i] for i in ss] 71 | 72 | precision, recall, skip_thr = [], [], max(1, len(ciou_list_full) // 200) 73 | for thr in np.sort(np.array(confidence_list_full))[:-1][::-skip_thr]: 74 | p, r = self.calc_precision_recall(bb_list_full, ciou_list_full, confidence_list_full, thr, iou_thr) 75 | precision.append(p) 76 | recall.append(r) 77 | precision_max = [np.max(precision[i:]) for i in range(len(precision))] 78 | ap = sum([precision_max[i] * (recall[i + 1] - recall[i]) 79 | for i in range(len(precision_max) - 1)]) 80 | return ap 81 | 82 | def cal_auc(self, bb_list, ciou_list): 83 | ss = [i for i, bb in enumerate(bb_list) if bb > 0] 84 | ciou = [ciou_list[i] for i in ss] 85 | cious = [np.sum(np.array(ciou) >= 0.05 * i) / len(ciou) 86 | for i in range(21)] 87 | thr = [0.05 * i for i in range(21)] 88 | auc = metrics.auc(thr, cious) 89 | return auc 90 | 91 | def filter_subset(self, subset, name_list, area_list, bb_list, ciou_list, conf_list): 92 | if subset == 'visible': 93 | ss = [i for i, bb in enumerate(bb_list) if bb > 0] 94 | elif subset == 'non-visible/non-audible': 95 | ss = [i for i, bb in enumerate(bb_list) if bb == 0] 96 | elif subset == 'all': 97 | ss = [i for i, bb in enumerate(bb_list) if bb >= 0] 98 | else: 99 | ss = [i for i, sz in enumerate(area_list) 100 | if self.min_sizes[subset] <= sz < self.max_sizes[subset] and bb_list[i] > 0] 101 | 102 | if len(ss) == 0: 103 | return [], [], [], [], [] 104 | 105 | name = [name_list[i] for i in ss] 106 | area = [area_list[i] for i in ss] 107 | bbox = [bb_list[i] for i in ss] 108 | ciou = [ciou_list[i] for i in ss] 109 | conf = [conf_list[i] for i in ss] 110 | 111 | return name, area, bbox, ciou, conf 112 | 113 | def finalize_stats(self): 114 | name_full_list, area_full_list, bb_full_list, ciou_full_list, confidence_full_list = self.gather_results() 115 | 116 | metrics = {} 117 | for iou_thr in self.iou_thrs: 118 | # for subset in ['all', 'visible']: 119 | for subset in ['all']: 120 | _, _, bb_list, ciou_list, conf_list = self.filter_subset(subset, name_full_list, area_full_list, 121 | bb_full_list, ciou_full_list, 122 | confidence_full_list) 123 | subset_name = f'{subset}@{int(iou_thr * 100)}' if subset is not None else f'@{int(iou_thr * 100)}' 124 | if len(ciou_list) == 0: 125 | p, r, ap, f1, auc = np.nan, np.nan, np.nan, np.nan, np.nan 126 | else: 127 | p, r = self.calc_precision_recall(bb_list, ciou_list, conf_list, -1000, iou_thr) 128 | ap = self.calc_ap(bb_list, ciou_list, conf_list, iou_thr) 129 | auc = self.cal_auc(bb_list, ciou_list) 130 | 131 | conf_thr = list(sorted(conf_list))[::max(1, len(conf_list) // 10)] 132 | pr = [self.calc_precision_recall(bb_list, ciou_list, conf_list, thr, iou_thr) for thr in conf_thr] 133 | f1 = [2 * r * p / (r + p) if r + p > 0 else 0. for p, r in pr] 134 | if subset == 'all' and iou_thr == 0.5: 135 | ef1 = max(f1) 136 | eap = ap 137 | metrics['ef1'] = ef1 138 | metrics['eap'] = eap 139 | if subset == 'visible' and iou_thr == 0.5: 140 | eloc = self.precision_at_50() 141 | eauc = auc 142 | metrics['eloc'] = eloc 143 | metrics['eauc'] = eauc 144 | metrics[f'Precision-{subset_name}'] = p 145 | # metrics[f'Recall-{subset_name}'] = r 146 | if np.isnan(f1).any(): 147 | metrics[f'F1-{subset_name}'] = f1 148 | else: 149 | metrics[f'F1-{subset_name}'] = ' '.join([f'{f * 100:.1f}' for f in f1]) 150 | metrics[f'AP-{subset_name}'] = ap 151 | metrics[f'AUC-{subset_name}'] = auc 152 | 153 | return metrics 154 | 155 | def gather_results(self): 156 | # import torch.distributed as dist 157 | # if not dist.is_initialized(): 158 | return self.name_list, self.area_list, self.bb_list, self.ciou_list, self.confidence_list 159 | # world_size = dist.get_world_size() 160 | # 161 | # bb_list = [None for _ in range(world_size)] 162 | # dist.all_gather_object(bb_list, self.bb_list) 163 | # bb_list = [x for bb in bb_list for x in bb] 164 | # 165 | # area_list = [None for _ in range(world_size)] 166 | # dist.all_gather_object(area_list, self.area_list) 167 | # area_list = [x for area in area_list for x in area] 168 | # 169 | # ciou_list = [None for _ in range(world_size)] 170 | # dist.all_gather_object(ciou_list, self.ciou_list) 171 | # ciou_list = [x for ciou in ciou_list for x in ciou] 172 | # 173 | # confidence_list = [None for _ in range(world_size)] 174 | # dist.all_gather_object(confidence_list, self.confidence_list) 175 | # confidence_list = [x for conf in confidence_list for x in conf] 176 | # 177 | # name_list = [None for _ in range(world_size)] 178 | # dist.all_gather_object(name_list, self.name_list) 179 | # name_list = [x for name in name_list for x in name] 180 | 181 | # return name_list, area_list, bb_list, ciou_list, confidence_list 182 | 183 | def precision_at_50(self): 184 | ss = [i for i, bb in enumerate(self.bb_list) if bb > 0] 185 | return np.mean(np.array([self.ciou_list[i] for i in ss]) > 0.5) 186 | 187 | def precision_at_50_object(self): 188 | max_num_obj = max(self.bb_list) 189 | for num_obj in range(1, max_num_obj + 1): 190 | ss = [i for i, bb in enumerate(self.bb_list) if bb == num_obj] 191 | precision = np.mean(np.array([self.ciou_list[i] for i in ss]) > 0.5) 192 | print('\n' + f'num_obj:{num_obj}, precision:{precision}') 193 | 194 | def f1_at_50(self): 195 | # conf_thr = np.array(self.confidence_list).mean() 196 | p, r = self.calc_precision_recall(self.bb_list, self.ciou_list, self.confidence_list, self.default_conf_thr, 197 | 0.5) 198 | return 2 * p * r / (p + r) if (p + r) > 0 else 0. 199 | 200 | def ap_at_50(self): 201 | return self.calc_ap(self.bb_list, self.ciou_list, self.confidence_list, 0.5) 202 | 203 | def clear(self): 204 | self.ciou_list = [] 205 | self.area_list = [] 206 | self.confidence_list = [] 207 | self.name_list = [] 208 | self.bb_list = [] 209 | 210 | def update(self, bb, gt, conf, pred, pred_thr, name): 211 | if isinstance(conf, torch.Tensor): 212 | conf = conf.detach().cpu().numpy() 213 | if isinstance(pred, torch.Tensor): 214 | pred = pred.detach().cpu().numpy() 215 | if isinstance(gt, torch.Tensor): 216 | gt = gt.detach().cpu().numpy() 217 | 218 | # Compute binary prediction map 219 | infer = np.zeros((224, 224)) 220 | infer[pred >= pred_thr] = 1 221 | 222 | # Compute ciou between prediction and ground truth 223 | ciou = np.sum(infer * gt) / (np.sum(gt) + np.sum(infer * (gt == 0)) + 1e-12) 224 | 225 | # Compute ground truth size 226 | area = gt.sum() 227 | 228 | # Save 229 | self.confidence_list.append(conf) 230 | self.ciou_list.append(ciou) 231 | self.area_list.append(area) 232 | self.name_list.append(name) 233 | self.bb_list.append(bb) 234 | 235 | def evaluate_batch(self, pred, gt, label, conf, name, thr=None): 236 | for i in range(pred.shape[0]): 237 | infer = pred[i, 0].detach().cpu().numpy() 238 | if thr is None: 239 | thr = np.sort(infer.flatten())[int(infer.shape[0] * infer.shape[1] * 0.5)] 240 | 241 | bb = 1 if label[i] != 'non-sounding' else 0 242 | 243 | self.update(bb, gt[i, 0], conf[i], infer, thr, name[i]) 244 | 245 | def finalize(self): 246 | metric_extend = self.finalize_stats() 247 | eap = metric_extend['AP-all@50'] 248 | ef1 = metric_extend['F1-all@50'] 249 | # eloc = metric_extend['Precision-visible@50'] 250 | emaxf1 = max([float(num) for num in ef1.split(' ')]) 251 | return self.metrics, {self.metrics[0]: eap*100, self.metrics[1]: emaxf1} # , self.metrics[2]: eloc*100} 252 | -------------------------------------------------------------------------------- /Flickr/metadata/flickr_test.csv: -------------------------------------------------------------------------------- 1 | 10000130166,child singing 2 | 10007936344,playing tambourine 3 | 10008553263,playing djembe 4 | 10009662863,female singing 5 | 10013411946,skidding 6 | 10016382545,driving motorcycle 7 | 10031203703,shot football 8 | 10035917404,singing choir 9 | 10045181004,children shouting 10 | 10060697266,tap dancing 11 | 10061269855,baby crying 12 | 10067897816,basketball bounce 13 | 10069511285,people cheering 14 | 10079676496,shot football 15 | 10080652613,shot football 16 | 10106776154,playing volleyball 17 | 10109936483,baltimore oriole calling 18 | 10110769444,people crowd 19 | 10111706074,"male speech, man speaking" 20 | 10119607346,people marching 21 | 10119736773,people marching 22 | 10129791115,shot football 23 | 10129813456,rope skipping 24 | 10163352444,playing volleyball 25 | 10165890496,helicopter 26 | 10173808724,church bell ringing 27 | 10202025415,"railroad car, train wagon" 28 | 10221582415,baby crying 29 | 10234680056,roller coaster running 30 | 10246013484,people crowd 31 | 10268129763,reversing beeps 32 | 10278084464,people shuffling 33 | 10283938426,shot football 34 | 10289643764,rope skipping 35 | 10304554474,basketball bounce 36 | 10308472565,skateboarding 37 | 10369815925,children shouting 38 | 10389014583,dinosaurs bellowing 39 | 10409146004,people booing 40 | 10409351364,sea waves 41 | 10413912845,fire truck siren 42 | 10437048305,child singing 43 | 10440265465,"playing marimba, xylophone" 44 | 10441843213,train horning 45 | 10446128125,baby laughter 46 | 10451898725,coyote howling 47 | 10468313544,train whistling 48 | 10476442285,goat bleating 49 | 10476446674,goat bleating 50 | 10477544895,playing harp 51 | 10481128583,bouncing on trampoline 52 | 10520382525,playing synthesizer 53 | 10545790805,"female speech, woman speaking" 54 | 10548273474,bouncing on trampoline 55 | 10549911663,playing guiro 56 | 10555704135,missile launch 57 | 10557361454,car passing by 58 | 10610974774,playing bagpipes 59 | 10624261424,female singing 60 | 10635080206,police radio chatter 61 | 10654057765,playing washboard 62 | 10659354545,people marching 63 | 10666396623,baby babbling 64 | 10667001125,alarm clock ringing 65 | 10701841844,people screaming 66 | 10743985374,tap dancing 67 | 10751849775,"race car, auto racing" 68 | 10761625676,train whistling 69 | 10771238614,tractor digging 70 | 10795023275,playing tennis 71 | 10796050285,splashing water 72 | 10805310706,playing tennis 73 | 10847897294,children shouting 74 | 10859255295,playing cymbal 75 | 10862172945,baby laughter 76 | 10939270325,tractor digging 77 | 10952523225,race car 78 | 10979853044,ocean burbling 79 | 10980225216,ocean burbling 80 | 11002352834,tractor digging 81 | 11036641226,reversing beeps 82 | 11050409385,rope skipping 83 | 11071986404,sliding door 84 | 11101358996,children shouting 85 | 11154049584,train horning 86 | 11314266424,sea waves 87 | 11650557765,vacuum cleaner cleaning floors 88 | 11764496003,driving motorcycle 89 | 11765015695,train whistling 90 | 11776645404,people shuffling 91 | 12015590114,eating with cutlery 92 | 12031178616,rope skipping 93 | 12048726756,playing hockey 94 | 12066557153,playing hockey 95 | 12103535156,helicopter 96 | 12158276143,helicopter 97 | 12328837165,people crowd 98 | 12373955905,cap gun shooting 99 | 12444360833,playing bugle 100 | 12512072175,playing tympani 101 | 12534178685,people marching 102 | 12598610153,child singing 103 | 12599965145,rope skipping 104 | 12729873904,door slamming 105 | 13153991894,basketball bounce 106 | 13234495505,"railroad car, train wagon" 107 | 13409856383,tractor digging 108 | 13447842855,ambulance siren 109 | 13579448834,people marching 110 | 13660204414,airplane flyby 111 | 13721325644,train whistling 112 | 13861270725,sloshing water 113 | 14062380939,playing bugle 114 | 14172932360,helicopter 115 | 14174434304,crow cawing 116 | 14203929291,people marching 117 | 14414061986,mosquito buzzing 118 | 14566072101,people marching 119 | 14648401332,people shuffling 120 | 14652810592,people marching 121 | 14678892718,swimming 122 | 14680986811,people crowd 123 | 14742398172,people screaming 124 | 14778198185,helicopter 125 | 14899015793,shot football 126 | 15135657108,firing muskets 127 | 15210828119,church bell ringing 128 | 15679023977,car passing by 129 | 15799251138,singing choir 130 | 15994498891,people eating noodle 131 | 16012165838,fire truck siren 132 | 16101081359,shot football 133 | 16369370669,playing hockey 134 | 16419739420,people marching 135 | 16563757295,people marching 136 | 16659090834,playing bugle 137 | 16804898716,missile launch 138 | 16950760048,people running 139 | 17140316339,rope skipping 140 | 17537079145,singing choir 141 | 19306860381,playing bagpipes 142 | 20854461578,playing bagpipes 143 | 2401957951,rope skipping 144 | 2404965389,people booing 145 | 2432219254,playing table tennis 146 | 2462141223,shot football 147 | 2465634368,scuba diving 148 | 2499673064,volcano explosion 149 | 2509250774,playing hockey 150 | 2542659118,train wheels squealing 151 | 2630218696,playing hockey 152 | 2695985181,plastic bottle crushing 153 | 2698329253,firing cannon 154 | 2766580444,horse neighing 155 | 2778088729,tap dancing 156 | 2808068937,chainsawing trees 157 | 2811549163,hail 158 | 2819123278,driving motorcycle 159 | 2858344348,people cheering 160 | 2895192729,driving motorcycle 161 | 2897081671,tractor digging 162 | 2897653916,cap gun shooting 163 | 2907554899,people cheering 164 | 2935372113,penguins braying 165 | 2961466546,playing volleyball 166 | 3012435981,tractor digging 167 | 3052033339,shot football 168 | 3102685146,chainsawing trees 169 | 3109408703,reversing beeps 170 | 3202731977,train horning 171 | 3207674548,playing hockey 172 | 3230152596,people cheering 173 | 3403348616,playing hockey 174 | 3413119417,"playing marimba, xylophone" 175 | 3416359816,door slamming 176 | 3484198977,dog howling 177 | 3535748358,playing hockey 178 | 3568893693,driving motorcycle 179 | 3668283413,tractor digging 180 | 3709678770,people marching 181 | 3719297546,squishing water 182 | 3727937033,fire truck siren 183 | 3749326537,shot football 184 | 3790233127,sea waves 185 | 3801318146,dog barking 186 | 3828103578,driving motorcycle 187 | 3896768873,swimming 188 | 3908875350,cupboard opening or closing 189 | 3956341886,chainsawing trees 190 | 4007045926,ripping paper 191 | 4041216001,people booing 192 | 4180455681,people eating 193 | 4303617148,playing hockey 194 | 4351373013,firing cannon 195 | 4407899725,tractor digging 196 | 4479260803,"male speech, man speaking" 197 | 4543136011,driving motorcycle 198 | 4576746825,driving motorcycle 199 | 4646464908,swimming 200 | 4657718822,firing cannon 201 | 4717096777,shot football 202 | 4755507106,"engine accelerating, revving, vroom" 203 | 4758950312,car engine idling 204 | 4767923306,striking pool 205 | 4876943924,horse neighing 206 | 4938432980,ocean burbling 207 | 4957886467,roller coaster running 208 | 4965189170,"engine accelerating, revving, vroom" 209 | 5006362787,tractor digging 210 | 5130986168,shot football 211 | 5179649119,telephone bell ringing 212 | 5194793239,driving motorcycle 213 | 5220919991,shot football 214 | 5237296528,people booing 215 | 5303633386,playing hockey 216 | 5344317532,playing hockey 217 | 5480070309,driving motorcycle 218 | 5490101320,bouncing on trampoline 219 | 5601291878,parrot talking 220 | 5607310495,striking pool 221 | 5697766486,machine gun shooting 222 | 5801923729,fire truck siren 223 | 5873098388,reversing beeps 224 | 5906878321,"race car, auto racing" 225 | 5991393107,helicopter 226 | 6067281197,firing cannon 227 | 6098668510,cap gun shooting 228 | 6158154336,children shouting 229 | 6185482260,"female speech, woman speaking" 230 | 6289328021,"donkey, ass braying" 231 | 6458319057,train wheels squealing 232 | 6669466181,"railroad car, train wagon" 233 | 6755812421,tractor digging 234 | 6897499873,playing tennis 235 | 7178293884,swimming 236 | 7387939334,people marching 237 | 7560517176,reversing beeps 238 | 7623513024,hedge trimmer running 239 | 7733838448,people booing 240 | 7740990330,baby babbling 241 | 7897462346,tractor digging 242 | 8250285374,people marching 243 | 8294094053,basketball bounce 244 | 8311669586,horse neighing 245 | 8396975855,door slamming 246 | 9060997789,airplane 247 | 9309351423,"railroad car, train wagon" 248 | 9456627660,swimming 249 | 9636000842,firing cannon 250 | 9761803312,playing hockey 251 | -------------------------------------------------------------------------------- /Flickr/metadata/flickr_test_plus_silent.csv: -------------------------------------------------------------------------------- 1 | video,audio,label 2 | 10000130166,10437048305,non-sounding 3 | 10007936344,10079676496,non-sounding 4 | 10008553263,10163352444,non-sounding 5 | 10009662863,15799251138,non-sounding 6 | 10013411946,3102685146,non-sounding 7 | 10016382545,11036641226,non-sounding 8 | 10031203703,10862172945,non-sounding 9 | 10035917404,12031178616,non-sounding 10 | 10045181004,19306860381,non-sounding 11 | 10060697266,14203929291,non-sounding 12 | 10061269855,13153991894,non-sounding 13 | 10067897816,2499673064,non-sounding 14 | 10069511285,3801318146,non-sounding 15 | 10079676496,14648401332,non-sounding 16 | 10080652613,10624261424,non-sounding 17 | 10106776154,10069511285,non-sounding 18 | 10109936483,10106776154,non-sounding 19 | 10110769444,12328837165,non-sounding 20 | 10111706074,10635080206,non-sounding 21 | 10119607346,10441843213,non-sounding 22 | 10119736773,10109936483,non-sounding 23 | 10129791115,4543136011,non-sounding 24 | 10129813456,12373955905,non-sounding 25 | 10163352444,6458319057,non-sounding 26 | 10165890496,6289328021,non-sounding 27 | 10173808724,9456627660,non-sounding 28 | 10202025415,5697766486,non-sounding 29 | 10221582415,3416359816,non-sounding 30 | 10234680056,10979853044,non-sounding 31 | 10246013484,3668283413,non-sounding 32 | 10268129763,10109936483,non-sounding 33 | 10278084464,5130986168,non-sounding 34 | 10283938426,3709678770,non-sounding 35 | 10289643764,10980225216,non-sounding 36 | 10304554474,7178293884,non-sounding 37 | 10308472565,10939270325,non-sounding 38 | 10369815925,14414061986,non-sounding 39 | 10389014583,3230152596,non-sounding 40 | 10409146004,12328837165,non-sounding 41 | 10409351364,10234680056,non-sounding 42 | 10413912845,5607310495,non-sounding 43 | 10437048305,5601291878,non-sounding 44 | 10440265465,2698329253,non-sounding 45 | 10441843213,3709678770,non-sounding 46 | 10446128125,12599965145,non-sounding 47 | 10451898725,10548273474,non-sounding 48 | 10468313544,11154049584,non-sounding 49 | 10476442285,4576746825,non-sounding 50 | 10476446674,16101081359,non-sounding 51 | 10477544895,16419739420,non-sounding 52 | 10481128583,3535748358,non-sounding 53 | 10520382525,6289328021,non-sounding 54 | 10545790805,10008553263,non-sounding 55 | 10548273474,10659354545,non-sounding 56 | 10549911663,10283938426,non-sounding 57 | 10555704135,10106776154,non-sounding 58 | 10557361454,10610974774,non-sounding 59 | 10610974774,10666396623,non-sounding 60 | 10624261424,14648401332,non-sounding 61 | 10635080206,3535748358,non-sounding 62 | 10654057765,10016382545,non-sounding 63 | 10659354545,14172932360,non-sounding 64 | 10666396623,2778088729,non-sounding 65 | 10667001125,2907554899,non-sounding 66 | 10701841844,16101081359,non-sounding 67 | 10743985374,14203929291,non-sounding 68 | 10751849775,10549911663,non-sounding 69 | 10761625676,3413119417,non-sounding 70 | 10771238614,3790233127,non-sounding 71 | 10795023275,3956341886,non-sounding 72 | 10796050285,14203929291,non-sounding 73 | 10805310706,10795023275,non-sounding 74 | 10847897294,10111706074,non-sounding 75 | 10859255295,4657718822,non-sounding 76 | 10862172945,10980225216,non-sounding 77 | 10939270325,10468313544,non-sounding 78 | 10952523225,10111706074,non-sounding 79 | 10979853044,2766580444,non-sounding 80 | 10980225216,2766580444,non-sounding 81 | 11002352834,7560517176,non-sounding 82 | 11036641226,11002352834,non-sounding 83 | 11050409385,14172932360,non-sounding 84 | 11071986404,10476442285,non-sounding 85 | 11101358996,3719297546,non-sounding 86 | 11154049584,10557361454,non-sounding 87 | 11314266424,10980225216,non-sounding 88 | 11650557765,2935372113,non-sounding 89 | 11764496003,10069511285,non-sounding 90 | 11765015695,12328837165,non-sounding 91 | 11776645404,11650557765,non-sounding 92 | 12015590114,16101081359,non-sounding 93 | 12031178616,13579448834,non-sounding 94 | 12048726756,2778088729,non-sounding 95 | 12066557153,10635080206,non-sounding 96 | 12103535156,3102685146,non-sounding 97 | 12158276143,10451898725,non-sounding 98 | 12328837165,4543136011,non-sounding 99 | 12373955905,2465634368,non-sounding 100 | 12444360833,5344317532,non-sounding 101 | 12512072175,11071986404,non-sounding 102 | 12534178685,14652810592,non-sounding 103 | 12598610153,10481128583,non-sounding 104 | 12599965145,12031178616,non-sounding 105 | 12729873904,12444360833,non-sounding 106 | 13153991894,3413119417,non-sounding 107 | 13234495505,12048726756,non-sounding 108 | 13409856383,14680986811,non-sounding 109 | 13447842855,2698329253,non-sounding 110 | 13579448834,10060697266,non-sounding 111 | 13660204414,13409856383,non-sounding 112 | 13721325644,4758950312,non-sounding 113 | 13861270725,10548273474,non-sounding 114 | 14062380939,3828103578,non-sounding 115 | 14172932360,5179649119,non-sounding 116 | 14174434304,3230152596,non-sounding 117 | 14203929291,12328837165,non-sounding 118 | 14414061986,10667001125,non-sounding 119 | 14566072101,9060997789,non-sounding 120 | 14648401332,10481128583,non-sounding 121 | 14652810592,12598610153,non-sounding 122 | 14678892718,2630218696,non-sounding 123 | 14680986811,2895192729,non-sounding 124 | 14742398172,5801923729,non-sounding 125 | 14778198185,8396975855,non-sounding 126 | 14899015793,20854461578,non-sounding 127 | 15135657108,5220919991,non-sounding 128 | 15210828119,12015590114,non-sounding 129 | 15679023977,6897499873,non-sounding 130 | 15799251138,5237296528,non-sounding 131 | 15994498891,8311669586,non-sounding 132 | 16012165838,2509250774,non-sounding 133 | 16101081359,13579448834,non-sounding 134 | 16369370669,10795023275,non-sounding 135 | 16419739420,2961466546,non-sounding 136 | 16563757295,6755812421,non-sounding 137 | 16659090834,10862172945,non-sounding 138 | 16804898716,10129813456,non-sounding 139 | 16950760048,5480070309,non-sounding 140 | 17140316339,10129791115,non-sounding 141 | 17537079145,10520382525,non-sounding 142 | 19306860381,12373955905,non-sounding 143 | 20854461578,3896768873,non-sounding 144 | 2401957951,16563757295,non-sounding 145 | 2404965389,2897653916,non-sounding 146 | 2432219254,10045181004,non-sounding 147 | 2462141223,15799251138,non-sounding 148 | 2465634368,14652810592,non-sounding 149 | 2499673064,5906878321,non-sounding 150 | 2509250774,10119736773,non-sounding 151 | 2542659118,10119736773,non-sounding 152 | 2630218696,10468313544,non-sounding 153 | 2695985181,15210828119,non-sounding 154 | 2698329253,5480070309,non-sounding 155 | 2766580444,10751849775,non-sounding 156 | 2778088729,10000130166,non-sounding 157 | 2808068937,15679023977,non-sounding 158 | 2811549163,16950760048,non-sounding 159 | 2819123278,2808068937,non-sounding 160 | 2858344348,14174434304,non-sounding 161 | 2895192729,12066557153,non-sounding 162 | 2897081671,4303617148,non-sounding 163 | 2897653916,10451898725,non-sounding 164 | 2907554899,11071986404,non-sounding 165 | 2935372113,10520382525,non-sounding 166 | 2961466546,3709678770,non-sounding 167 | 3012435981,3535748358,non-sounding 168 | 3052033339,10308472565,non-sounding 169 | 3102685146,10369815925,non-sounding 170 | 3109408703,3230152596,non-sounding 171 | 3202731977,10129791115,non-sounding 172 | 3207674548,5607310495,non-sounding 173 | 3230152596,9636000842,non-sounding 174 | 3403348616,5194793239,non-sounding 175 | 3413119417,10557361454,non-sounding 176 | 3416359816,9636000842,non-sounding 177 | 3484198977,10035917404,non-sounding 178 | 3535748358,8294094053,non-sounding 179 | 3568893693,4007045926,non-sounding 180 | 3668283413,10667001125,non-sounding 181 | 3709678770,2695985181,non-sounding 182 | 3719297546,10761625676,non-sounding 183 | 3727937033,10743985374,non-sounding 184 | 3749326537,10446128125,non-sounding 185 | 3790233127,14680986811,non-sounding 186 | 3801318146,7560517176,non-sounding 187 | 3828103578,7740990330,non-sounding 188 | 3896768873,10939270325,non-sounding 189 | 3908875350,2897653916,non-sounding 190 | 3956341886,3484198977,non-sounding 191 | 4007045926,7623513024,non-sounding 192 | 4041216001,4351373013,non-sounding 193 | 4180455681,5006362787,non-sounding 194 | 4303617148,10080652613,non-sounding 195 | 4351373013,2630218696,non-sounding 196 | 4407899725,6067281197,non-sounding 197 | 4479260803,4303617148,non-sounding 198 | 4543136011,10796050285,non-sounding 199 | 4576746825,8294094053,non-sounding 200 | 4646464908,5906878321,non-sounding 201 | 4657718822,7740990330,non-sounding 202 | 4717096777,17140316339,non-sounding 203 | 4755507106,11765015695,non-sounding 204 | 4758950312,6755812421,non-sounding 205 | 4767923306,7623513024,non-sounding 206 | 4876943924,10079676496,non-sounding 207 | 4938432980,16659090834,non-sounding 208 | 4957886467,3709678770,non-sounding 209 | 4965189170,5344317532,non-sounding 210 | 5006362787,4576746825,non-sounding 211 | 5130986168,2961466546,non-sounding 212 | 5179649119,16659090834,non-sounding 213 | 5194793239,10667001125,non-sounding 214 | 5220919991,10476442285,non-sounding 215 | 5237296528,10080652613,non-sounding 216 | 5303633386,11765015695,non-sounding 217 | 5344317532,10476442285,non-sounding 218 | 5480070309,12103535156,non-sounding 219 | 5490101320,5601291878,non-sounding 220 | 5601291878,10979853044,non-sounding 221 | 5607310495,4303617148,non-sounding 222 | 5697766486,3484198977,non-sounding 223 | 5801923729,10304554474,non-sounding 224 | 5873098388,3727937033,non-sounding 225 | 5906878321,10129791115,non-sounding 226 | 5991393107,4576746825,non-sounding 227 | 6067281197,10308472565,non-sounding 228 | 6098668510,3207674548,non-sounding 229 | 6158154336,2811549163,non-sounding 230 | 6185482260,10246013484,non-sounding 231 | 6289328021,10289643764,non-sounding 232 | 6458319057,9761803312,non-sounding 233 | 6669466181,4657718822,non-sounding 234 | 6755812421,2895192729,non-sounding 235 | 6897499873,10165890496,non-sounding 236 | 7178293884,10013411946,non-sounding 237 | 7387939334,10862172945,non-sounding 238 | 7560517176,12328837165,non-sounding 239 | 7623513024,9761803312,non-sounding 240 | 7733838448,17140316339,non-sounding 241 | 7740990330,11002352834,non-sounding 242 | 7897462346,10061269855,non-sounding 243 | 8250285374,14172932360,non-sounding 244 | 8294094053,5801923729,non-sounding 245 | 8311669586,11314266424,non-sounding 246 | 8396975855,6755812421,non-sounding 247 | 9060997789,12599965145,non-sounding 248 | 9309351423,3230152596,non-sounding 249 | 9456627660,4965189170,non-sounding 250 | 9636000842,3403348616,non-sounding 251 | 9761803312,3109408703,non-sounding -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio-Grounded Contrastive Learning (WACV’24) 2 | 3 | Official pytorch implementation of out paper: 4 | 5 | > **[Can CLIP Help Sound Source Localization?](https://arxiv.org/abs/2311.04066)** 6 | > 7 | > [Sooyoung Park*](https://sites.google.com/view/sooyoungpark), [Arda Senocak*](https://ardasnck.github.io/), [Joon Son Chung](https://mmai.io/joon/) (* Equal Contribution) 8 | > 9 | > WACV 2024 10 | 11 | 12 | ## Introduction 13 | 14 | image 15 | 16 | This repo is pytorch implementation of Audio-Grounded Contrastive Learning (ACL). Code is very simple and easy to understand fastly. 17 | 18 | Some of these codes are based on [AudioToken](https://github.com/guyyariv/AudioToken), [BEATs](https://github.com/microsoft/unilm/tree/master/beats), [TCL](https://github.com/kakaobrain/tcl). 19 | 20 | Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/swimmiing/ACL-SSL-zeroshot-demo) 21 | 22 | ## Required packages 23 | 24 | - Python = 3.10.8 25 | - Pytorch = 1.13.0 26 | - transformers = 4.25.1 27 | 28 | ### Installation 29 | 30 | ```bash 31 | $ conda install -c nvidia cudatoolkit=11.7 32 | $ conda install -c conda-forge cudnn 33 | $ conda install python=3.10 34 | $ pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 35 | $ pip install tensorboard 36 | $ pip transformers==4.25.1 37 | $ pip install opencv-python 38 | $ pip install tqdm 39 | $ pip install scikit-learn 40 | 41 | ``` 42 | 43 | ## Data preparation 44 | 45 | **Important Note:** All audio samples must be converted to 16kHz, and for detailed instructions, refer to the readme in each dataset-specific directory. 46 | 47 | - **Dataset** 48 | - VGG-Sound: [[Link]](https://www.robots.ox.ac.uk/~vgg/data/vggsound/) 49 | - VGG-SS: [[Link]](https://www.robots.ox.ac.uk/~vgg/research/lvs/) 50 | - Flickr: [[Link]](https://github.com/ardasnck/learning_to_localize_sound_source) 51 | - AVSBench: [[Link]](http://www.avlbench.opennlplab.cn/dataset/avsbench) 52 | - Extended VGG-SS/Flickr: [[Link]](https://github.com/stoneMo/SLAVC) 53 | 54 | ## Model preparation 55 | 56 | Downloading pretrained model (audio backbone) in pretrain folder 57 | - BEATs: https://github.com/microsoft/unilm/tree/master/beats 58 | - BEATs_iter3_plus_AS2M_finedtuned_on_AS2M_cpt2.pt 59 | 60 | 61 | ## Training 62 | 63 | - Ensure that you check the .sh files and set the `$ export CUDA_VISIBLE_DEVICES=”**”` according to your hardware setup. 64 | - Make sure that `—model_name` corresponds to the configuration file located at `./config/model/{-model_name}.yaml`. 65 | - Model files (.pth) will be saved in the directory `{—save_path}/Train_record/{-model_name}_{-exp_name}/`. 66 | - Review the configuration settings in `./config/train/{-train_config}.yaml` to ensure they match your training requirements. 67 | - Choose one of the following methods to initiate training: 68 | 69 | ```bash 70 | $ sh SingleGPU_Experiment.sh. # For single GPU setup 71 | $ sh Distributed_Experiment.sh. # For multi-GPU setup (DDP) 72 | ``` 73 | 74 | ## Test 75 | 76 | - Before testing, please review the .sh file and set the `$ export CUDA_VISIBLE_DEVICES=”**”` environment variable according to your hardware configuration. 77 | - Ensure that the `—model_name` parameter corresponds to the configuration file located at `./config/model/{-model_name}.yaml`. 78 | - Model files (.pth) located in the directory `{—save_path}/{-model_name}_{-exp_name}/Param_{-epochs}.pth` will be used for testing. 79 | - The `—epochs` parameter can accept either an integer or a list of integers (e.g., 1, 2, 3). 80 | - If `—epochs` is left unspecified (null), the default model file `{—save_path}/Train_record/{-model_name}_{-exp_name}/Param_best.pth` will be used for testing. 81 | 82 | ```bash 83 | $ sh Test_PTModels 84 | ``` 85 | 86 | ## P**retrained models** 87 | 88 | **Important Note:** After downloading the Param_best.pth file, move it to the directory `{—save_path}/{-model_name}_{-exp_name}/` before use. 89 | 90 | - VGG-Sound 144k trained model: [[Link]](https://drive.google.com/file/d/1XnVrBES3IKjAcV0uCkvbIdEEclcOYJoR/view?usp=drive_link) 91 | - This model was trained using a 2-GPU setup. 92 | - The reported numbers are the highest, with performance varying across different seeds, and the provided .pth link corresponds to the checkpoint used for the highest result. 93 | 94 | ## **Citation** 95 | 96 | If you use this project, please cite this project as: 97 | 98 | ```latex 99 | @inproceedings{park2023clip, 100 | title={Can CLIP Help Sound Source Localization?}, 101 | author={Sooyoung Park and Arda Senocak and Joon Son Chung}, 102 | journal = {arXiv preprint arXiv:2311.04066}, 103 | year={2023}, 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /SingleGPU_Experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0" 4 | 5 | python Train_ACL.py \ 6 | --model_name ACL_ViT16 \ 7 | --exp_name aclifa_1gpu \ 8 | --train_config Exp_ACL_v1 \ 9 | --vggss_path {put dataset directory} \ 10 | --flickr_path {put dataset directory} \ 11 | --avs_path {put dataset directory} \ 12 | --save_path {put logging directory} 13 | -------------------------------------------------------------------------------- /Test_PTModels.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from VGGSS.VGGSS_Dataset import VGGSSDataset, ExtendVGGSSDataset 4 | from Flickr.Flickr_Dataset import FlickrDataset, ExtendFlickrDataset 5 | from AVSBench.AVSBench_Dataset import AVSBenchDataset 6 | 7 | from Eval import eval_vggss_agg, eval_flickr_agg, eval_avsbench_agg, eval_exvggss_agg, eval_exflickr_agg 8 | 9 | from modules.models import * 10 | from modules.arg_utils import int_or_int_list_or_none 11 | from typing import Union, List, Any 12 | 13 | 14 | @torch.no_grad() 15 | def main( 16 | model_name: str, 17 | exp_name: str, 18 | epochs: Union[int, List[Union[int, None]]], 19 | data_path_dict: dict, 20 | save_path: str) -> None: 21 | """ 22 | Main function for evaluating sound source localization model. 23 | 24 | Args: 25 | model_name (str): The name of the model, corresponding to the model config file in './config/model'. 26 | exp_name (str): The postfix for saving the experiment. 27 | epochs (Union[int, List[Union[int, None]]]): List of epochs to evaluate. 28 | data_path_dict (dict): The directory for dataset. 29 | save_path (str): The directory for saving evaluation results. 30 | """ 31 | 32 | USE_CUDA = torch.cuda.is_available() 33 | device = torch.device('cuda:0' if USE_CUDA else 'cpu') 34 | 35 | model_exp_name = f'{model_name}_{exp_name}' if exp_name != "" else model_name 36 | 37 | print(f"Exp_name: {model_exp_name}") 38 | 39 | for epoch in epochs: 40 | # Get model 41 | model_conf_file = f'./config/model/{model_name}.yaml' 42 | model = ACL(model_conf_file, device) 43 | model.train(False) 44 | 45 | # Load model 46 | postfix = str(epoch) if epoch is not None else 'best' 47 | model_dir = os.path.join(save_path, 'Train_record', model_exp_name, f'Param_{postfix}.pth') 48 | model.load(model_dir) 49 | 50 | # Set directory 51 | viz_dir_template = os.path.join(save_path, 'Visual_results', '{}', model_exp_name, f'epoch{postfix}') 52 | tensorboard_path = os.path.join(save_path, 'Train_record', model_exp_name) 53 | 54 | # Get dataloader 55 | exvggss_dataset = ExtendVGGSSDataset(data_path_dict['vggss'], input_resolution=352) 56 | exvggss_dataloader = torch.utils.data.DataLoader(exvggss_dataset, batch_size=1, shuffle=False, num_workers=1, 57 | pin_memory=True, drop_last=False) 58 | 59 | exflickr_dataset = ExtendFlickrDataset(data_path_dict['flickr'], input_resolution=352) 60 | exflickr_dataloader = torch.utils.data.DataLoader(exflickr_dataset, batch_size=1, shuffle=False, num_workers=1, 61 | pin_memory=True, drop_last=False) 62 | 63 | flickr_dataset = FlickrDataset(data_path_dict['flickr'], 'flickr_test', is_train=False, input_resolution=352) 64 | flickr_dataloader = torch.utils.data.DataLoader(flickr_dataset, batch_size=1, shuffle=False, num_workers=1, 65 | pin_memory=True, drop_last=False) 66 | 67 | vggss_dataset = VGGSSDataset(data_path_dict['vggss'], 'vggss_test', is_train=False, input_resolution=352) 68 | vggss_dataloader = torch.utils.data.DataLoader(vggss_dataset, batch_size=1, shuffle=False, num_workers=1, 69 | pin_memory=True, drop_last=False) 70 | 71 | avss4_dataset = AVSBenchDataset(data_path_dict['avs'], 'avs1_s4_test', is_train=False, input_resolution=352) 72 | avss4_dataloader = torch.utils.data.DataLoader(avss4_dataset, batch_size=5, shuffle=False, num_workers=1, 73 | pin_memory=True, drop_last=False) 74 | 75 | avsms3_dataset = AVSBenchDataset(data_path_dict['avs'], 'avs1_ms3_test', is_train=False, input_resolution=352) 76 | avsms3_dataloader = torch.utils.data.DataLoader(avsms3_dataset, batch_size=5, shuffle=False, num_workers=1, 77 | pin_memory=True, drop_last=False) 78 | 79 | # Evaluate 80 | eval_exflickr_agg(model, exflickr_dataloader, viz_dir_template.format('exflickr')) 81 | eval_exvggss_agg(model, exvggss_dataloader, viz_dir_template.format('exvggss')) 82 | eval_flickr_agg(model, flickr_dataloader, viz_dir_template.format('flickr'), tensorboard_path=tensorboard_path) 83 | eval_vggss_agg(model, vggss_dataloader, viz_dir_template.format('vggss'), tensorboard_path=tensorboard_path) 84 | eval_avsbench_agg(model, avss4_dataloader, viz_dir_template.format('s4'), tensorboard_path=tensorboard_path) 85 | eval_avsbench_agg(model, avsms3_dataloader, viz_dir_template.format('ms3'), tensorboard_path=tensorboard_path) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('--model_name', type=str, default='ACL_ViT16', help='Use model config file name') 91 | parser.add_argument('--exp_name', type=str, default='aclifa_2gpu', help='postfix for save experiment') 92 | parser.add_argument('--epochs', type=int_or_int_list_or_none, default=[None], help='epochs ([None] for released)') 93 | parser.add_argument('--vggss_path', type=str, default='', help='VGGSS dataset directory') 94 | parser.add_argument('--flickr_path', type=str, default='', help='Flickr dataset directory') 95 | parser.add_argument('--avs_path', type=str, default='', help='AVSBench dataset directory') 96 | parser.add_argument('--save_path', type=str, default='', help='Checkpoints directory') 97 | args = parser.parse_args() 98 | 99 | data_dict = {'vggss': args.vggss_path, 100 | 'flickr': args.flickr_path, 101 | 'avs': args.avs_path} 102 | 103 | # Run example 104 | main(args.model_name, args.exp_name, args.epochs, data_dict, args.save_path) 105 | -------------------------------------------------------------------------------- /Test_PTModels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0" 4 | 5 | python Test_PTModels.py \ 6 | --model_name ACL_ViT16 \ 7 | --exp_name aclifa_2gpu \ 8 | --vggss_path {put dataset directory} \ 9 | --flickr_path {put dataset directory} \ 10 | --avs_path {put dataset directory} \ 11 | --save_path {put dataset directory} \ 12 | --epochs None 13 | -------------------------------------------------------------------------------- /VGGSS/README.md: -------------------------------------------------------------------------------- 1 | # Directory guide for VGGSound 2 | ```commandline 3 | ├── VGGSound/ 4 | | ├── audio 5 | | | └── ... 6 | | | └── *** .wav 7 | | ├── frames 8 | | | └── ... 9 | | | └── *** .jpg 10 | | ├── extend_audio 11 | | | └── ... 12 | | | └── *** .wav 13 | | ├── extend_frames 14 | | | └── ... 15 | | | └── *** .jpg 16 | ``` 17 | All .wav files sampled 16k 18 | 19 | ## Important 20 | Official annotations (bounding box) are based on the 125th frame of a 25fps video for each file, not the exact center frame. -------------------------------------------------------------------------------- /VGGSS/VGGSS_Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import numpy as np 5 | import torchaudio 6 | from torchvision import transforms as vt 7 | from PIL import Image 8 | import os 9 | import csv 10 | import json 11 | from typing import Dict, List, Optional, Union 12 | 13 | 14 | def load_all_bboxes(annotation_dir: str) -> Dict[str, List[np.ndarray]]: 15 | """ 16 | Load all bounding boxes from json annotation. 17 | 18 | Args: 19 | annotation_dir (str): json annotation file path. 20 | 21 | Returns: 22 | Dict[str, List[np.ndarray]]: Dictionary containing bounding boxes for each file. 23 | """ 24 | gt_bboxes = {} 25 | with open(annotation_dir) as json_file: 26 | annotations = json.load(json_file) 27 | for annotation in annotations: 28 | bboxes = [(np.clip(np.array(bbox), 0, 1) * 224).astype(int) for bbox in annotation['bbox']] 29 | gt_bboxes[annotation['file']] = bboxes 30 | 31 | return gt_bboxes 32 | 33 | 34 | def bbox2gtmap(bboxes: List[List[int]]) -> np.ndarray: 35 | """ 36 | Convert bounding boxes to numpy ground truth map. 37 | 38 | Args: 39 | bboxes (List[List[int]]): List of bounding boxes. 40 | 41 | Returns: 42 | np.ndarray: Ground truth map. 43 | """ 44 | gt_map = np.zeros([224, 224]) 45 | for xmin, ymin, xmax, ymax in bboxes: 46 | temp = np.zeros([224, 224]) 47 | temp[ymin:ymax, xmin:xmax] = 1 48 | gt_map += temp 49 | gt_map[gt_map > 1] = 1 50 | 51 | return gt_map 52 | 53 | 54 | class VGGSSDataset(Dataset): 55 | def __init__(self, data_path: str, split: str, is_train: bool = True, set_length: int = 10, 56 | input_resolution: int = 224, hard_aug: bool = False): 57 | """ 58 | Initialize VGG-Sound Dataset for VGG-SS. 59 | 60 | Args: 61 | data_path (str): Path to the dataset. 62 | split (str): Dataset split (Use csv file name in metadata directory). 63 | is_train (bool, optional): Whether it's a training set. Default is True. 64 | set_length (int, optional): Duration of input audio. Default is 10. 65 | input_resolution (int, optional): Resolution of input images. Default is 224. 66 | hard_aug (bool, optional): Not used. 67 | """ 68 | super(VGGSSDataset, self).__init__() 69 | 70 | self.SAMPLE_RATE = 16000 71 | self.split = split 72 | self.set_length = set_length 73 | self.csv_dir = 'VGGSS/metadata/' + split + '.csv' 74 | 75 | ''' Audio files ''' 76 | self.audio_path = os.path.join(data_path, 'audio') 77 | audio_files = set([fn.split('.wav')[0] for fn in os.listdir(self.audio_path) if fn.endswith('.wav')]) 78 | 79 | ''' Image files ''' 80 | self.image_path = os.path.join(data_path, 'frames') 81 | image_files = set([fn.split('.jpg')[0] for fn in os.listdir(self.image_path) if fn.endswith('.jpg')]) 82 | 83 | ''' Ground truth (Bounding box) ''' 84 | gt_path = f'VGGSS/metadata/vggss.json' 85 | if is_train: 86 | self.bbox_dict = None 87 | else: 88 | self.bbox_dict = load_all_bboxes(gt_path) 89 | 90 | ''' Ground truth (Text label) ''' 91 | self.label_dict = {item[0]: item[1] for item in csv.reader(open(self.csv_dir))} 92 | 93 | ''' Available files''' 94 | subset = set([item[0] for item in csv.reader(open(self.csv_dir))]) 95 | self.file_list = list(audio_files.intersection(image_files).intersection(subset)) 96 | self.file_list = sorted(self.file_list) 97 | 98 | ''' Transform ''' 99 | if is_train: 100 | self.image_transform = vt.Compose([ 101 | vt.Resize((int(input_resolution * 1.1), int(input_resolution * 1.1)), vt.InterpolationMode.BICUBIC), 102 | vt.ToTensor(), 103 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 104 | vt.RandomCrop((input_resolution, input_resolution)), 105 | vt.RandomHorizontalFlip(), 106 | ]) 107 | if hard_aug: 108 | self.image_transform = vt.Compose([ 109 | vt.RandomResizedCrop((input_resolution, input_resolution)), 110 | vt.RandomApply([vt.GaussianBlur(5, [.1, 2.])], p=0.8), 111 | vt.RandomApply([vt.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 112 | vt.RandomGrayscale(p=0.2), 113 | vt.ToTensor(), 114 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 115 | vt.RandomHorizontalFlip(), 116 | ]) 117 | else: 118 | self.image_transform = vt.Compose([ 119 | vt.Resize((input_resolution, input_resolution), vt.InterpolationMode.BICUBIC), 120 | vt.ToTensor(), 121 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 122 | ]) 123 | 124 | self.is_train = is_train 125 | self.use_image = True 126 | if input_resolution is None: 127 | self.use_image = False 128 | 129 | def __len__(self): 130 | """ 131 | Return the number of items in the dataset. 132 | """ 133 | return len(self.file_list) 134 | 135 | def get_audio(self, item: int) -> torch.Tensor: 136 | """ 137 | Get audio data for a given item. 138 | 139 | Args: 140 | item (int): Index of the item. 141 | 142 | Returns: 143 | torch.Tensor: Audio data. 144 | """ 145 | audio_file, _ = torchaudio.load(os.path.join(self.audio_path, self.file_list[item] + '.wav')) 146 | audio_file = audio_file.squeeze(0) 147 | 148 | # slicing or padding based on set_length 149 | # slicing 150 | if audio_file.shape[0] > (self.SAMPLE_RATE * self.set_length): 151 | audio_file = audio_file[:self.SAMPLE_RATE * self.set_length] 152 | # zero padding 153 | if audio_file.shape[0] < (self.SAMPLE_RATE * self.set_length): 154 | pad_len = (self.SAMPLE_RATE * self.set_length) - audio_file.shape[0] 155 | pad_val = torch.zeros(pad_len) 156 | audio_file = torch.cat((audio_file, pad_val), dim=0) 157 | 158 | return audio_file 159 | 160 | def get_image(self, item: int) -> Image.Image: 161 | """ 162 | Get image data for a given item. 163 | 164 | Args: 165 | item (int): Index of the item. 166 | 167 | Returns: 168 | Image.Image: Image data. 169 | """ 170 | image_file = Image.open(os.path.join(self.image_path, self.file_list[item] + '.jpg')) 171 | return image_file 172 | 173 | def get_bbox(self, item: int) -> Optional[torch.Tensor]: 174 | """ 175 | Get ground truth data for a given item. 176 | 177 | Args: 178 | item (int): Index of the item. 179 | 180 | Returns: 181 | Optional[torch.Tensor]: Ground truth data. 182 | """ 183 | # Bounding Box 184 | if self.bbox_dict is None: 185 | return None 186 | else: 187 | bbox = self.bbox_dict[self.file_list[item]] 188 | return vt.ToTensor()(bbox2gtmap(bbox)) 189 | 190 | def __getitem__(self, item: int) -> Dict[str, Union[torch.Tensor, torch.Tensor, Optional[torch.Tensor], str, str]]: 191 | """ 192 | Get item from the dataset. 193 | 194 | Args: 195 | item (int): Index of the item. 196 | 197 | Returns: 198 | Dict[str, Union[torch.Tensor, torch.Tensor, Optinal[torch.Tensor], str, str]]: Data example 199 | """ 200 | file_id = self.file_list[item] 201 | 202 | ''' Load data ''' 203 | audio_file = self.get_audio(item) if self.set_length != 0 else None 204 | image_file = self.get_image(item) if self.use_image else None 205 | label = self.label_dict[self.file_list[item]].replace('_', ' ') 206 | bboxes = self.get_bbox(item) if self.set_length != 0 and self.use_image else None 207 | 208 | ''' Transform ''' 209 | audio = audio_file if self.set_length != 0 else None 210 | image = self.image_transform(image_file) if self.use_image else None 211 | 212 | out = {'images': image, 'audios': audio, 'bboxes': bboxes, 'labels': label, 'ids': file_id} 213 | out = {key: value for key, value in out.items() if value is not None} 214 | return out 215 | 216 | 217 | class ExtendVGGSSDataset(Dataset): 218 | def __init__(self, data_path: str, set_length: int = 10, input_resolution: int = 224): 219 | """ 220 | Initialize Extended VGG-SS dataset. 221 | 222 | Args: 223 | data_path (str): Path to the dataset. 224 | set_length (int, optional): Duration of input audio. Default is 10. 225 | input_resolution (int, optional): Resolution of input images. Default is 224. 226 | """ 227 | super(ExtendVGGSSDataset, self).__init__() 228 | 229 | self.SAMPLE_RATE = 16000 230 | self.set_length = set_length 231 | self.csv_dir = 'VGGSS/metadata/vggss_test.csv' 232 | self.extend_csv_dir = 'VGGSS/metadata/vggss_test_plus_silent.csv' 233 | self.split = 'exvggss' 234 | 235 | ''' Audio files ''' 236 | self.audio_path = os.path.join(data_path, 'extend_audio') 237 | audio_files = set([fn.split('.wav')[0] for fn in os.listdir(self.audio_path) if fn.endswith('.wav')]) 238 | 239 | ''' Image files ''' 240 | self.image_path = os.path.join(data_path, 'extend_frames') 241 | image_files = set([fn.split('.jpg')[0] for fn in os.listdir(self.image_path) if fn.endswith('.jpg')]) 242 | 243 | ''' Ground truth (Bounding box) ''' 244 | gt_path = f'VGGSS/metadata/vggss.json' 245 | self.bbox_dict = load_all_bboxes(gt_path) 246 | 247 | ''' Ground truth (Text label) ''' 248 | self.label_dict = {item[0]: item[1] for item in csv.reader(open(self.csv_dir))} 249 | 250 | ''' Available files''' 251 | subset = set([item[0] for item in csv.reader(open(self.csv_dir))]) 252 | file_list = sorted(list(audio_files.intersection(image_files).intersection(subset))) 253 | self.image_files = [dt + '.jpg' for dt in file_list] 254 | self.audio_files = [dt + '.wav' for dt in file_list] 255 | self.bboxes = [self.bbox_dict[dt] for dt in file_list] 256 | self.labels = [self.label_dict[dt] for dt in file_list] 257 | 258 | ''' Add non-sounding files ''' 259 | for item in csv.reader(open(self.extend_csv_dir)): 260 | if item[2] == 'non-sounding': 261 | self.image_files.append(f'{item[0]}.jpg') 262 | self.audio_files.append(f'{item[1]}.wav') 263 | self.bboxes.append([]) 264 | self.labels.append('non-sounding') 265 | 266 | ''' Transform ''' 267 | self.image_transform = vt.Compose([ 268 | vt.Resize((input_resolution, input_resolution), vt.InterpolationMode.BICUBIC), 269 | vt.ToTensor(), 270 | vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP 271 | ]) 272 | 273 | def __len__(self): 274 | """ 275 | Return the number of items in the dataset. 276 | """ 277 | return len(self.image_files) 278 | 279 | def get_audio(self, item: int) -> torch.Tensor: 280 | """ 281 | Get audio data for a given item. 282 | 283 | Args: 284 | item (int): Index of the item. 285 | 286 | Returns: 287 | torch.Tensor: Audio data. 288 | """ 289 | audio_file, _ = torchaudio.load(os.path.join(self.audio_path, self.audio_files[item])) 290 | audio_file = audio_file.squeeze(0) 291 | 292 | # slicing or padding based on set_length 293 | # slicing 294 | if audio_file.shape[0] > (self.SAMPLE_RATE * self.set_length): 295 | audio_file = audio_file[:self.SAMPLE_RATE * self.set_length] 296 | # zero padding 297 | if audio_file.shape[0] < (self.SAMPLE_RATE * self.set_length): 298 | pad_len = (self.SAMPLE_RATE * self.set_length) - audio_file.shape[0] 299 | pad_val = torch.zeros(pad_len) 300 | audio_file = torch.cat((audio_file, pad_val), dim=0) 301 | 302 | return audio_file 303 | 304 | def get_image(self, item: int) -> Image.Image: 305 | """ 306 | Get image data for a given item. 307 | 308 | Args: 309 | item (int): Index of the item. 310 | 311 | Returns: 312 | Image.Image: Image data. 313 | """ 314 | image_file = Image.open(os.path.join(self.image_path, self.image_files[item])) 315 | return image_file 316 | 317 | def get_bbox(self, item: int) -> Optional[torch.Tensor]: 318 | """ 319 | Get ground truth data for a given item. 320 | 321 | Args: 322 | item (int): Index of the item. 323 | 324 | Returns: 325 | Optional[torch.Tensor]: Ground truth data. 326 | """ 327 | # Bounding Box 328 | if len(self.bboxes[item]) == 0: 329 | return vt.ToTensor()(np.zeros([224, 224])) 330 | else: 331 | bbox = self.bboxes[item] 332 | return vt.ToTensor()(bbox2gtmap(bbox)) 333 | 334 | def __getitem__(self, item: int) -> Dict[str, Union[torch.Tensor, torch.Tensor, Optional[torch.Tensor], str, str]]: 335 | """ 336 | Get item from the dataset. 337 | 338 | Args: 339 | item (int): Index of the item. 340 | 341 | Returns: 342 | Dict[str, Union[torch.Tensor, torch.Tensor, Optinal[torch.Tensor], str, str]]: Data example 343 | """ 344 | ''' Load data ''' 345 | audio_file = self.get_audio(item) 346 | image_file = self.get_image(item) 347 | bboxes = self.get_bbox(item) 348 | label = self.labels[item] 349 | file_id = self.image_files[item].split('.')[0] + '_' + self.audio_files[item].split('.')[0] 350 | 351 | ''' Transform ''' 352 | audio = audio_file 353 | image = self.image_transform(image_file) 354 | 355 | out = {'images': image, 'audios': audio, 'bboxes': bboxes, 'labels': label, 'ids': file_id} 356 | out = {key: value for key, value in out.items() if value is not None} 357 | return out 358 | -------------------------------------------------------------------------------- /VGGSS/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn import metrics 4 | from typing import List, Optional, Tuple, Dict 5 | 6 | 7 | class Evaluator(object): 8 | def __init__(self) -> None: 9 | """ 10 | Initialize the VGG-Sound Source (VGG-SS) Evaluator. 11 | 12 | Attributes: 13 | ciou (List[float]): Buffer of cIoU values. 14 | AUC (List[float]): Buffer of AUC values. 15 | N (int): Counter for the number of evaluations. 16 | metrics (List[str]): List of metric names. 17 | """ 18 | super(Evaluator, self).__init__() 19 | self.ciou = [] 20 | self.AUC = [] 21 | self.N = 0 22 | self.metrics = ['cIoU', 'AUC'] 23 | 24 | def evaluate_batch(self, pred: torch.Tensor, target: torch.Tensor, thr: Optional[float] = None) -> None: 25 | """ 26 | Evaluate a batch of predictions against ground truth. 27 | 28 | Args: 29 | pred (torch.Tensor): Model predictions. 30 | target (torch.Tensor): Ground truth maps. 31 | thr (Optional[float]): Threshold for binary classification. If None, dynamically determined. 32 | 33 | Returns: 34 | None 35 | """ 36 | for j in range(pred.size(0)): 37 | infer = pred[j] 38 | gt = target[j] 39 | if thr is None: 40 | thr = np.sort(infer.detach().cpu().numpy().flatten())[int(infer.shape[1] * infer.shape[2] / 2)] 41 | self.cal_CIOU(infer, gt, thr) 42 | 43 | def cal_CIOU(self, infer: torch.Tensor, gtmap: torch.Tensor, thres: float = 0.01) -> List[float]: 44 | """ 45 | Calculate cIoU (consensus Intersection over Union). 46 | 47 | Args: 48 | infer (torch.Tensor): Model prediction. 49 | gtmap (torch.Tensor): Ground truth map. 50 | thres (float): Threshold for binary classification. 51 | 52 | Returns: 53 | List[float]: List of cIoU values for each instance in the batch. 54 | """ 55 | infer_map = torch.zeros_like(gtmap) 56 | infer_map[infer >= thres] = 1 57 | ciou = (infer_map * gtmap).sum(2).sum(1) / (gtmap.sum(2).sum(1) + (infer_map * (gtmap == 0)).sum(2).sum(1)) 58 | 59 | for i in range(gtmap.size(0)): 60 | self.ciou.append(ciou[i].detach().cpu()) 61 | return ciou 62 | 63 | def finalize_AUC(self) -> float: 64 | """ 65 | Calculate the Area Under the Curve (AUC). 66 | 67 | Returns: 68 | float: AUC value. 69 | """ 70 | cious = [np.sum(np.array(self.ciou) >= 0.05 * i) / len(self.ciou) 71 | for i in range(21)] 72 | thr = [0.05 * i for i in range(21)] 73 | auc = metrics.auc(thr, cious) 74 | return auc 75 | 76 | def finalize_AP50(self) -> float: 77 | """ 78 | Calculate Average Precision (cIoU@0.5). 79 | 80 | Returns: 81 | float: cIoU@0.5 value. 82 | """ 83 | ap50 = np.mean(np.array(self.ciou) >= 0.5) 84 | return ap50 85 | 86 | def finalize_cIoU(self) -> float: 87 | """ 88 | Calculate mean cIoU. 89 | 90 | Returns: 91 | float: Mean cIoU value. 92 | """ 93 | ciou = np.mean(np.array(self.ciou)) 94 | return ciou 95 | 96 | def finalize(self) -> Tuple[List[str], Dict[str, float]]: 97 | """ 98 | Finalize and return evaluation metrics. 99 | 100 | Returns: 101 | Tuple[List[str], Dict[str, float]]: List of metric names and corresponding values. 102 | """ 103 | ap50 = self.finalize_AP50() * 100 104 | auc = self.finalize_AUC() * 100 105 | return self.metrics, {self.metrics[0]: ap50, self.metrics[1]: auc} 106 | 107 | -------------------------------------------------------------------------------- /VGGSS/extend_eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from sklearn import metrics 5 | 6 | import util 7 | 8 | 9 | class Evaluator(object): 10 | def __init__(self, iou_thrs=(0.5, ), default_conf_thr=0.5, pred_size=0.5, pred_thr=0.5, 11 | results_dir='./results'): 12 | """ 13 | Initialize the Extended Flickr evaluator. 14 | 15 | Notes: 16 | Taking computation speed into consideration, it is set to output only the 'all' subset. (AP, Max-F1) 17 | """ 18 | super(Evaluator, self).__init__() 19 | self.iou_thrs = iou_thrs 20 | self.default_conf_thr = default_conf_thr 21 | self.min_sizes = {'small': 0, 'medium': 32 ** 2, 'large': 96 ** 2, 'huge': 144 ** 2} 22 | self.max_sizes = {'small': 32 ** 2, 'medium': 96 ** 2, 'large': 144 ** 2, 'huge': 10000 ** 2} 23 | 24 | self.ciou_list = [] 25 | self.area_list = [] 26 | self.confidence_list = [] 27 | self.name_list = [] 28 | self.bb_list = [] 29 | # self.metrics = ['AP', 'Max-F1', 'LocAcc'] 30 | self.metrics = ['AP', 'Max-F1'] 31 | 32 | self.results_dir = results_dir 33 | self.viz_save_dir = f"{results_dir}/viz_conf" + str(default_conf_thr) + "_predsize" + str( 34 | pred_size) + "_predthr" + str(pred_thr) 35 | self.results_save_dir = f"{results_dir}/results_conf" + str(default_conf_thr) + "_predsize" + str( 36 | pred_size) + "_predthr" + str(pred_thr) 37 | 38 | @staticmethod 39 | def calc_precision_recall(bb_list, ciou_list, confidence_list, confidence_thr, ciou_thr=0.5): 40 | assert len(bb_list) == len(ciou_list) == len(confidence_list) 41 | true_pos, false_pos, false_neg = 0, 0, 0 42 | for bb, ciou, confidence in zip(bb_list, ciou_list, confidence_list): 43 | if bb == 0: 44 | # no sounding objects in frame 45 | if confidence >= confidence_thr: 46 | # sounding object detected 47 | false_pos += 1 48 | else: 49 | # sounding objects in frame 50 | if confidence >= confidence_thr: 51 | # sounding object detected... 52 | if ciou >= ciou_thr: # ...in correct place 53 | true_pos += 1 54 | else: # ...in wrong place 55 | false_pos += 1 56 | else: 57 | # no sounding objects detected 58 | false_neg += 1 59 | 60 | precision = 1. if true_pos + false_pos == 0 else true_pos / (true_pos + false_pos) 61 | recall = 1. if true_pos + false_neg == 0 else true_pos / (true_pos + false_neg) 62 | 63 | return precision, recall 64 | 65 | def calc_ap(self, bb_list_full, ciou_list_full, confidence_list_full, iou_thr=0.5): 66 | 67 | assert len(bb_list_full) == len(ciou_list_full) == len(confidence_list_full) 68 | 69 | # for visible objects 70 | # ss = [i for i, bb in enumerate(bb_list_full) if bb > 0] 71 | # bb_list = [bb_list_full[i] for i in ss] 72 | # ciou_list = [ciou_list_full[i] for i in ss] 73 | # confidence_list = [confidence_list_full[i] for i in ss] 74 | 75 | precision, recall, skip_thr = [], [], max(1, len(ciou_list_full) // 200) 76 | for thr in np.sort(np.array(confidence_list_full))[:-1][::-skip_thr]: 77 | p, r = self.calc_precision_recall(bb_list_full, ciou_list_full, confidence_list_full, thr, iou_thr) 78 | precision.append(p) 79 | recall.append(r) 80 | precision_max = [np.max(precision[i:]) for i in range(len(precision))] 81 | ap = sum([precision_max[i] * (recall[i + 1] - recall[i]) 82 | for i in range(len(precision_max) - 1)]) 83 | return ap 84 | 85 | def cal_auc(self, bb_list, ciou_list): 86 | ss = [i for i, bb in enumerate(bb_list) if bb > 0] 87 | ciou = [ciou_list[i] for i in ss] 88 | cious = [np.sum(np.array(ciou) >= 0.05 * i) / len(ciou) 89 | for i in range(21)] 90 | thr = [0.05 * i for i in range(21)] 91 | auc = metrics.auc(thr, cious) 92 | return auc 93 | 94 | def filter_subset(self, subset, name_list, area_list, bb_list, ciou_list, conf_list): 95 | if subset == 'visible': 96 | ss = [i for i, bb in enumerate(bb_list) if bb > 0] 97 | elif subset == 'non-visible/non-audible': 98 | ss = [i for i, bb in enumerate(bb_list) if bb == 0] 99 | elif subset == 'all': 100 | ss = [i for i, bb in enumerate(bb_list) if bb >= 0] 101 | else: 102 | ss = [i for i, sz in enumerate(area_list) 103 | if self.min_sizes[subset] <= sz < self.max_sizes[subset] and bb_list[i] > 0] 104 | 105 | if len(ss) == 0: 106 | return [], [], [], [], [] 107 | 108 | name = [name_list[i] for i in ss] 109 | area = [area_list[i] for i in ss] 110 | bbox = [bb_list[i] for i in ss] 111 | ciou = [ciou_list[i] for i in ss] 112 | conf = [conf_list[i] for i in ss] 113 | 114 | return name, area, bbox, ciou, conf 115 | 116 | def finalize_stats(self): 117 | name_full_list, area_full_list, bb_full_list, ciou_full_list, confidence_full_list = self.gather_results() 118 | 119 | metrics = {} 120 | for iou_thr in self.iou_thrs: 121 | # for subset in ['all', 'visible']: 122 | for subset in ['all']: 123 | _, _, bb_list, ciou_list, conf_list = self.filter_subset(subset, name_full_list, area_full_list, 124 | bb_full_list, ciou_full_list, 125 | confidence_full_list) 126 | subset_name = f'{subset}@{int(iou_thr * 100)}' if subset is not None else f'@{int(iou_thr * 100)}' 127 | if len(ciou_list) == 0: 128 | p, r, ap, f1, auc = np.nan, np.nan, np.nan, np.nan, np.nan 129 | else: 130 | p, r = self.calc_precision_recall(bb_list, ciou_list, conf_list, -1000, iou_thr) 131 | ap = self.calc_ap(bb_list, ciou_list, conf_list, iou_thr) 132 | auc = self.cal_auc(bb_list, ciou_list) 133 | 134 | conf_thr = list(sorted(conf_list))[::max(1, len(conf_list) // 10)] 135 | pr = [self.calc_precision_recall(bb_list, ciou_list, conf_list, thr, iou_thr) for thr in conf_thr] 136 | f1 = [2 * r * p / (r + p) if r + p > 0 else 0. for p, r in pr] 137 | if subset == 'all' and iou_thr == 0.5: 138 | ef1 = max(f1) 139 | eap = ap 140 | metrics['ef1'] = ef1 141 | metrics['eap'] = eap 142 | if subset == 'visible' and iou_thr == 0.5: 143 | eloc = self.precision_at_50() 144 | eauc = auc 145 | metrics['eloc'] = eloc 146 | metrics['eauc'] = eauc 147 | metrics[f'Precision-{subset_name}'] = p 148 | # metrics[f'Recall-{subset_name}'] = r 149 | if np.isnan(f1).any(): 150 | metrics[f'F1-{subset_name}'] = f1 151 | else: 152 | metrics[f'F1-{subset_name}'] = ' '.join([f'{f * 100:.1f}' for f in f1]) 153 | metrics[f'AP-{subset_name}'] = ap 154 | metrics[f'AUC-{subset_name}'] = auc 155 | 156 | return metrics 157 | 158 | def gather_results(self): 159 | # import torch.distributed as dist 160 | # if not dist.is_initialized(): 161 | return self.name_list, self.area_list, self.bb_list, self.ciou_list, self.confidence_list 162 | # world_size = dist.get_world_size() 163 | # 164 | # bb_list = [None for _ in range(world_size)] 165 | # dist.all_gather_object(bb_list, self.bb_list) 166 | # bb_list = [x for bb in bb_list for x in bb] 167 | # 168 | # area_list = [None for _ in range(world_size)] 169 | # dist.all_gather_object(area_list, self.area_list) 170 | # area_list = [x for area in area_list for x in area] 171 | # 172 | # ciou_list = [None for _ in range(world_size)] 173 | # dist.all_gather_object(ciou_list, self.ciou_list) 174 | # ciou_list = [x for ciou in ciou_list for x in ciou] 175 | # 176 | # confidence_list = [None for _ in range(world_size)] 177 | # dist.all_gather_object(confidence_list, self.confidence_list) 178 | # confidence_list = [x for conf in confidence_list for x in conf] 179 | # 180 | # name_list = [None for _ in range(world_size)] 181 | # dist.all_gather_object(name_list, self.name_list) 182 | # name_list = [x for name in name_list for x in name] 183 | # 184 | # return name_list, area_list, bb_list, ciou_list, confidence_list 185 | 186 | def precision_at_50(self): 187 | ss = [i for i, bb in enumerate(self.bb_list) if bb > 0] 188 | return np.mean(np.array([self.ciou_list[i] for i in ss]) > 0.5) 189 | 190 | def precision_at_50_object(self): 191 | max_num_obj = max(self.bb_list) 192 | for num_obj in range(1, max_num_obj + 1): 193 | ss = [i for i, bb in enumerate(self.bb_list) if bb == num_obj] 194 | precision = np.mean(np.array([self.ciou_list[i] for i in ss]) > 0.5) 195 | print('\n' + f'num_obj:{num_obj}, precision:{precision}') 196 | 197 | def f1_at_50(self): 198 | # conf_thr = np.array(self.confidence_list).mean() 199 | p, r = self.calc_precision_recall(self.bb_list, self.ciou_list, self.confidence_list, self.default_conf_thr, 200 | 0.5) 201 | return 2 * p * r / (p + r) if (p + r) > 0 else 0. 202 | 203 | def ap_at_50(self): 204 | return self.calc_ap(self.bb_list, self.ciou_list, self.confidence_list, 0.5) 205 | 206 | def clear(self): 207 | self.ciou_list = [] 208 | self.area_list = [] 209 | self.confidence_list = [] 210 | self.name_list = [] 211 | self.bb_list = [] 212 | 213 | def update(self, bb, gt, conf, pred, pred_thr, name): 214 | if isinstance(conf, torch.Tensor): 215 | conf = conf.detach().cpu().numpy() 216 | if isinstance(pred, torch.Tensor): 217 | pred = pred.detach().cpu().numpy() 218 | if isinstance(gt, torch.Tensor): 219 | gt = gt.detach().cpu().numpy() 220 | 221 | # Compute binary prediction map 222 | infer = np.zeros((224, 224)) 223 | infer[pred >= pred_thr] = 1 224 | 225 | # Compute ciou between prediction and ground truth 226 | ciou = np.sum(infer * gt) / (np.sum(gt) + np.sum(infer * (gt == 0)) + 1e-12) 227 | 228 | # Compute ground truth size 229 | area = gt.sum() 230 | 231 | # Save 232 | self.confidence_list.append(conf) 233 | self.ciou_list.append(ciou) 234 | self.area_list.append(area) 235 | self.name_list.append(name) 236 | self.bb_list.append(bb) 237 | 238 | def evaluate_batch(self, output, gt, label, conf, name, thr=None): 239 | for i in range(output.shape[0]): 240 | pred = output[i, 0].detach().cpu().numpy() 241 | if thr is None: 242 | thr = np.sort(pred.flatten())[int(pred.shape[0] * pred.shape[1] * 0.5)] 243 | 244 | bb = 1 if label[i] != 'non-sounding' else 0 245 | 246 | self.update(bb, gt[i, 0], conf[i], pred, thr, name[i]) 247 | 248 | def finalize(self): 249 | metric_extend = self.finalize_stats() 250 | eap = metric_extend['AP-all@50'] 251 | ef1 = metric_extend['F1-all@50'] 252 | # eloc = metric_extend['Precision-visible@50'] 253 | emaxf1 = max([float(num) for num in ef1.split(' ')]) 254 | return self.metrics, {self.metrics[0]: eap*100, self.metrics[1]: emaxf1} # , self.metrics[2]: eloc*100} 255 | -------------------------------------------------------------------------------- /asset/summary_wacv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swimmiing/ACL-SSL/57f4a6c3b93e03992806da483c143c8802eba21d/asset/summary_wacv.png -------------------------------------------------------------------------------- /config/model/ACL_ViT16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | clip: ViT16 3 | vision_backbone: null 4 | audio_backbone: BEATs 5 | audio_proj: FGA512 6 | 7 | pretrain: 8 | vision_backbone: null 9 | audio_backbone: ./pretrain/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt 10 | audio_proj: null 11 | 12 | fga_conf: 13 | FGA: 14 | input_size: 768 15 | output_size: 768 16 | 17 | FGA512: 18 | input_size: 768 19 | output_size: 512 20 | 21 | clip_conf: 22 | RN50: 23 | name: RN50 24 | vision: 25 | image_resolution: 224 26 | vision_layers: [3, 4, 6, 3] 27 | vision_width: 64 28 | heads: 8 29 | vision_patch_size: null 30 | text: 31 | transformer_layers: 12 32 | transformer_width: 512 33 | transformer_heads: 8 34 | vocab_size: 49408 35 | context_length: 77 36 | embedding_dim: 1024 37 | 38 | ViT16: 39 | name: ViT-B/16 40 | vision: 41 | image_resolution: 224 42 | vision_layers: 12 43 | vision_width: 768 44 | heads: 12 45 | vision_patch_size: 16 46 | text: 47 | transformer_layers: 12 48 | transformer_width: 512 49 | transformer_heads: 8 50 | vocab_size: 49408 51 | context_length: 77 52 | embedding_dim: 512 53 | 54 | ViT14: 55 | name: ViT-L/14 56 | vision: 57 | image_resolution: 224 58 | vision_layers: 24 59 | vision_width: 1024 60 | heads: 16 61 | vision_patch_size: 14 62 | text: 63 | transformer_layers: 12 64 | transformer_width: 768 65 | transformer_heads: 12 66 | vocab_size: 49408 67 | context_length: 77 68 | embedding_dim: 768 69 | 70 | vision_backbone_conf: 71 | maskclip_plus_rn50_512: 72 | name: maskclip_plus_rn50_512 73 | image_resolution: 512 74 | vision_layers: [ 3, 4, 6, 3 ] 75 | vision_width: 2048 76 | aspp: 77 | dilations: [ 6, 12, 18, 24 ] 78 | in_channels: 2048 79 | channels: 512 80 | 81 | maskclip_plus_rn101_512: 82 | name: maskclip_plus_rn101_512 83 | image_resolution: 512 84 | vision_layers: [ 3, 4, 23, 3 ] 85 | vision_width: 2048 86 | aspp: 87 | dilations: [ 6, 12, 18, 24 ] 88 | in_channels: 2048 89 | channels: 1024 90 | -------------------------------------------------------------------------------- /config/train/Exp_ACL_v1.yaml: -------------------------------------------------------------------------------- 1 | model: ACL 2 | 3 | common: 4 | train_data: vggss 5 | epoch: 20 6 | batch_size: 8 7 | input_resolution: 352 8 | num_workers: 4 9 | seed: 0 10 | loss: 11 | - acl_i 12 | - acl_f 13 | - area_reg 14 | loss_w: 15 | - 1 16 | - 1 17 | - 1 18 | 19 | optimizer: Adam 20 | scheduler: null 21 | amp: True 22 | 23 | optim_conf: 24 | Adam: 25 | module_path: torch.optim 26 | module_name: Adam 27 | lr: 0.0001 28 | weight_decay: 0.0001 29 | 30 | AdamW: 31 | module_path: torch.optim 32 | module_name: AdamW 33 | lr: 0.001 34 | 35 | SGDR: 36 | module_path: torch.optim 37 | module_name: SGD 38 | lr: 0.5 39 | weight_decay: 0.00001 40 | 41 | sched_conf: 42 | Cosine: 43 | module_path: torch.optim.lr_scheduler 44 | module_name: CosineAnnealingLR 45 | eta_ratio: 0.0 -------------------------------------------------------------------------------- /loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def infonce(pred: torch.Tensor, target: torch.Tensor, beta: float = 1/0.07, **kwargs) -> torch.Tensor: 6 | ''' 7 | Compute the InfoNCE (Noise Contrastive Estimation) loss. 8 | 9 | Args: 10 | pred (torch.Tensor): The predicted tensor. 11 | target (torch.Tensor): The target tensor. 12 | beta (float, optional): Temperature parameter. Default is 1/0.07. 13 | 14 | Returns: 15 | torch.Tensor: InfoNCE loss. 16 | ''' 17 | B = pred.shape[0] 18 | logits = torch.einsum('nc,mc->nm', F.normalize(pred), F.normalize(target)) * beta 19 | labels = torch.arange(B).long().to(pred.device) 20 | loss = F.cross_entropy(logits, labels) 21 | 22 | return loss 23 | 24 | 25 | def area_reg(p_area: torch.Tensor, n_area: torch.Tensor, p_thr: float = 0.4, n_thr: float = 0.0, 26 | **kwargs) -> torch.Tensor: 27 | ''' 28 | Compute the area regularization loss. 29 | 30 | Args: 31 | p_area (torch.Tensor): Positive area tensor. 32 | n_area (torch.Tensor): Negative area tensor. 33 | p_thr (float, optional): Expected positive area. Default is 0.4. 34 | n_thr (float, optional): Expected negative area. Default is 0.0. 35 | 36 | Returns: 37 | torch.Tensor: Area regularization loss. 38 | ''' 39 | loss = torch.abs(p_area - p_thr) + torch.abs(n_area - n_thr) 40 | return loss 41 | 42 | 43 | def acl_i(v_i: torch.Tensor, pred_emb: torch.Tensor, beta: float = 1 / 0.07, **kwargs) -> torch.Tensor: 44 | ''' 45 | Compute the image-level audio-grounded contrastive learning (ACL_I) loss. 46 | 47 | Args: 48 | v_i (torch.Tensor): Image-level audio-grounded visual embedding tensor. 49 | pred_emb (torch.Tensor): Audio-driven embedding tensor. 50 | beta (float, optional): Temperature parameter. Default is 1/0.07. 51 | 52 | Returns: 53 | torch.Tensor: Image-level ACL loss 54 | ''' 55 | loss = 0.5 * (infonce(pred_emb, v_i, beta=beta) + infonce(v_i, pred_emb, beta=beta)) 56 | 57 | return loss 58 | 59 | 60 | def acl_f(v_f: torch.Tensor, pred_emb: torch.Tensor, beta: float = 1 / 0.07, **kwargs) -> torch.Tensor: 61 | ''' 62 | Compute the feature-level audio-grounded contrastive learning (ACL_F) loss. 63 | 64 | Args: 65 | v_f (torch.Tensor): Feature-level audio-grounded visual embedding tensor. 66 | pred_emb (torch.Tensor): Audio-driven embedding tensor. 67 | beta (float, optional): Temperature parameter. Default is 1/0.07. 68 | 69 | Returns: 70 | torch.Tensor: Feature-level ACL loss 71 | ''' 72 | B, _, C = v_f.size() 73 | logits = torch.einsum('bnc,bc ->bn', F.normalize(v_f, dim=2), F.normalize(pred_emb)) 74 | 75 | labels = torch.arange(B).long().to(pred_emb.device) 76 | loss = 0.5 * (F.cross_entropy(logits * beta, labels) + F.cross_entropy(logits.T * beta, labels)) 77 | 78 | return loss 79 | -------------------------------------------------------------------------------- /modules/AudioToken/AudioToken.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.loaders import AttnProcsLayers 3 | 4 | from modules.BEATs.BEATs import BEATs, BEATsConfig 5 | from modules.AudioToken.embedder import FGAEmbedder 6 | from diffusers import AutoencoderKL, UNet2DConditionModel 7 | from diffusers.models.attention_processor import LoRAAttnProcessor 8 | 9 | 10 | class AudioTokenWrapper(torch.nn.Module): 11 | """Simple wrapper module for Stable Diffusion that holds all the models together""" 12 | 13 | def __init__( 14 | self, 15 | args, 16 | accelerator, 17 | ): 18 | 19 | super().__init__() 20 | # Load scheduler and models 21 | from modules.clip_text_model.modeling_clip import CLIPTextModel 22 | self.text_encoder = CLIPTextModel.from_pretrained( 23 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 24 | ) 25 | self.unet = UNet2DConditionModel.from_pretrained( 26 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 27 | ) 28 | self.vae = AutoencoderKL.from_pretrained( 29 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision 30 | ) 31 | 32 | checkpoint = torch.load( 33 | 'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') 34 | cfg = BEATsConfig(checkpoint['cfg']) 35 | self.aud_encoder = BEATs(cfg) 36 | self.aud_encoder.load_state_dict(checkpoint['model']) 37 | self.aud_encoder.predictor = None 38 | input_size = 768 * 3 39 | 40 | if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4": 41 | self.embedder = FGAEmbedder(input_size=input_size, output_size=768) 42 | 43 | else: 44 | self.embedder = FGAEmbedder(input_size=input_size, output_size=1024) 45 | 46 | self.vae.eval() 47 | self.unet.eval() 48 | self.text_encoder.eval() 49 | self.aud_encoder.eval() 50 | 51 | if 'lora' in args and args.lora: 52 | # Set correct lora layers 53 | lora_attn_procs = {} 54 | for name in self.unet.attn_processors.keys(): 55 | cross_attention_dim = None if name.endswith( 56 | "attn1.processor") else self.unet.config.cross_attention_dim 57 | if name.startswith("mid_block"): 58 | hidden_size = self.unet.config.block_out_channels[-1] 59 | elif name.startswith("up_blocks"): 60 | block_id = int(name[len("up_blocks.")]) 61 | hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] 62 | elif name.startswith("down_blocks"): 63 | block_id = int(name[len("down_blocks.")]) 64 | hidden_size = self.unet.config.block_out_channels[block_id] 65 | 66 | lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, 67 | cross_attention_dim=cross_attention_dim) 68 | 69 | self.unet.set_attn_processor(lora_attn_procs) 70 | self.lora_layers = AttnProcsLayers(self.unet.attn_processors) 71 | 72 | if args.data_set == 'train': 73 | 74 | # Freeze vae, unet, text_enc and aud_encoder 75 | self.vae.requires_grad_(False) 76 | self.unet.requires_grad_(False) 77 | self.text_encoder.requires_grad_(False) 78 | self.aud_encoder.requires_grad_(False) 79 | self.embedder.requires_grad_(True) 80 | self.embedder.train() 81 | 82 | if 'lora' in args and args.lora: 83 | self.unet.train() 84 | 85 | if args.data_set == 'test': 86 | 87 | from transformers import CLIPTextModel 88 | self.text_encoder = CLIPTextModel.from_pretrained( 89 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 90 | ) 91 | 92 | self.embedder.eval() 93 | embedder_learned_embeds = args.learned_embeds 94 | self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device)) 95 | 96 | if 'lora' in args and args.lora: 97 | self.lora_layers.eval() 98 | lora_layers_learned_embeds = args.lora_learned_embeds 99 | self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device)) 100 | self.unet.load_attn_procs(lora_layers_learned_embeds) 101 | -------------------------------------------------------------------------------- /modules/AudioToken/embedder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from modules.FGA.atten import Atten 3 | 4 | class FGAEmbedder(nn.Module): 5 | def __init__(self, input_size=768*3, output_size=768): 6 | super(FGAEmbedder, self).__init__() 7 | self.fc1 = nn.Linear(input_size, input_size) 8 | self.fc2 = nn.Linear(input_size, output_size) 9 | self.gelu = nn.GELU() 10 | self.fga = Atten(util_e=[output_size], pairwise_flag=False) 11 | 12 | def forward(self, audio_embs): 13 | audio_embs = self.fc1(audio_embs) 14 | audio_embs = self.gelu(audio_embs) 15 | audio_embs = self.fc2(audio_embs) 16 | attend = self.fga([audio_embs])[0] 17 | return attend 18 | -------------------------------------------------------------------------------- /modules/BEATs/BEATs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import LayerNorm 14 | import torchaudio.compliance.kaldi as ta_kaldi 15 | from torch.cuda.amp import autocast 16 | 17 | from modules.BEATs.backbone import ( 18 | TransformerEncoder, 19 | ) 20 | 21 | import logging 22 | from typing import Optional 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class BEATsConfig: 28 | def __init__(self, cfg=None): 29 | self.input_patch_size: int = -1 # path size of patch embedding 30 | self.embed_dim: int = 512 # patch embedding dimension 31 | self.conv_bias: bool = False # include bias in conv encoder 32 | 33 | self.encoder_layers: int = 12 # num encoder layers in the transformer 34 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 35 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 36 | self.encoder_attention_heads: int = 12 # num encoder attention heads 37 | self.activation_fn: str = "gelu" # activation function to use 38 | 39 | self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay 40 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 41 | self.deep_norm: bool = False # apply deep_norm first in the transformer 42 | 43 | # dropouts 44 | self.dropout: float = 0.1 # dropout probability for the transformer 45 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 46 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 47 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 48 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 49 | 50 | # positional embeddings 51 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 52 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 53 | 54 | # relative position embedding 55 | self.relative_position_embedding: bool = False # apply relative position embedding 56 | self.num_buckets: int = 320 # number of buckets for relative position embedding 57 | self.max_distance: int = 1280 # maximum distance for relative position embedding 58 | self.gru_rel_pos: bool = False # apply gated relative position embedding 59 | 60 | # label predictor 61 | self.finetuned_model: bool = False # whether the model is a fine-tuned model. 62 | self.predictor_dropout: float = 0.1 # dropout probability for the predictor 63 | self.predictor_class: int = 527 # target class number for the predictor 64 | 65 | if cfg is not None: 66 | self.update(cfg) 67 | 68 | def update(self, cfg: dict): 69 | self.__dict__.update(cfg) 70 | 71 | 72 | class BEATs(nn.Module): 73 | def __init__( 74 | self, 75 | cfg: BEATsConfig, 76 | ) -> None: 77 | super().__init__() 78 | logger.info(f"BEATs Config: {cfg.__dict__}") 79 | 80 | self.cfg = cfg 81 | 82 | self.embed = cfg.embed_dim 83 | self.post_extract_proj = ( 84 | nn.Linear(self.embed, cfg.encoder_embed_dim) 85 | if self.embed != cfg.encoder_embed_dim 86 | else None 87 | ) 88 | 89 | self.input_patch_size = cfg.input_patch_size 90 | self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, 91 | bias=cfg.conv_bias) 92 | 93 | self.dropout_input = nn.Dropout(cfg.dropout_input) 94 | 95 | assert not cfg.deep_norm or not cfg.layer_norm_first 96 | self.encoder = TransformerEncoder(cfg) 97 | self.layer_norm = LayerNorm(self.embed) 98 | 99 | if cfg.finetuned_model: 100 | self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) 101 | self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) 102 | else: 103 | self.predictor = None 104 | 105 | def forward_padding_mask( 106 | self, 107 | features: torch.Tensor, 108 | padding_mask: torch.Tensor, 109 | ) -> torch.Tensor: 110 | extra = padding_mask.size(1) % features.size(1) 111 | if extra > 0: 112 | padding_mask = padding_mask[:, :-extra] 113 | padding_mask = padding_mask.view( 114 | padding_mask.size(0), features.size(1), -1 115 | ) 116 | padding_mask = padding_mask.all(-1) 117 | return padding_mask 118 | 119 | @autocast(enabled=False) 120 | def preprocess( 121 | self, 122 | source: torch.Tensor, 123 | fbank_mean: float = 15.41663, 124 | fbank_std: float = 6.55582, 125 | ) -> torch.Tensor: 126 | fbanks = [] 127 | for waveform in source: 128 | waveform = waveform.unsqueeze(0) * 2 ** 15 129 | fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) 130 | fbanks.append(fbank) 131 | fbank = torch.stack(fbanks, dim=0) 132 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 133 | return fbank 134 | 135 | def extract_features( 136 | self, 137 | source: torch.Tensor, 138 | padding_mask: Optional[torch.Tensor] = None, 139 | fbank_mean: float = 15.41663, 140 | fbank_std: float = 6.55582, 141 | ): 142 | fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) 143 | if padding_mask is not None: 144 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 145 | # ToDo Aug here 146 | fbank = fbank.unsqueeze(1) 147 | features = self.patch_embedding(fbank) 148 | features = features.reshape(features.shape[0], features.shape[1], -1) 149 | features = features.transpose(1, 2) 150 | features = self.layer_norm(features) 151 | 152 | if padding_mask is not None: 153 | padding_mask = self.forward_padding_mask(features, padding_mask) 154 | 155 | if self.post_extract_proj is not None: 156 | features = self.post_extract_proj(features) 157 | 158 | x = self.dropout_input(features) 159 | 160 | x, layers_sum, layers = self.encoder( 161 | x, 162 | padding_mask=padding_mask, 163 | ) 164 | 165 | if self.predictor is not None: 166 | x = self.predictor_dropout(x) 167 | logits = self.predictor(x) 168 | 169 | if padding_mask is not None and padding_mask.any(): 170 | logits[padding_mask] = 0 171 | logits = logits.sum(dim=1) 172 | logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) 173 | else: 174 | logits = logits.mean(dim=1) 175 | 176 | lprobs = torch.sigmoid(logits) 177 | 178 | return lprobs, padding_mask 179 | else: 180 | return x, layers_sum, layers 181 | -------------------------------------------------------------------------------- /modules/BEATs/Tokenizers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import LayerNorm 14 | import torchaudio.compliance.kaldi as ta_kaldi 15 | 16 | from modules.BEATs.backbone import ( 17 | TransformerEncoder, 18 | ) 19 | from modules.BEATs.quantizer import ( 20 | NormEMAVectorQuantizer, 21 | ) 22 | 23 | import logging 24 | from typing import Optional 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class TokenizersConfig: 30 | def __init__(self, cfg=None): 31 | self.input_patch_size: int = -1 # path size of patch embedding 32 | self.embed_dim: int = 512 # patch embedding dimension 33 | self.conv_bias: bool = False # include bias in conv encoder 34 | 35 | self.encoder_layers: int = 12 # num encoder layers in the transformer 36 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 37 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 38 | self.encoder_attention_heads: int = 12 # num encoder attention heads 39 | self.activation_fn: str = "gelu" # activation function to use 40 | 41 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 42 | self.deep_norm: bool = False # apply deep_norm first in the transformer 43 | 44 | # dropouts 45 | self.dropout: float = 0.1 # dropout probability for the transformer 46 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 47 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 48 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 49 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 50 | 51 | # positional embeddings 52 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 53 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 54 | 55 | # relative position embedding 56 | self.relative_position_embedding: bool = False # apply relative position embedding 57 | self.num_buckets: int = 320 # number of buckets for relative position embedding 58 | self.max_distance: int = 1280 # maximum distance for relative position embedding 59 | self.gru_rel_pos: bool = False # apply gated relative position embedding 60 | 61 | # quantizer 62 | self.quant_n: int = 1024 # codebook number in quantizer 63 | self.quant_dim: int = 256 # codebook dimension in quantizer 64 | 65 | if cfg is not None: 66 | self.update(cfg) 67 | 68 | def update(self, cfg: dict): 69 | self.__dict__.update(cfg) 70 | 71 | 72 | class Tokenizers(nn.Module): 73 | def __init__( 74 | self, 75 | cfg: TokenizersConfig, 76 | ) -> None: 77 | super().__init__() 78 | logger.info(f"Tokenizers Config: {cfg.__dict__}") 79 | 80 | self.cfg = cfg 81 | 82 | self.embed = cfg.embed_dim 83 | self.post_extract_proj = ( 84 | nn.Linear(self.embed, cfg.encoder_embed_dim) 85 | if self.embed != cfg.encoder_embed_dim 86 | else None 87 | ) 88 | 89 | self.input_patch_size = cfg.input_patch_size 90 | self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, 91 | bias=cfg.conv_bias) 92 | 93 | self.dropout_input = nn.Dropout(cfg.dropout_input) 94 | 95 | assert not cfg.deep_norm or not cfg.layer_norm_first 96 | self.encoder = TransformerEncoder(cfg) 97 | self.layer_norm = LayerNorm(self.embed) 98 | 99 | self.quantize = NormEMAVectorQuantizer( 100 | n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, 101 | ) 102 | self.quant_n = cfg.quant_n 103 | self.quantize_layer = nn.Sequential( 104 | nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), 105 | nn.Tanh(), 106 | nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize 107 | ) 108 | 109 | def forward_padding_mask( 110 | self, 111 | features: torch.Tensor, 112 | padding_mask: torch.Tensor, 113 | ) -> torch.Tensor: 114 | extra = padding_mask.size(1) % features.size(1) 115 | if extra > 0: 116 | padding_mask = padding_mask[:, :-extra] 117 | padding_mask = padding_mask.view( 118 | padding_mask.size(0), features.size(1), -1 119 | ) 120 | padding_mask = padding_mask.all(-1) 121 | return padding_mask 122 | 123 | def preprocess( 124 | self, 125 | source: torch.Tensor, 126 | fbank_mean: float = 15.41663, 127 | fbank_std: float = 6.55582, 128 | ) -> torch.Tensor: 129 | fbanks = [] 130 | for waveform in source: 131 | waveform = waveform.unsqueeze(0) * 2 ** 15 132 | fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) 133 | fbanks.append(fbank) 134 | fbank = torch.stack(fbanks, dim=0) 135 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 136 | return fbank 137 | 138 | def extract_labels( 139 | self, 140 | source: torch.Tensor, 141 | padding_mask: Optional[torch.Tensor] = None, 142 | fbank_mean: float = 15.41663, 143 | fbank_std: float = 6.55582, 144 | ): 145 | fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) 146 | 147 | if padding_mask is not None: 148 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 149 | 150 | fbank = fbank.unsqueeze(1) 151 | features = self.patch_embedding(fbank) 152 | features = features.reshape(features.shape[0], features.shape[1], -1) 153 | features = features.transpose(1, 2) 154 | features = self.layer_norm(features) 155 | 156 | if padding_mask is not None: 157 | padding_mask = self.forward_padding_mask(features, padding_mask) 158 | 159 | if self.post_extract_proj is not None: 160 | features = self.post_extract_proj(features) 161 | 162 | x = self.dropout_input(features) 163 | 164 | x, layer_results = self.encoder( 165 | x, 166 | padding_mask=padding_mask, 167 | ) 168 | 169 | quantize_input = self.quantize_layer(x) 170 | quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) 171 | 172 | return embed_ind 173 | -------------------------------------------------------------------------------- /modules/BEATs/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import warnings 12 | import torch 13 | from torch import Tensor, nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class GradMultiply(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, x, scale): 20 | ctx.scale = scale 21 | res = x.new(x) 22 | return res 23 | 24 | @staticmethod 25 | def backward(ctx, grad): 26 | return grad * ctx.scale, None 27 | 28 | 29 | class SamePad(nn.Module): 30 | def __init__(self, kernel_size, causal=False): 31 | super().__init__() 32 | if causal: 33 | self.remove = kernel_size - 1 34 | else: 35 | self.remove = 1 if kernel_size % 2 == 0 else 0 36 | 37 | def forward(self, x): 38 | if self.remove > 0: 39 | x = x[:, :, : -self.remove] 40 | return x 41 | 42 | 43 | class Swish(nn.Module): 44 | def __init__(self): 45 | super(Swish, self).__init__() 46 | self.act = torch.nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | return x * self.act(x) 50 | 51 | 52 | class GLU_Linear(nn.Module): 53 | def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): 54 | super(GLU_Linear, self).__init__() 55 | 56 | self.glu_type = glu_type 57 | self.output_dim = output_dim 58 | 59 | if glu_type == "sigmoid": 60 | self.glu_act = torch.nn.Sigmoid() 61 | elif glu_type == "swish": 62 | self.glu_act = Swish() 63 | elif glu_type == "relu": 64 | self.glu_act = torch.nn.ReLU() 65 | elif glu_type == "gelu": 66 | self.glu_act = torch.nn.GELU() 67 | 68 | if bias_in_glu: 69 | self.linear = nn.Linear(input_dim, output_dim * 2, True) 70 | else: 71 | self.linear = nn.Linear(input_dim, output_dim * 2, False) 72 | 73 | def forward(self, x): 74 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case 75 | x = self.linear(x) 76 | 77 | if self.glu_type == "bilinear": 78 | x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) 79 | else: 80 | x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) 81 | 82 | return x 83 | 84 | 85 | def gelu_accurate(x): 86 | if not hasattr(gelu_accurate, "_a"): 87 | gelu_accurate._a = math.sqrt(2 / math.pi) 88 | return ( 89 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 90 | ) 91 | 92 | 93 | def gelu(x: torch.Tensor) -> torch.Tensor: 94 | return torch.nn.functional.gelu(x.float()).type_as(x) 95 | 96 | 97 | def get_activation_fn(activation: str): 98 | """Returns the activation function corresponding to `activation`""" 99 | 100 | if activation == "relu": 101 | return F.relu 102 | elif activation == "gelu": 103 | return gelu 104 | elif activation == "gelu_fast": 105 | warnings.warn( 106 | "--activation-fn=gelu_fast has been renamed to gelu_accurate" 107 | ) 108 | return gelu_accurate 109 | elif activation == "gelu_accurate": 110 | return gelu_accurate 111 | elif activation == "tanh": 112 | return torch.tanh 113 | elif activation == "linear": 114 | return lambda x: x 115 | elif activation == "glu": 116 | return lambda x: x 117 | else: 118 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 119 | 120 | 121 | def quant_noise(module, p, block_size): 122 | """ 123 | Wraps modules and applies quantization noise to the weights for 124 | subsequent quantization with Iterative Product Quantization as 125 | described in "Training with Quantization Noise for Extreme Model Compression" 126 | 127 | Args: 128 | - module: nn.Module 129 | - p: amount of Quantization Noise 130 | - block_size: size of the blocks for subsequent quantization with iPQ 131 | 132 | Remarks: 133 | - Module weights must have the right sizes wrt the block size 134 | - Only Linear, Embedding and Conv2d modules are supported for the moment 135 | - For more detail on how to quantize by blocks with convolutional weights, 136 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 137 | - We implement the simplest form of noise here as stated in the paper 138 | which consists in randomly dropping blocks 139 | """ 140 | 141 | # if no quantization noise, don't register hook 142 | if p <= 0: 143 | return module 144 | 145 | # supported modules 146 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 147 | 148 | # test whether module.weight has the right sizes wrt block_size 149 | is_conv = module.weight.ndim == 4 150 | 151 | # 2D matrix 152 | if not is_conv: 153 | assert ( 154 | module.weight.size(1) % block_size == 0 155 | ), "Input features must be a multiple of block sizes" 156 | 157 | # 4D matrix 158 | else: 159 | # 1x1 convolutions 160 | if module.kernel_size == (1, 1): 161 | assert ( 162 | module.in_channels % block_size == 0 163 | ), "Input channels must be a multiple of block sizes" 164 | # regular convolutions 165 | else: 166 | k = module.kernel_size[0] * module.kernel_size[1] 167 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 168 | 169 | def _forward_pre_hook(mod, input): 170 | # no noise for evaluation 171 | if mod.training: 172 | if not is_conv: 173 | # gather weight and sizes 174 | weight = mod.weight 175 | in_features = weight.size(1) 176 | out_features = weight.size(0) 177 | 178 | # split weight matrix into blocks and randomly drop selected blocks 179 | mask = torch.zeros( 180 | in_features // block_size * out_features, device=weight.device 181 | ) 182 | mask.bernoulli_(p) 183 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 184 | 185 | else: 186 | # gather weight and sizes 187 | weight = mod.weight 188 | in_channels = mod.in_channels 189 | out_channels = mod.out_channels 190 | 191 | # split weight matrix into blocks and randomly drop selected blocks 192 | if mod.kernel_size == (1, 1): 193 | mask = torch.zeros( 194 | int(in_channels // block_size * out_channels), 195 | device=weight.device, 196 | ) 197 | mask.bernoulli_(p) 198 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 199 | else: 200 | mask = torch.zeros( 201 | weight.size(0), weight.size(1), device=weight.device 202 | ) 203 | mask.bernoulli_(p) 204 | mask = ( 205 | mask.unsqueeze(2) 206 | .unsqueeze(3) 207 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 208 | ) 209 | 210 | # scale weights and apply mask 211 | mask = mask.to( 212 | torch.bool 213 | ) # x.bool() is not currently supported in TorchScript 214 | s = 1 / (1 - p) 215 | mod.weight.data = s * weight.masked_fill(mask, 0) 216 | 217 | module.register_forward_pre_hook(_forward_pre_hook) 218 | return module 219 | -------------------------------------------------------------------------------- /modules/BEATs/quantizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on VQGAN code bases 7 | # https://github.com/CompVis/taming-transformers 8 | # --------------------------------------------------------' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.distributed as distributed 14 | 15 | try: 16 | from einops import rearrange, repeat 17 | except ImportError: 18 | pass 19 | 20 | 21 | def l2norm(t): 22 | return F.normalize(t, p=2, dim=-1) 23 | 24 | 25 | def ema_inplace(moving_avg, new, decay): 26 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 27 | 28 | 29 | def sample_vectors(samples, num): 30 | num_samples, device = samples.shape[0], samples.device 31 | 32 | if num_samples >= num: 33 | indices = torch.randperm(num_samples, device=device)[:num] 34 | else: 35 | indices = torch.randint(0, num_samples, (num,), device=device) 36 | 37 | return samples[indices] 38 | 39 | 40 | def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): 41 | dim, dtype, device = samples.shape[-1], samples.dtype, samples.device 42 | 43 | means = sample_vectors(samples, num_clusters) 44 | 45 | for _ in range(num_iters): 46 | if use_cosine_sim: 47 | dists = samples @ means.t() 48 | else: 49 | diffs = rearrange(samples, 'n d -> n () d') \ 50 | - rearrange(means, 'c d -> () c d') 51 | dists = -(diffs ** 2).sum(dim=-1) 52 | 53 | buckets = dists.max(dim=-1).indices 54 | bins = torch.bincount(buckets, minlength=num_clusters) 55 | zero_mask = bins == 0 56 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 57 | 58 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 59 | new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) 60 | new_means = new_means / bins_min_clamped[..., None] 61 | 62 | if use_cosine_sim: 63 | new_means = l2norm(new_means) 64 | 65 | means = torch.where(zero_mask[..., None], means, new_means) 66 | 67 | return means, bins 68 | 69 | 70 | class EmbeddingEMA(nn.Module): 71 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): 72 | super().__init__() 73 | self.num_tokens = num_tokens 74 | self.codebook_dim = codebook_dim 75 | self.decay = decay 76 | self.eps = eps 77 | if codebook_init_path == '': 78 | if not kmeans_init: 79 | weight = torch.randn(num_tokens, codebook_dim) 80 | weight = l2norm(weight) 81 | else: 82 | weight = torch.zeros(num_tokens, codebook_dim) 83 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 84 | else: 85 | print(f"load init codebook weight from {codebook_init_path}") 86 | codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') 87 | weight = codebook_ckpt_weight.clone() 88 | self.register_buffer('initted', torch.Tensor([True])) 89 | 90 | self.weight = nn.Parameter(weight, requires_grad=False) 91 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) 92 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) 93 | # self.register_buffer('initted', torch.Tensor([not kmeans_init])) 94 | self.update = True 95 | 96 | @torch.jit.ignore 97 | def init_embed_(self, data): 98 | if self.initted: 99 | return 100 | print("Performing Kemans init for codebook") 101 | embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) 102 | self.weight.data.copy_(embed) 103 | self.cluster_size.data.copy_(cluster_size) 104 | self.initted.data.copy_(torch.Tensor([True])) 105 | 106 | def forward(self, embed_id): 107 | return F.embedding(embed_id, self.weight) 108 | 109 | def cluster_size_ema_update(self, new_cluster_size): 110 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 111 | 112 | def embed_avg_ema_update(self, new_embed_avg): 113 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 114 | 115 | def weight_update(self, num_tokens): 116 | n = self.cluster_size.sum() 117 | smoothed_cluster_size = ( 118 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 119 | ) 120 | # normalize embedding average with smoothed cluster size 121 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 122 | # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) 123 | self.weight.data.copy_(embed_normalized) 124 | 125 | 126 | def norm_ema_inplace(moving_avg, new, decay): 127 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 128 | moving_avg.data.copy_(l2norm(moving_avg.data)) 129 | 130 | 131 | class NormEMAVectorQuantizer(nn.Module): 132 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 133 | statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): 134 | super().__init__() 135 | self.codebook_dim = embedding_dim 136 | self.num_tokens = n_embed 137 | self.beta = beta 138 | self.decay = decay 139 | 140 | # learnable = True if orthogonal_reg_weight > 0 else False 141 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) 142 | 143 | self.statistic_code_usage = statistic_code_usage 144 | if statistic_code_usage: 145 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 146 | if distributed.is_available() and distributed.is_initialized(): 147 | print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") 148 | self.all_reduce_fn = distributed.all_reduce 149 | else: 150 | self.all_reduce_fn = nn.Identity() 151 | 152 | def reset_cluster_size(self, device): 153 | if self.statistic_code_usage: 154 | self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) 155 | self.cluster_size = self.cluster_size.to(device) 156 | 157 | def forward(self, z): 158 | # reshape z -> (batch, height, width, channel) and flatten 159 | # z, 'b c h w -> b h w c' 160 | # z = rearrange(z, 'b c h w -> b h w c') 161 | # z = z.transpose(1, 2) 162 | z = l2norm(z) 163 | z_flattened = z.reshape(-1, self.codebook_dim) 164 | 165 | self.embedding.init_embed_(z_flattened) 166 | 167 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 168 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 169 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 170 | 171 | encoding_indices = torch.argmin(d, dim=1) 172 | 173 | z_q = self.embedding(encoding_indices).view(z.shape) 174 | 175 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 176 | 177 | if not self.training: 178 | with torch.no_grad(): 179 | cluster_size = encodings.sum(0) 180 | self.all_reduce_fn(cluster_size) 181 | ema_inplace(self.cluster_size, cluster_size, self.decay) 182 | 183 | if self.training and self.embedding.update: 184 | # EMA cluster size 185 | 186 | bins = encodings.sum(0) 187 | self.all_reduce_fn(bins) 188 | 189 | # self.embedding.cluster_size_ema_update(bins) 190 | ema_inplace(self.cluster_size, bins, self.decay) 191 | 192 | zero_mask = (bins == 0) 193 | bins = bins.masked_fill(zero_mask, 1.) 194 | 195 | embed_sum = z_flattened.t() @ encodings 196 | self.all_reduce_fn(embed_sum) 197 | 198 | embed_normalized = (embed_sum / bins.unsqueeze(0)).t() 199 | embed_normalized = l2norm(embed_normalized) 200 | 201 | embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, 202 | embed_normalized) 203 | norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) 204 | 205 | # compute loss for embedding 206 | loss = self.beta * F.mse_loss(z_q.detach(), z) 207 | 208 | # preserve gradients 209 | z_q = z + (z_q - z).detach() 210 | 211 | # reshape back to match original input shape 212 | # z_q, 'b h w c -> b c h w' 213 | # z_q = rearrange(z_q, 'b h w c -> b c h w') 214 | # z_q = z_q.transpose(1, 2) 215 | return z_q, loss, encoding_indices -------------------------------------------------------------------------------- /modules/CLIPSeg/clipseg_for_audio.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from typing import List, Tuple, Union, Optional 6 | import numpy as np 7 | from transformers.models.clipseg.modeling_clipseg import _expand_mask 8 | 9 | 10 | class CLIPSeg(transformers.CLIPSegForImageSegmentation): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | def encode_text(self, text: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Encode textual input and return the text embeddings. 17 | 18 | Args: 19 | text (torch.Tensor): Input text tensor. 20 | 21 | Returns: 22 | torch.Tensor: Text embeddings. 23 | """ 24 | tokens = text 25 | if text.ndim == 3: 26 | tokens = torch.squeeze(text, dim=1) 27 | non_zero_index = torch.nonzero(tokens.sum(axis=0) == 0)[0] 28 | input_ids = tokens[:, :non_zero_index] 29 | attention_mask = (input_ids > 0).to(tokens.dtype) 30 | input_ids += torch.max(input_ids) * (1 - attention_mask) 31 | conditional_embeddings = self.clip.get_text_features(input_ids, attention_mask=attention_mask, 32 | position_ids=None) 33 | 34 | return conditional_embeddings 35 | 36 | def similarity(self, image: torch.Tensor, embeddings: List[torch.Tensor]) -> torch.Tensor: 37 | """ 38 | Calculate the similarity score between an image and a list of embeddings. 39 | 40 | Args: 41 | image (torch.Tensor): Input image tensor of shape (B, C, H, W). 42 | embeddings (List[torch.Tensor]): List of N embedding tensors of shape (dim,). 43 | 44 | Returns: 45 | torch.Tensor: Similarity scores of shape (B, N) for each batch. 46 | """ 47 | B, c, h, w = image.shape 48 | if (h, w) != (352, 352): 49 | vision_outputs = self.clip.vision_model(pixel_values=F.interpolate(image, 352, mode='bicubic'), 50 | output_attentions=False, 51 | output_hidden_states=False, 52 | return_dict=False) 53 | img_embedding = self.clip.visual_projection(vision_outputs[1]) 54 | else: 55 | vision_outputs = self.clip.vision_model(pixel_values=image, 56 | output_attentions=False, 57 | output_hidden_states=False, 58 | return_dict=False) 59 | img_embedding = self.clip.visual_projection(vision_outputs[1]) 60 | 61 | paired_embedding = torch.cat(embeddings, dim=0) 62 | paired_embedding = paired_embedding.repeat(B, 1) # Batch-wise replication of embeddings 63 | paired_embedding = paired_embedding.view(B, -1, img_embedding.size(-1)) 64 | 65 | result = torch.matmul(F.normalize(paired_embedding, dim=-1), F.normalize(img_embedding, dim=-1).unsqueeze(-1)) 66 | result = result.squeeze(-1).view(B, -1) 67 | return F.softmax(result, dim=-1) 68 | 69 | def encode_audio(self, placeholder_token: torch.Tensor, audio_token: torch.Tensor, pos: int, 70 | length: int) -> torch.Tensor: 71 | """ 72 | Encode audio token into the audio-driven embeddings. (Audio-Driven Embedder) 73 | 74 | Args: 75 | placeholder_token (torch.Tensor): Placeholder text token tensor. 76 | audio_token (torch.Tensor): Audio token tensor. 77 | pos (int): Position index for audio token. 78 | length (int): Length of the input token. 79 | 80 | Returns: 81 | torch.Tensor: Audio-driven embeddings. 82 | 83 | Reference: 84 | "Can CLIP Help Sound Source Localization?" WACV 2024 85 | - https://arxiv.org/abs/2311.04066 86 | """ 87 | tokens = placeholder_token 88 | if placeholder_token.ndim == 3: 89 | tokens = torch.squeeze(placeholder_token, dim=1) 90 | 91 | inputs_embeds = self.clip.text_model.embeddings.token_embedding(tokens).type( 92 | self.dtype) # [batch_size, n_ctx, d_model] 93 | inputs_embeds = torch.cat((inputs_embeds[:, :pos, :], audio_token, inputs_embeds[:, pos:, :]), 94 | dim=1) # Inject Audio token 95 | inputs_embeds = inputs_embeds[:, :length, :] 96 | 97 | bsz, seq_len, _ = inputs_embeds.shape 98 | attention_mask = torch.ones((bsz, seq_len)).to(placeholder_token.device) 99 | position_ids = torch.arange(length).unsqueeze(0).to(placeholder_token.device) 100 | 101 | position_embeddings = self.clip.text_model.embeddings.position_embedding(position_ids) 102 | hidden_states = inputs_embeds + position_embeddings 103 | 104 | bsz, seq_len, _ = inputs_embeds.shape 105 | # CLIPSeg's text model uses causal mask, prepare it here. 106 | # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324 107 | causal_attention_mask = self.clip.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( 108 | hidden_states.device 109 | ) 110 | # expand attention_mask 111 | if attention_mask is not None: 112 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 113 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype) 114 | 115 | encoder_outputs = self.clip.text_model.encoder( 116 | inputs_embeds=hidden_states, 117 | attention_mask=attention_mask, 118 | causal_attention_mask=causal_attention_mask, 119 | output_attentions=False, 120 | output_hidden_states=False, 121 | return_dict=True, 122 | ) 123 | 124 | last_hidden_state = encoder_outputs[0] 125 | last_hidden_state = self.clip.text_model.final_layer_norm(last_hidden_state) 126 | 127 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 128 | # take features from the eot embedding (eot_token is the highest number in each sequence) 129 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 130 | pooled_output = last_hidden_state[:, -1, :] 131 | audio_driven_embeddings = self.clip.text_projection(pooled_output) 132 | return audio_driven_embeddings 133 | 134 | def get_pixels(self, image: torch.Tensor) -> torch.Tensor: 135 | """ 136 | Extract spatial features (pixel-level) from the CLIP image encoder. 137 | 138 | Args: 139 | image (torch.Tensor): Input image tensor. 140 | 141 | Returns: 142 | torch.Tensor: Spatial visual features (pixel-level). 143 | """ 144 | vision_outputs = self.clip.vision_model(pixel_values=image, 145 | output_attentions=None, 146 | output_hidden_states=True, 147 | return_dict=True) 148 | last_layer = self.clip.vision_model.encoder.layers[-1] 149 | 150 | hidden_states = vision_outputs.hidden_states[-2] 151 | residual = hidden_states 152 | 153 | hidden_states = last_layer.layer_norm1(hidden_states) 154 | 155 | bsz, tgt_len, embed_dim = hidden_states.size() 156 | 157 | # get query proj 158 | # query_states = last_layer.self_attn.q_proj(hidden_states) * last_layer.self_attn.scale 159 | # key_states = last_layer.self_attn.k_proj(hidden_states) 160 | value_states = last_layer.self_attn.v_proj(hidden_states) 161 | 162 | value_states = last_layer.self_attn.out_proj(value_states) 163 | 164 | value_states += residual 165 | 166 | residual = value_states 167 | value_states = last_layer.layer_norm2(value_states) 168 | value_states = last_layer.mlp(value_states) 169 | value_states += residual 170 | 171 | value_states = self.clip.vision_model.post_layernorm(value_states) 172 | output = self.clip.visual_projection(value_states) 173 | 174 | width = int(np.sqrt(tgt_len - 1)) 175 | output = output[:, 1:] 176 | if output.ndim == 2: 177 | output = output.unsqueeze(0) 178 | 179 | output = output.permute(0, 2, 1) 180 | output = output.reshape(bsz, self.clip.visual_projection.out_features, width, width) 181 | 182 | return output 183 | -------------------------------------------------------------------------------- /modules/FGA/atten.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from itertools import product, permutations, combinations_with_replacement, chain 7 | 8 | 9 | class Unary(nn.Module): 10 | def __init__(self, embed_size): 11 | """ 12 | Captures local entity information 13 | :param embed_size: the embedding dimension 14 | """ 15 | super(Unary, self).__init__() 16 | self.embed = nn.Conv1d(embed_size, embed_size, 1) 17 | self.feature_reduce = nn.Conv1d(embed_size, 1, 1) 18 | 19 | def forward(self, X): 20 | X = X.transpose(1, 2) 21 | 22 | X_embed = self.embed(X) 23 | 24 | X_nl_embed = F.dropout(F.relu(X_embed), training=self.training) 25 | X_poten = self.feature_reduce(X_nl_embed) 26 | return X_poten.squeeze(1) 27 | 28 | 29 | class Pairwise(nn.Module): 30 | def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None): 31 | """ 32 | Captures interaction between utilities or entities of the same utility 33 | :param embed_x_size: the embedding dimension of the first utility 34 | :param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization 35 | :param embed_y_size: the embedding dimension of the second utility (none for self-interactions) 36 | :param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization 37 | """ 38 | 39 | super(Pairwise, self).__init__() 40 | embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size 41 | self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim 42 | 43 | self.embed_size = max(embed_x_size, embed_y_size) 44 | self.x_spatial_dim = x_spatial_dim 45 | 46 | self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1) 47 | self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1) 48 | if x_spatial_dim is not None: 49 | self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim) 50 | 51 | self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1) 52 | self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1) 53 | 54 | def forward(self, X, Y=None): 55 | 56 | X_t = X.transpose(1, 2) 57 | Y_t = Y.transpose(1, 2) if Y is not None else X_t 58 | 59 | 60 | X_embed = self.embed_X(X_t) 61 | Y_embed = self.embed_Y(Y_t) 62 | 63 | X_norm = F.normalize(X_embed) 64 | Y_norm = F.normalize(Y_embed) 65 | 66 | S = X_norm.transpose(1, 2).bmm(Y_norm) 67 | if self.x_spatial_dim is not None: 68 | S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \ 69 | .view(-1, self.x_spatial_dim, self.y_spatial_dim) 70 | 71 | X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2) 72 | Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2) 73 | else: 74 | X_poten = S.mean(dim=2, keepdim=False) 75 | Y_poten = S.mean(dim=1, keepdim=False) 76 | 77 | if Y is None: 78 | return X_poten 79 | else: 80 | return X_poten, Y_poten 81 | 82 | 83 | class Atten(nn.Module): 84 | def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False, 85 | sizes=[], size_force=False, pairwise_flag=True, 86 | unary_flag=True, self_flag=True): 87 | """ 88 | The class performs an attention on a given list of utilities representation. 89 | :param util_e: the embedding dimensions 90 | :param sharing_factor_weights: To share weights, provide a dict of tuples: 91 | {idx: (num_utils, connected utils) 92 | Note, for efficiency, the shared utils (i.e., history, are connected to ans 93 | and question only. 94 | TODO: connections between shared utils 95 | :param prior_flag: is prior factor provided 96 | :param sizes: the spatial simension (used for batch-norm and weighted marginalization) 97 | :param size_force: force spatial size with adaptive avg pooling. 98 | :param pairwise_flag: use pairwise interaction between utilities 99 | :param unary_flag: use local information 100 | :param self_flag: use self interactions between utilitie's entities 101 | """ 102 | super(Atten, self).__init__() 103 | self.util_e = util_e 104 | 105 | self.prior_flag = prior_flag 106 | 107 | self.n_utils = len(util_e) 108 | 109 | self.spatial_pool = nn.ModuleDict() 110 | 111 | self.un_models = nn.ModuleList() 112 | 113 | self.self_flag = self_flag 114 | self.pairwise_flag = pairwise_flag 115 | self.unary_flag = unary_flag 116 | self.size_force = size_force 117 | 118 | if len(sizes) == 0: 119 | sizes = [None for _ in util_e] 120 | 121 | self.sharing_factor_weights = sharing_factor_weights 122 | 123 | #force the provided size 124 | for idx, e_dim in enumerate(util_e): 125 | self.un_models.append(Unary(e_dim)) 126 | if self.size_force: 127 | self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx]) 128 | 129 | #Pairwise 130 | self.pp_models = nn.ModuleDict() 131 | for ((idx1, e_dim_1), (idx2, e_dim_2)) \ 132 | in combinations_with_replacement(enumerate(util_e), 2): 133 | # self 134 | if self.self_flag and idx1 == idx2: 135 | self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1]) 136 | else: 137 | if pairwise_flag: 138 | if idx1 in self.sharing_factor_weights: 139 | # not connected 140 | if idx2 not in self.sharing_factor_weights[idx1][1]: 141 | continue 142 | if idx2 in self.sharing_factor_weights: 143 | # not connected 144 | if idx1 not in self.sharing_factor_weights[idx2][1]: 145 | continue 146 | self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2]) 147 | 148 | # Handle reduce potentials (with scalars) 149 | self.reduce_potentials = nn.ModuleList() 150 | 151 | self.num_of_potentials = dict() 152 | 153 | self.default_num_of_potentials = 0 154 | 155 | if self.self_flag: 156 | self.default_num_of_potentials += 1 157 | if self.unary_flag: 158 | self.default_num_of_potentials += 1 159 | if self.prior_flag: 160 | self.default_num_of_potentials += 1 161 | for idx in range(self.n_utils): 162 | self.num_of_potentials[idx] = self.default_num_of_potentials 163 | 164 | ''' 165 | All other utilities 166 | ''' 167 | if pairwise_flag: 168 | for idx, (num_utils, connected_utils) in sharing_factor_weights: 169 | for c_u in connected_utils: 170 | self.num_of_potentials[c_u] += num_utils 171 | self.num_of_potentials[idx] += 1 172 | for k in self.num_of_potentials: 173 | if k not in self.sharing_factor_weights: 174 | self.num_of_potentials[k] += (self.n_utils - 1) \ 175 | - len(sharing_factor_weights) 176 | 177 | for idx in range(self.n_utils): 178 | self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx], 179 | 1, 1, bias=False)) 180 | 181 | def forward(self, utils, priors=None): 182 | assert self.n_utils == len(utils) 183 | assert (priors is None and not self.prior_flag) \ 184 | or (priors is not None 185 | and self.prior_flag 186 | and len(priors) == self.n_utils) 187 | b_size = utils[0].size(0) 188 | util_factors = dict() 189 | attention = list() 190 | 191 | #Force size, constant size is used for pairwise batch normalization 192 | if self.size_force: 193 | for i, (num_utils, _) in self.sharing_factor_weights.items(): 194 | if str(i) not in self.spatial_pool.keys(): 195 | continue 196 | else: 197 | high_util = utils[i] 198 | high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3)) 199 | high_util = high_util.transpose(1, 2) 200 | utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2) 201 | 202 | for i in range(self.n_utils): 203 | if i in self.sharing_factor_weights \ 204 | or str(i) not in self.spatial_pool.keys(): 205 | continue 206 | utils[i] = utils[i].transpose(1, 2) 207 | utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2) 208 | if self.prior_flag and priors[i] is not None: 209 | priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1) 210 | 211 | # handle Shared weights 212 | for i, (num_utils, connected_list) in self.sharing_factor_weights: 213 | if self.unary_flag: 214 | util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) 215 | 216 | if self.self_flag: 217 | util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) 218 | 219 | if self.pairwise_flag: 220 | for j in connected_list: 221 | other_util = utils[j] 222 | expanded_util = other_util.unsqueeze(1).expand(b_size, 223 | num_utils, 224 | other_util.size(1), 225 | other_util.size(2)).contiguous().view( 226 | b_size * num_utils, 227 | other_util.size(1), 228 | other_util.size(2)) 229 | 230 | if i < j: 231 | factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util) 232 | else: 233 | factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i]) 234 | util_factors[i].append(factor_ij) 235 | util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1))) 236 | 237 | # handle local factors 238 | for i in range(self.n_utils): 239 | if i in self.sharing_factor_weights: 240 | continue 241 | if self.unary_flag: 242 | util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) 243 | if self.self_flag: 244 | util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) 245 | 246 | # joint 247 | if self.pairwise_flag: 248 | for (i, j) in combinations_with_replacement(range(self.n_utils), 2): 249 | if i in self.sharing_factor_weights \ 250 | or j in self.sharing_factor_weights: 251 | continue 252 | if i == j: 253 | continue 254 | else: 255 | factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j]) 256 | util_factors.setdefault(i, []).append(factor_ij) 257 | util_factors.setdefault(j, []).append(factor_ji) 258 | 259 | # perform attention 260 | for i in range(self.n_utils): 261 | if self.prior_flag: 262 | prior = priors[i] \ 263 | if priors[i] is not None \ 264 | else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda() 265 | 266 | util_factors[i].append(prior) 267 | 268 | util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1) 269 | for p in util_factors[i]], dim=1) 270 | util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1) 271 | util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2) 272 | attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2)) 273 | 274 | return attention 275 | 276 | 277 | class NaiveAttention(nn.Module): 278 | def __init__(self): 279 | """ 280 | Used for ablation analysis - removing attention. 281 | """ 282 | super(NaiveAttention, self).__init__() 283 | 284 | def forward(self, utils, priors): 285 | atten = [] 286 | spatial_atten = [] 287 | for u, p in zip(utils, priors): 288 | if type(u) is tuple: 289 | u = u[1] 290 | num_elements = u.shape[0] 291 | if p is not None: 292 | u = u.view(-1, u.shape[-2], u.shape[-1]) 293 | p = p.view(-1, p.shape[-2], p.shape[-1]) 294 | spatial_atten.append( 295 | torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1])) 296 | else: 297 | spatial_atten.append(u.mean(2)) 298 | continue 299 | if p is not None: 300 | atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2)) 301 | else: 302 | atten.append(u.mean(1)) 303 | return atten, spatial_atten -------------------------------------------------------------------------------- /modules/FGA/fga_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.FGA.atten import Atten 5 | 6 | 7 | class FGA(nn.Module): 8 | def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim, 9 | hidden_hist_dim, hidden_cap_dim, hidden_img_dim): 10 | ''' 11 | Factor Graph Attention 12 | :param vocab_size: vocabulary size 13 | :param word_embed_dim 14 | :param hidden_ques_dim: 15 | :param hidden_ans_dim: 16 | :param hidden_hist_dim: 17 | :param img_features_dim: 18 | ''' 19 | super(FGA, self).__init__() 20 | 21 | print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s," 22 | " hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim, 23 | hidden_ques_dim, 24 | hidden_ans_dim, 25 | hidden_hist_dim, 26 | hidden_cap_dim, 27 | hidden_img_dim)) 28 | self.hidden_ques_dim = hidden_ques_dim 29 | self.hidden_ans_dim = hidden_ans_dim 30 | self.hidden_cap_dim = hidden_cap_dim 31 | self.hidden_img_dim = hidden_img_dim 32 | self.hidden_hist_dim = hidden_hist_dim 33 | 34 | # Vocab of History LSTMs is one more as we are keeping a stop id (the last id) 35 | self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0) 36 | 37 | self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True) 38 | self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True) 39 | 40 | self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) 41 | self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) 42 | 43 | self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True) 44 | 45 | 46 | self.qahistnet = nn.Sequential( 47 | nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim), 48 | nn.ReLU(inplace=True) 49 | ) 50 | 51 | self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \ 52 | self.hidden_ans_dim + self.hidden_img_dim + \ 53 | self.hidden_cap_dim + self.hidden_hist_dim*9 54 | 55 | self.simnet = nn.Sequential( 56 | nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False), 57 | nn.BatchNorm1d((self.concat_dim) // 2), 58 | nn.ReLU(inplace=True), 59 | nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False), 60 | nn.BatchNorm1d((self.concat_dim) // 4), 61 | nn.ReLU(inplace=True), 62 | nn.Dropout(0.5), 63 | nn.Linear((self.concat_dim)//4, 1) 64 | ) 65 | 66 | # To share weights, provide list of tuples: (idx, list of connected utils) 67 | # Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only. 68 | # connecting shared factors is not supported (!) 69 | sharing_factor_weights = {4: (9, [0, 1]), 70 | 5: (9, [0, 1])} 71 | 72 | self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal 73 | self.hidden_ques_dim, # Question modal 74 | self.hidden_cap_dim, # Caption modal 75 | self.hidden_img_dim, # Image modal 76 | self.hidden_hist_dim, # Question-history modal 77 | self.hidden_hist_dim # Answer-history modal 78 | ], 79 | sharing_factor_weights=sharing_factor_weights, 80 | sizes=[100, # 100 Answers 81 | 21, # Question length 82 | 41, # Caption length 83 | 37, # 36 Image regions 84 | 21, # History-Question length 85 | 21 # History-Answer length 86 | ] # The spatial dim used for pairwise normalization (use force for adaptive) 87 | , prior_flag=True, 88 | pairwise_flag=True) 89 | 90 | 91 | 92 | def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap, 93 | input_ques_length, input_ans_length, input_cap_length, i_e): 94 | """ 95 | 96 | :param input_ques: 97 | :param input_ans: 98 | :param input_hist_ques: 99 | :param input_hist_ans: 100 | :param input_hist_cap: 101 | :param input_ques_length: 102 | :param input_ans_length: 103 | :param input_cap_length: 104 | :param i_e: 105 | :return: 106 | """ 107 | 108 | 109 | n_options = input_ans.size()[1] 110 | batch_size = input_ques.size()[0] 111 | 112 | 113 | 114 | nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2] 115 | nwords_per_cap = input_hist_cap.size()[1] 116 | max_length_input_ans = input_ans.size()[-1] 117 | 118 | assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \ 119 | input_ans.size()[0] == input_hist_cap.size()[0] 120 | assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1] 121 | assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2] 122 | 123 | q_we = self.word_embedddings(input_ques) 124 | a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans)) 125 | hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa)) 126 | ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa)) 127 | c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap)) 128 | 129 | 130 | 131 | ''' 132 | q_we = batch x 20 x embed_ques_dim 133 | a_we = 100*batch x 20 x embed_ans_dim 134 | hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim 135 | ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim 136 | c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim 137 | ''' 138 | self.lstm_ques.flatten_parameters() 139 | self.lstm_ans.flatten_parameters() 140 | self.lstm_hist_ques.flatten_parameters() 141 | self.lstm_hist_ans.flatten_parameters() 142 | self.lstm_hist_cap.flatten_parameters() 143 | 144 | 145 | i_feat = i_e 146 | 147 | q_seq, self.hidden_ques = self.lstm_ques(q_we) 148 | a_seq, self.hidden_ans = self.lstm_ans(a_we) 149 | hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we) 150 | ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we) 151 | cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we) 152 | 153 | 154 | ''' 155 | length is used for attention prior 156 | ''' 157 | q_len = input_ques_length.data - 1 158 | c_len = input_cap_length.data.view(-1) - 1 159 | 160 | 161 | ans_index = torch.arange(0, n_options * batch_size).long().cuda() 162 | ans_len = input_ans_length.data.view(-1) - 1 163 | ans_seq = a_seq[ans_index, ans_len, :] 164 | ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim) 165 | 166 | batch_index = torch.arange(0, batch_size).long().cuda() 167 | q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda() 168 | q_prior[batch_index, q_len] = 100 169 | c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda() 170 | c_prior[batch_index, c_len] = 100 171 | ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda() 172 | img_prior = torch.ones(batch_size, i_feat.size(1)).cuda() 173 | 174 | (ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \ 175 | self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq], 176 | priors=[ans_prior, q_prior, c_prior, img_prior, None, None]) 177 | 178 | ''' 179 | expand to answers based 180 | ''' 181 | ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size, 182 | n_options, 183 | self.hidden_ques_dim) 184 | cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size, 185 | n_options, 186 | self.hidden_cap_dim) 187 | img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options, 188 | self.hidden_img_dim) 189 | ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options, 190 | self.hidden_ans_dim) 191 | 192 | 193 | ''' 194 | combine history 195 | ''' 196 | 197 | input_qahistnet = torch.cat((hq_atten, ha_atten), 1) 198 | # input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim) 199 | output_qahistnet = self.qahistnet(input_qahistnet) 200 | # output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim) 201 | output_qahistnet = output_qahistnet.view(batch_size, 202 | nqa_per_dial * self.hidden_hist_dim) 203 | # output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim) 204 | output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\ 205 | .expand(batch_size, 206 | n_options, 207 | nqa_per_dial * self.hidden_hist_dim) 208 | 209 | input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten, 210 | output_qahistnet, cap_atten), 2) # Concatenate last dimension 211 | 212 | input_qa = input_qa.view(batch_size * n_options, self.concat_dim) 213 | 214 | out_scores = self.simnet(input_qa) 215 | 216 | out_scores = out_scores.squeeze(dim=1) 217 | out_scores = out_scores.view(batch_size, n_options) 218 | 219 | return out_scores -------------------------------------------------------------------------------- /modules/arg_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List, Optional, Union, Tuple 3 | 4 | 5 | def int_or_int_list_or_none(value: Optional[Union[int, str]]) -> List[Optional[int]]: 6 | """ 7 | Parse an input value into a list of integers or a single integer, or None. 8 | 9 | Args: 10 | value (Optional[Union[int, str]]): The input value to parse. 11 | 12 | Returns: 13 | List[Optional[int]]: A list containing either a single integer, a list of integers, 14 | or a single None value. 15 | 16 | Raises: 17 | argparse.ArgumentTypeError: If the input value cannot be parsed into the specified formats. 18 | """ 19 | if value in ['None', 'null']: 20 | return [None] 21 | try: 22 | # If the value contains commas, parse it as a comma-separated list of integers 23 | if ',' in value: 24 | return [int(x) for x in value.split(',')] 25 | # If it's a single integer, pack it into a list 26 | else: 27 | return [int(value)] 28 | except ValueError: 29 | raise argparse.ArgumentTypeError("Invalid format. Use an integer, a comma-separated list of integers, or None.") 30 | 31 | 32 | def int_or_float(value): 33 | if '.' in value: 34 | try: 35 | return float(value) 36 | except ValueError: 37 | raise argparse.ArgumentTypeError("Quality level must be an integer or a float") 38 | else: 39 | try: 40 | return int(value) 41 | except ValueError: 42 | raise argparse.ArgumentTypeError("Quality level must be an integer or a float") 43 | -------------------------------------------------------------------------------- /modules/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False): 6 | """Samples from the Gumbel-Sigmoid distribution and optionally discretizes. 7 | References: 8 | - https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py 9 | - https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax 10 | Note: 11 | X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1). 12 | That is, we can implement gumbel_sigmoid using Logistic distribution. 13 | """ 14 | logistic = torch.rand_like(logits) 15 | logistic = logistic.div_(1. - logistic).log_() # ~Logistic(0,1) 16 | 17 | gumbels = (logits + logistic) / tau # ~Logistic(logits, tau) 18 | y_soft = gumbels.sigmoid_() 19 | 20 | if hard: 21 | # Straight through. 22 | y_hard = y_soft.gt(0.5).type(y_soft.dtype) 23 | # gt_ break gradient flow 24 | # y_hard = y_soft.gt_(0.5) # gt_() maintain dtype, different to gt() 25 | ret = y_hard - y_soft.detach() + y_soft 26 | else: 27 | # Reparametrization trick. 28 | ret = y_soft 29 | 30 | return ret 31 | 32 | 33 | class Sim2Mask(nn.Module): 34 | def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True): 35 | """ 36 | Sim2Mask module for generating binary masks. 37 | 38 | Args: 39 | init_w (float): Initial value for weight. 40 | init_b (float): Initial value for bias. 41 | gumbel_tau (float): Gumbel-Softmax temperature. 42 | learnable (bool): If True, weight and bias are learnable parameters. 43 | 44 | Reference: 45 | "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023 46 | - https://github.com/kakaobrain/tcl 47 | - https://arxiv.org/abs/2212.00785 48 | """ 49 | super().__init__() 50 | self.init_w = init_w 51 | self.init_b = init_b 52 | self.gumbel_tau = gumbel_tau 53 | self.learnable = learnable 54 | 55 | assert not ((init_w is None) ^ (init_b is None)) 56 | if learnable: 57 | self.w = nn.Parameter(torch.full([], float(init_w))) 58 | self.b = nn.Parameter(torch.full([], float(init_b))) 59 | else: 60 | self.w = init_w 61 | self.b = init_b 62 | 63 | def forward(self, x, deterministic=False): 64 | logits = x * self.w + self.b 65 | 66 | soft_mask = torch.sigmoid(logits) 67 | if deterministic: 68 | hard_mask = soft_mask.gt(0.5).type(logits.dtype) 69 | else: 70 | hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau) 71 | 72 | return hard_mask, soft_mask 73 | 74 | def extra_repr(self): 75 | return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}' 76 | 77 | 78 | def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor: 79 | """ 80 | Normalize image tensor to the range [0, 1]. 81 | 82 | Args: 83 | tensor (torch.Tensor): Input image tensor. 84 | 85 | Returns: 86 | torch.Tensor: Normalized image tensor. 87 | """ 88 | vmin = tensor.amin((2, 3), keepdims=True) - 1e-7 89 | vmax = tensor.amax((2, 3), keepdims=True) + 1e-7 90 | tensor = (tensor - vmin) / (vmax - vmin) 91 | return tensor 92 | 93 | 94 | class ImageMasker(Sim2Mask): 95 | def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor: 96 | """ 97 | Forward pass for generating image-level binary masks. 98 | 99 | Args: 100 | x (torch.Tensor): Input tensor. 101 | infer (bool): True for only inference stage. 102 | 103 | Returns: 104 | torch.Tensor: Binary mask. 105 | 106 | Reference: 107 | "Can CLIP Help Sound Source Localization?" WACV 2024 108 | - https://arxiv.org/abs/2311.04066 109 | """ 110 | if self.training or not infer: 111 | output = super().forward(x, False)[0] 112 | else: 113 | output = torch.sigmoid(x + self.b / self.w) 114 | return output 115 | 116 | 117 | class FeatureMasker(nn.Module): 118 | def __init__(self, thr: float = 0.5, tau: float = 0.07): 119 | """ 120 | Masker module for generating feature-level masks. 121 | 122 | Args: 123 | thr (float): Threshold for generating the mask. 124 | tau (float): Temperature for the sigmoid function. 125 | 126 | Reference: 127 | "Can CLIP Help Sound Source Localization?" WACV 2024 128 | - https://arxiv.org/abs/2311.04066 129 | """ 130 | super().__init__() 131 | self.thr = thr 132 | self.tau = tau 133 | 134 | def forward(self, x: torch.Tensor) -> torch.Tensor: 135 | """ 136 | Forward pass for generating feature-level masks 137 | 138 | Args: 139 | x (torch.Tensor): Input tensor. 140 | 141 | Returns: 142 | torch.Tensor: Generated mask. 143 | """ 144 | return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau) 145 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | import yaml 6 | import argparse 7 | 8 | from modules.BEATs.BEATs import BEATs, BEATsConfig 9 | from modules.AudioToken.embedder import FGAEmbedder 10 | from modules.CLIPSeg.clipseg_for_audio import CLIPSeg 11 | from modules.mask_utils import ImageMasker, FeatureMasker 12 | from transformers import AutoTokenizer 13 | 14 | 15 | class ACL(nn.Module): 16 | def __init__(self, conf_file: str, device: str): 17 | """ 18 | Audio-Grounded Contrastive Learning (ACL) model. 19 | 20 | Args: 21 | conf_file (str): Path to the configuration file. 22 | device (str): Device to move the model to. 23 | """ 24 | super(ACL, self).__init__() 25 | 26 | # Get configuration 27 | with open(conf_file) as f: 28 | config = yaml.load(f, Loader=yaml.FullLoader) 29 | self.args = argparse.Namespace() 30 | self.args.model = argparse.Namespace(**config['model']) 31 | self.args.clip_embedding_dim = config['clip_conf'][self.args.model.clip]['embedding_dim'] 32 | self.args.clip_name = config['clip_conf'][self.args.model.clip]['name'] 33 | self.pretrain = argparse.Namespace(**config['pretrain']) 34 | self.args.audio_proj = argparse.Namespace(**config['fga_conf'][self.args.model.audio_proj]) 35 | 36 | # Init audio encoder 37 | checkpoint = torch.load(self.pretrain.audio_backbone) 38 | cfg = BEATsConfig(checkpoint['cfg']) 39 | self.audio_backbone = BEATs(cfg) 40 | 41 | # Text Tokenizer for placeholder prompt 42 | self.tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined") 43 | 44 | # Init audio projection layer 45 | self.audio_proj = FGAEmbedder(input_size=self.args.audio_proj.input_size * 3, 46 | output_size=self.args.audio_proj.output_size) 47 | 48 | # Init audio-visual grounder (Grounder: CLIPSeg) 49 | self.av_grounder = CLIPSeg.from_pretrained("CIDAS/clipseg-rd64-refined") 50 | 51 | # Init maskers 52 | self.masker_i = ImageMasker(10.0, 14.0, 1.0) 53 | self.masker_f = FeatureMasker(0.5, 0.07) 54 | 55 | # Load weights 56 | self.audio_backbone.load_state_dict(checkpoint['model']) 57 | self.audio_backbone.predictor = None 58 | 59 | if self.pretrain.audio_proj is not None: 60 | self.audio_proj.load_state_dict(torch.load(self.pretrain.audio_embedder)) 61 | 62 | # Set device 63 | self.device = device 64 | self.audio_backbone.to(device=self.device) 65 | self.av_grounder.to(device=self.device) 66 | self.audio_proj.to(device=self.device) 67 | self.masker_i.to(self.device) 68 | self.masker_f.to(self.device) 69 | 70 | def get_placeholder_token(self, prompt_text: str): 71 | """ 72 | Get placeholder token from prompt text 73 | 74 | Args: 75 | prompt_text (str): prompt text without '{}' 76 | 77 | Returns: 78 | CLIPTokenizerFast result with prompt text 79 | """ 80 | placeholder_token = self.tokenizer(prompt_text, return_tensors="pt").data['input_ids'] 81 | placeholder_token = F.pad(placeholder_token, (0, 77 - placeholder_token.shape[-1])).to(self.device) 82 | return placeholder_token 83 | 84 | def train(self, bool: bool = True): 85 | """ 86 | Set the module in training mode. 87 | 88 | Args: 89 | bool (bool): If True, set the module in training mode. 90 | """ 91 | super().train(bool) 92 | self.av_grounder.requires_grad_(False) 93 | self.audio_backbone.requires_grad_(False) 94 | 95 | def encode_audio(self, audio: torch.Tensor, placeholder_token: torch.Tensor, pos: int, 96 | prompt_size: int) -> torch.Tensor: 97 | """ 98 | Encode audio input into audio-driven embedding (Audio-Driven Embedder) 99 | 100 | Args: 101 | audio (torch.Tensor): Input audio tensor. 102 | placeholder_token (torch.Tensor): Placeholder token for CLIP Text encoder. 103 | pos (int): Position of audio token. 104 | prompt_size (int): Size of the placeholder prompt. 105 | 106 | Returns: 107 | torch.Tensor: Audio-driven embeddings. 108 | """ 109 | audio_feat = self.audio_backbone.extract_features(audio)[1] 110 | audio_token_emb = self.audio_proj(audio_feat).unsqueeze(1) 111 | audio_driven_embedding = self.av_grounder.encode_audio(placeholder_token, audio_token_emb, pos, 112 | prompt_size + audio_token_emb.shape[1]) 113 | 114 | return audio_driven_embedding 115 | 116 | def encode_vision(self, image: torch.Tensor) -> torch.Tensor: 117 | """ 118 | Encode visual input and generate visual embeddings. 119 | 120 | Args: 121 | image (torch.Tensor): Input image tensor. 122 | 123 | Returns: 124 | torch.Tensor: Visual embeddings. 125 | """ 126 | vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image, 127 | output_attentions=None, 128 | output_hidden_states=True, 129 | return_dict=True) 130 | pooled_output = self.av_grounder.clip.visual_projection(vision_outputs[1]) 131 | 132 | return pooled_output 133 | 134 | def forward_decoder(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> torch.Tensor: 135 | """ 136 | Forward pass of audio-visual grounder 137 | 138 | Args: 139 | image (torch.Tensor): Input image tensor. 140 | embedding (torch.Tensor): Condition embedding tensor for grounder. 141 | resolution (int): Resolution of the output. 142 | ignore_indices (list): List of indices to ignore. 143 | 144 | Returns: 145 | torch.Tensor: Logits from the decoder. 146 | """ 147 | # step 1: forward the query images through the frozen CLIP vision encoder 148 | vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image, 149 | output_attentions=None, 150 | output_hidden_states=True, 151 | return_dict=True) 152 | 153 | hidden_states = vision_outputs.hidden_states 154 | # we add +1 here as the hidden states also include the initial embeddings 155 | activations = [hidden_states[i + 1] for i in self.av_grounder.extract_layers] 156 | 157 | # step 2: compute conditional embeddings, either from text, images or an own provided embedding 158 | # Audio injected embedding from input argument 159 | 160 | # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks 161 | decoder_outputs = self.av_grounder.decoder( 162 | activations, 163 | embedding, 164 | output_attentions=None, 165 | output_hidden_states=None, 166 | return_dict=True, 167 | ) 168 | logits = decoder_outputs.logits 169 | 170 | if logits.ndim == 2: 171 | logits = logits.unsqueeze(0).unsqueeze(1) 172 | else: 173 | logits = logits.unsqueeze(1) 174 | 175 | B, c, h, w = image.shape 176 | if (h, w) != (resolution, resolution): 177 | logits = F.interpolate(logits, resolution, mode='bicubic') 178 | 179 | return logits 180 | 181 | def forward_module(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224, 182 | force_comb: bool = False) -> torch.Tensor: 183 | """ 184 | Forward pass through the module. 185 | 186 | Args: 187 | image (torch.Tensor): Input image tensor. 188 | embedding (torch.Tensor): Condition embedding tensor for grounder. 189 | resolution (int): Resolution of the output tensor. 190 | force_comb (bool): If True, force to get logits with all combination audio and image. 191 | 192 | Returns: 193 | torch.Tensor: Logits from the decoder. 194 | """ 195 | # N image, 1 embedding case -> [B_i, h, w] 196 | if embedding.shape[0] != image.shape[0] and embedding.shape[0] == 1: 197 | embeddings = embedding.repeat(image.shape[0], 1) 198 | logits = self.forward_decoder(image, embeddings, resolution) 199 | 200 | # N image, M embedding case -> [B_i, B_e, h, w] 201 | elif embedding.shape[0] != image.shape[0] and embedding.shape[0] != 1 and image.shape[0] != 1 or force_comb: 202 | logit_list = [] 203 | for i in range(embedding.shape[0]): 204 | embeddings = embedding[i].unsqueeze(0).repeat(image.shape[0], 1) 205 | logit_list.append(self.forward_decoder(image, embeddings, resolution)) 206 | logits = torch.cat(logit_list, dim=1) 207 | 208 | # N image, N embedding or 1 image, N embedding -> [B_e, h, w] 209 | else: 210 | logits = self.forward_decoder(image, embedding, resolution) 211 | 212 | return logits 213 | 214 | def encode_masked_vision(self, image: torch.Tensor, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, float, float]: 215 | """ 216 | Encode masked visual feature both image-level and feature-level. 217 | 218 | Args: 219 | image (torch.Tensor): Input image tensor. 220 | embedding (torch.Tensor): Condition embedding tensor for grounder. 221 | 222 | Returns: 223 | tuple[torch.Tensor, torch.Tensor, float, float]: Feature masked embeddings, masked image embeddings, positive area, negative area. 224 | """ 225 | B, c, h, w = image.shape 226 | maskclip_feat = self.av_grounder.get_pixels(image) # v^D: [B, c, h, w] 227 | clipseg_mask = self.forward_module(image, embedding, h, force_comb=True) # M^G: [B, B, H, W] 228 | 229 | # Area 230 | area_matrix = self.masker_i(clipseg_mask).mean((2, 3)) 231 | positive_area = area_matrix.diagonal().mean() 232 | negative_area = area_matrix.mean() - positive_area / B 233 | 234 | # Feature level masker 235 | feature_mask = F.interpolate(self.masker_f(clipseg_mask), maskclip_feat.shape[2]) 236 | 237 | # Image level masker 238 | ind = torch.arange(B).to(image.device) 239 | image_mask = self.masker_i(clipseg_mask[ind, ind].unsqueeze(1)) # Positive pair only 240 | feature_masked_emb = torch.einsum('bchw,bnhw->bnc', maskclip_feat, feature_mask) / (feature_mask.sum() + 1e-6) 241 | 242 | # step 1: forward the query images through the frozen CLIP vision encoder 243 | masked_vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image * image_mask, 244 | output_attentions=None, 245 | output_hidden_states=True, 246 | return_dict=True) 247 | masked_image_emb = self.av_grounder.clip.visual_projection(masked_vision_outputs[1]) 248 | 249 | return feature_masked_emb, masked_image_emb, positive_area, negative_area 250 | 251 | def forward(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> dict: 252 | """ 253 | Forward pass of ACL model. 254 | 255 | Args: 256 | image (torch.Tensor): Input image tensor. 257 | embedding (torch.Tensor): Condition embedding tensor for grounder. 258 | resolution (int): Resolution of the output tensor. 259 | 260 | Returns: 261 | dict: Output dictionary containing relevant tensors. 262 | """ 263 | if self.training: 264 | # seg_logit = self.forward_module(image, embedding, resolution) 265 | v_f, v_i, p_area, n_area = self.encode_masked_vision(image, embedding) 266 | out_dict = {'v_f': v_f, 'v_i': v_i, 'p_area': p_area, 'n_area': n_area} 267 | 268 | else: 269 | seg_logit = self.forward_module(image, embedding, resolution) 270 | heatmap = self.masker_i(seg_logit, infer=True) 271 | out_dict = {'heatmap': heatmap} 272 | 273 | return out_dict 274 | 275 | def save(self, model_dir: str): 276 | """ 277 | Save model parameters to a file. (Only trainable parts) 278 | 279 | Args: 280 | model_dir (str): Directory to save the model. 281 | """ 282 | ckp = {'audio_proj': self.audio_proj.state_dict(), 'masker_i': self.masker_i.state_dict()} 283 | torch.save(ckp, model_dir) 284 | 285 | def load(self, model_dir: str): 286 | """ 287 | Load model parameters from a file. (Only trainable parts) 288 | 289 | Args: 290 | model_dir (str): Directory to load the model from. 291 | """ 292 | ckp = torch.load(model_dir, map_location=self.device) 293 | self.audio_proj.load_state_dict(ckp['audio_proj']) 294 | self.masker_i.load_state_dict(ckp['masker_i']) 295 | -------------------------------------------------------------------------------- /pretrain/README.md: -------------------------------------------------------------------------------- 1 | For pretrain model -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | import random 5 | import os 6 | from typing import Tuple, Optional 7 | 8 | 9 | def get_prompt_template(mode: str = 'default') -> Tuple[str, int, int]: 10 | ''' 11 | Generate a prompt template based on the specified mode. 12 | 13 | Args: 14 | mode (str, optional): The mode for selecting the prompt template. Default is 'default'. 15 | 16 | Returns: 17 | Tuple[str, int, int]: A tuple containing the generated prompt template, the position of the placeholder '{}', 18 | and the length of the prompt. 19 | 20 | Notes: 21 | If the mode is 'random', a random prompt template is chosen from a predefined list. 22 | ''' 23 | prompt_template = 'A photo of {}' 24 | 25 | if mode == 'random': 26 | prompt_templates = [ 27 | 'a photo of a {}', 'a photograph of a {}', 'an image of a {}', '{}', 28 | 'a cropped photo of a {}', 'a good photo of a {}', 'a photo of one {}', 29 | 'a bad photo of a {}', 'a photo of the {}', 'a photo of {}', 'a blurry photo of a {}', 30 | 'a picture of a {}', 'a photo of a scene where {}' 31 | ] 32 | prompt_template = random.choice(prompt_templates) 33 | 34 | # Calculate prompt length and text position 35 | prompt_length = 1 + len(prompt_template.split(' ')) + 1 - 1 # eos, sos => 1 + 1, {} => -1 36 | text_pos_at_prompt = 1 + prompt_template.split(' ').index('{}') 37 | 38 | return prompt_template, text_pos_at_prompt, prompt_length 39 | 40 | 41 | # Reproducibility 42 | def fix_seed(seed: int = 0) -> None: 43 | ''' 44 | Set seeds for random number generators to ensure reproducibility. 45 | 46 | Args: 47 | seed (int, optional): The seed value. Default is 0. 48 | ''' 49 | np.random.seed(seed) 50 | random.seed(seed) 51 | torch.manual_seed(seed) 52 | torch.cuda.manual_seed_all(seed) # multi-GPU 53 | torch.backends.cudnn.deterministic = True 54 | torch.backends.cudnn.benchmark = False 55 | os.environ['PYTHONHASHSEED'] = str(seed) 56 | 57 | 58 | def seed_worker(worker_id: int) -> None: 59 | ''' 60 | Set a seed for a worker process to ensure reproducibility in PyTorch DataLoader. 61 | 62 | Args: 63 | worker_id (int): The ID of the worker process. 64 | ''' 65 | worker_seed = torch.initial_seed() % 2**32 66 | np.random.seed(worker_seed) 67 | random.seed(worker_seed) -------------------------------------------------------------------------------- /viz_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageDraw, ImageFont 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def draw_overall(result_dir: str, original_image: Image.Image, gt_image: Image.Image, heatmap_image: Image.Image, 8 | seg_image: Image.Image, label: str, name: str) -> None: 9 | ''' 10 | Draw an overall result figure with original, ground truth, heatmap, and binarized heatmap images. 11 | 12 | Args: 13 | result_dir (str): Directory to save the figure. 14 | original_image (Image.Image): Original image. 15 | gt_image (Image.Image): Ground truth image. 16 | heatmap_image (Image.Image): Heatmap image. 17 | seg_image (Image.Image): Binarized heatmap image. 18 | label (str): Label information. 19 | name (str): Name identifier. 20 | 21 | Returns: 22 | None 23 | ''' 24 | result_box_shape = (2, 2) 25 | 26 | # Calculate IoU 27 | np_gt = 1 - (np.array(gt_image) / 255) 28 | np_seg = 1 - np.array(seg_image) / 255 29 | seg_iou = (np_seg * np_gt).sum() / (((np_seg + np_gt) > 0).sum() + 1e-6) 30 | 31 | # Draw overall result figure 32 | image_width, image_height = 224, 224 33 | padding = 10 34 | canvas_width = (image_width * result_box_shape[1]) + (padding * (result_box_shape[1] + 1)) 35 | canvas_height = (image_height * result_box_shape[0]) + (padding * (result_box_shape[0] + 1)) 36 | canvas = Image.new('RGB', (canvas_width, canvas_height)) 37 | 38 | draw = ImageDraw.Draw(canvas) 39 | font = ImageFont.load_default() 40 | out_text = [f'Label: {label}', f'IoU: {seg_iou:.2f}'] 41 | 42 | resized_images = [original_image, gt_image, heatmap_image, seg_image] 43 | for i in range(np.prod(result_box_shape)): 44 | row = i % 2 45 | col = i // 2 46 | x = (image_width + padding) * col 47 | y = (image_height + padding) * row 48 | canvas.paste(resized_images[i], (x, y)) 49 | 50 | if row == 1: 51 | text = out_text[i // 2] 52 | text_x = (image_width + padding) * col 53 | text_y = (image_height + padding) * row + image_height + padding 54 | draw.text((text_x, text_y), text, font=font, fill=(255, 255, 255)) 55 | 56 | # save fig 57 | output_path = os.path.join(result_dir, 'overall') 58 | os.makedirs(output_path, exist_ok=True) 59 | canvas.save(os.path.join(output_path, f'{name}.jpg')) 60 | 61 | 62 | def draw_overlaid(result_dir: str, original_image: Image.Image, heatmap_image: Image.Image, name: str) -> None: 63 | ''' 64 | Draw an overlaid figure with the original image and heatmap. 65 | 66 | Args: 67 | result_dir (str): Directory to save the figure. 68 | original_image (Image.Image): Original image. 69 | heatmap_image (Image.Image): Heatmap image. 70 | name (str): Name identifier. 71 | 72 | Returns: 73 | None 74 | ''' 75 | heatmap_image = cv2.applyColorMap(np.array(heatmap_image), cv2.COLORMAP_JET) 76 | overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0) 77 | overlaid_image = Image.fromarray(overlaid_image) 78 | 79 | # save fig 80 | output_path = os.path.join(result_dir, 'overlaid') 81 | os.makedirs(output_path, exist_ok=True) 82 | overlaid_image.save(os.path.join(output_path, f'{name}.jpg')) 83 | --------------------------------------------------------------------------------