├── utils ├── __init__.py └── functions.py ├── backbones ├── __init__.py ├── SubNets │ ├── transformers_encoder │ │ ├── __init__.py │ │ ├── position_embedding.py │ │ └── multihead_attention.py │ ├── __init__.py │ └── AlignNets.py ├── FusionNets │ ├── __init__.py │ ├── MIntOOD.py │ ├── MULT.py │ ├── sampler.py │ ├── SDIF.py │ └── AlignNets.py └── base.py ├── ood_detection ├── msp.py ├── maxlogit.py ├── energy.py ├── __init__.py ├── ma.py ├── residual.py └── vim.py ├── requirements.txt ├── losses ├── __init__.py └── SupConLoss.py ├── configs ├── __init__.py ├── base.py ├── spectra_MELD-DA.py ├── spectra_MIntRec.py ├── spectra_IEMOCAP-DA.py ├── mag_bert_MIntRec.py ├── text_IEMOCAP-DA.py ├── text_MELD-DA.py ├── text_MIntRec.py ├── mintood_MIntRec.py ├── mult_IEMOCAP-DA.py ├── mult_MELD-DA.py ├── mult_MIntRec.py ├── mmim_MELD-DA.py ├── tcl_map_MIntRec.py ├── tcl_map_IEMOCAP-DA.py ├── tcl_map_MELD-DA.py ├── sdif_MELD-DA.py ├── sdif_MIntRec.py ├── sdif_IEMOCAP-DA.py ├── mintood_MELD-DA.py ├── mmim_IEMOCAP-DA.py ├── mmim_MIntRec.py ├── mintood_IEMOCAP-DA.py ├── mag_bert_MELD-DA.py └── mag_bert_IEMOCAP-DA.py ├── methods ├── __init__.py ├── TCL_MAP │ └── loss.py ├── MULT │ └── manager.py ├── SDIF │ └── manager.py ├── MAG_BERT │ └── manager.py └── Spectra │ └── manager.py ├── examples ├── run_test.sh └── run_train.sh ├── data ├── mm_pre.py ├── __init__.py └── utils.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /backbones/SubNets/transformers_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ood_detection/msp.py: -------------------------------------------------------------------------------- 1 | 2 | def func(args, inputs): 3 | 4 | scores = inputs['y_prob'] 5 | 6 | return scores -------------------------------------------------------------------------------- /ood_detection/maxlogit.py: -------------------------------------------------------------------------------- 1 | 2 | def func(args, inputs): 3 | 4 | logits = inputs['y_logit'] 5 | scores = logits.max(1) 6 | 7 | return scores -------------------------------------------------------------------------------- /ood_detection/energy.py: -------------------------------------------------------------------------------- 1 | from scipy.special import logsumexp 2 | 3 | 4 | def func(args, inputs): 5 | 6 | logits = inputs['y_logit'] 7 | scores = logsumexp(logits, axis = -1) 8 | 9 | return scores -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.10 2 | numpy==1.23.5 3 | pandas==1.5.3 4 | scikit_learn==1.2.1 5 | scipy==1.13.0 6 | torch==1.13.1 7 | tqdm==4.65.0 8 | tqdm==4.64.1 9 | transformers==4.28.1 10 | transformers==4.26.1 11 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .SupConLoss import SupConLoss 3 | 4 | loss_map = { 5 | 'CrossEntropyLoss': nn.CrossEntropyLoss(), 6 | 'SupConLoss': SupConLoss() 7 | } 8 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | pretrained_models_path = { 2 | 'bert-base-uncased': '/home/sharing/disk1/pretrained_embedding/bert/uncased_L-12_H-768_A-12/', 3 | 'roberta-base': '/home/sharing/disk1/pretrained_embedding/roberta/roberta-base', 4 | } -------------------------------------------------------------------------------- /backbones/SubNets/__init__.py: -------------------------------------------------------------------------------- 1 | from .FeatureNets import BERTEncoder, RoBERTaEncoder 2 | 3 | text_backbones_map = { 4 | 'bert-base-uncased': BERTEncoder, 5 | 'bert-large-uncased': BERTEncoder, 6 | 'roberta-base': RoBERTaEncoder 7 | } -------------------------------------------------------------------------------- /backbones/FusionNets/__init__.py: -------------------------------------------------------------------------------- 1 | from .MAG_BERT import MAG_BERT 2 | from .MMIM import MMIM 3 | from .MULT import MULT 4 | from .TCL_MAP import TCL_MAP 5 | from .SDIF import SDIF 6 | from .Spectra import Spectra 7 | 8 | multimodal_methods_map = { 9 | 'mag_bert': MAG_BERT, 10 | 'mmim': MMIM, 11 | 'mult': MULT, 12 | 'tcl_map': TCL_MAP, 13 | 'sdif': SDIF, 14 | 'spectra': Spectra, 15 | } -------------------------------------------------------------------------------- /ood_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .energy import func as ENERGY 2 | from .ma import func as MA 3 | from .vim import func as VIM 4 | from .maxlogit import func as MAXLOGIT 5 | from .msp import func as MSP 6 | from .residual import func as RESIDUAL 7 | 8 | ood_detection_map = { 9 | 'energy': ENERGY, 10 | 'ma': MA, 11 | 'vim': VIM, 12 | 'maxlogit': MAXLOGIT, 13 | 'msp': MSP, 14 | 'residual': RESIDUAL 15 | } -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .MAG_BERT.manager import MAG_BERT 2 | from .TEXT.manager import TEXT 3 | from .MMIM.manager import MMIM 4 | from .MIntOOD.manager import MIntOOD 5 | from .MULT.manager import MULT 6 | from .TCL_MAP.manager import TCL_MAP 7 | from .SDIF.manager import SDIF 8 | from .Spectra.manager import Spectra 9 | 10 | 11 | 12 | method_map = { 13 | 'mag_bert': MAG_BERT, 14 | 'text': TEXT, 15 | 'mmim': MMIM, 16 | 'mintood': MIntOOD, 17 | 'mult': MULT, 18 | 'tcl_map': TCL_MAP, 19 | 'sdif': SDIF, 20 | 'spectra': Spectra, 21 | } -------------------------------------------------------------------------------- /configs/base.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from easydict import EasyDict 3 | from .__init__ import pretrained_models_path 4 | 5 | 6 | class ParamManager: 7 | 8 | def __init__(self, args): 9 | 10 | args.text_pretrained_model = pretrained_models_path[args.text_backbone] 11 | self.args = EasyDict(dict(vars(args))) 12 | 13 | def add_config_param(old_args, config_file_name = None): 14 | 15 | if config_file_name is None: 16 | config_file_name = old_args.config_file_name 17 | 18 | if config_file_name.endswith('.py'): 19 | module_name = '.' + config_file_name[:-3] 20 | else: 21 | module_name = '.' + config_file_name 22 | 23 | config = importlib.import_module(module_name, 'configs') 24 | 25 | config_param = config.Param 26 | method_args = config_param(old_args) 27 | new_args = EasyDict(dict(old_args, **method_args.hyper_param)) 28 | 29 | return new_args 30 | -------------------------------------------------------------------------------- /ood_detection/ma.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.covariance import EmpiricalCovariance 3 | 4 | def cal_ma_dis(mean, prec, features): 5 | 6 | ma_score = -np.array([(((f - mean)@prec) * (f - mean)).sum(-1).min() for f in features]) 7 | 8 | return ma_score 9 | 10 | def func(args, inputs): 11 | 12 | train_feats = inputs['train_feats'] 13 | train_labels = inputs['train_labels'] 14 | print('1111111111111', np.unique(train_labels)) 15 | 16 | train_means = [] 17 | train_dis = [] 18 | 19 | for l in range(args.num_labels): 20 | fs = train_feats[train_labels == l] 21 | m = fs.mean(axis = 0) 22 | train_means.append(m) 23 | train_dis.extend(fs - m) 24 | 25 | ec = EmpiricalCovariance(assume_centered=True) 26 | ec.fit(np.array(train_dis).astype(np.float64)) 27 | 28 | mean = np.array(train_means) 29 | prec = ec.precision_ ## 协方差的伪逆矩阵 30 | 31 | features = inputs['y_feat'] 32 | scores = cal_ma_dis(mean, prec, features) 33 | 34 | return scores -------------------------------------------------------------------------------- /backbones/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from torch import nn 4 | from .FusionNets import multimodal_methods_map 5 | 6 | __all__ = ['ModelManager'] 7 | 8 | # base backbones 9 | class MIA(nn.Module): 10 | 11 | def __init__(self, args): 12 | 13 | super(MIA, self).__init__() 14 | 15 | fusion_method = multimodal_methods_map[args.multimodal_method] 16 | self.model = fusion_method(args) 17 | 18 | def forward(self, text_feats, video_data, audio_data, *args, **kwargs): 19 | 20 | mm_model = self.model(text_feats, video_data, audio_data, *args, **kwargs) 21 | 22 | return mm_model 23 | 24 | def vim(self): 25 | 26 | return self.model.vim() 27 | 28 | 29 | class ModelManager: 30 | 31 | def __init__(self, args): 32 | 33 | self.logger = logging.getLogger(args.logger_name) 34 | self.device = args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | 36 | def _set_model(self, args): 37 | 38 | model = MIA(args) 39 | model.to(self.device) 40 | return model 41 | -------------------------------------------------------------------------------- /ood_detection/residual.py: -------------------------------------------------------------------------------- 1 | # import numpy as np 2 | # from sklearn.covariance import EmpiricalCovariance 3 | # from numpy.linalg import norm, pinv 4 | 5 | # def func(args, inputs): 6 | 7 | # w = inputs['w'] 8 | # w = w.data.cpu().numpy() 9 | # b = inputs['b'] 10 | # b = b.data.cpu().numpy() 11 | 12 | # train_feats = inputs['train_feats'] 13 | # ## 计算P和α 14 | 15 | # u = -np.matmul(pinv(w), b) 16 | 17 | # ec = EmpiricalCovariance(assume_centered=True) 18 | # ec.fit(train_feats - u) 19 | # eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 20 | # NS = np.ascontiguousarray((eigen_vectors.T[np.argsort(eig_vals * -1)[args.num_labels:]]).T) 21 | 22 | # # 主子空间p的正交补 23 | # features = inputs['y_feat'] 24 | # scores = -norm(np.matmul(features - u, NS), axis=-1) 25 | 26 | # return scores 27 | 28 | import numpy as np 29 | from sklearn.covariance import EmpiricalCovariance 30 | from numpy.linalg import norm, pinv 31 | 32 | def func(args, inputs): 33 | 34 | w = inputs['w'] 35 | w = w.data.cpu().numpy() 36 | b = inputs['b'] 37 | b = b.data.cpu().numpy() 38 | 39 | train_feats = inputs['train_feats'] 40 | ## 计算P和α 41 | 42 | if args.method == 'mmco': 43 | w /= np.linalg.norm(w, ord=2, axis=1, keepdims=True) 44 | train_feats /= np.linalg.norm(train_feats, ord=2, axis=1, keepdims=True) 45 | 46 | u = -np.matmul(pinv(w), b) 47 | 48 | ec = EmpiricalCovariance(assume_centered=True) 49 | ec.fit(train_feats - u) 50 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 51 | NS = np.ascontiguousarray((eigen_vectors.T[np.argsort(eig_vals * -1)[args.num_labels:]]).T) 52 | 53 | # 主子空间p的正交补 54 | features = inputs['y_feat'] 55 | scores = -norm(np.matmul(features - u, NS), axis=-1) 56 | 57 | return scores -------------------------------------------------------------------------------- /configs/spectra_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('roberta'): 29 | 30 | hyper_parameters = { 31 | 'eval_monitor': ['f1'], 32 | 'train_batch_size': 16, 33 | 'eval_batch_size': 8, 34 | 'test_batch_size': 8, 35 | 'wait_patience': 8, 36 | 'num_train_epochs': 100, 37 | ################ 38 | 'warmup_proportion': [0.1], 39 | 'lr': [1e-5], 40 | 'weight_decay': [0.03], 41 | 'scale':32 42 | } 43 | else: 44 | raise ValueError('Not supported text backbone') 45 | 46 | if args.ood_detection_method in ood_detection_parameters.keys(): 47 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 48 | hyper_parameters.update(ood_parameters) 49 | 50 | return hyper_parameters -------------------------------------------------------------------------------- /configs/spectra_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('roberta'): 29 | 30 | hyper_parameters = { 31 | 'eval_monitor': ['f1'], 32 | 'train_batch_size': 16, 33 | 'eval_batch_size': 8, 34 | 'test_batch_size': 8, 35 | 'wait_patience': 8, 36 | 'num_train_epochs': 100, 37 | ################ 38 | 'warmup_proportion': [0.1], 39 | 'lr': [2e-5], 40 | 'weight_decay': [0.03], 41 | 'scale':32, 42 | } 43 | else: 44 | raise ValueError('Not supported text backbone') 45 | 46 | if args.ood_detection_method in ood_detection_parameters.keys(): 47 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 48 | hyper_parameters.update(ood_parameters) 49 | 50 | return hyper_parameters -------------------------------------------------------------------------------- /configs/spectra_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('roberta'): 29 | 30 | hyper_parameters = { 31 | 'eval_monitor': ['f1'], 32 | 'train_batch_size': 16, 33 | 'eval_batch_size': 8, 34 | 'test_batch_size': 8, 35 | 'wait_patience': 8, 36 | 'num_train_epochs': 100, 37 | ################ 38 | 'warmup_proportion': [0.1], 39 | 'lr': [2e-5], 40 | 'weight_decay': [0.03], 41 | 'scale':32 42 | } 43 | else: 44 | raise ValueError('Not supported text backbone') 45 | 46 | if args.ood_detection_method in ood_detection_parameters.keys(): 47 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 48 | hyper_parameters.update(ood_parameters) 49 | 50 | return hyper_parameters -------------------------------------------------------------------------------- /examples/run_test.sh: -------------------------------------------------------------------------------- 1 | ## Methods: mintood text mag_bert mult mmim tcl_map sdif spectra 2 | ## Dataset Configurations: MIntRec+MIntRec-OOD MELD-DA+MELD-DA-OOD IEMOCAP+IEMOCAP-DA-OOD 3 | ## OOD Detection methods: ma vim residual msp ma maxlogit 4 | ## Ablation Types: full text fusion_add fusion_concat sampler_beta wo_contrast wo_cosine wo_binary 5 | ## Note: If using SPECTRA, audio_feats and ood_audio_feats need to use features compatible with WavLM (replace audio_feats_path and ood_audio_feats_path with 'spectra_audio.pkl'). For details, refer to WavLM documentation at https://huggingface.co/docs/transformers/model_doc/wavlm. 6 | 7 | 8 | for method in 'mintood' 9 | do 10 | for text_backbone in 'bert-base-uncased' 11 | do 12 | for ood_dataset in 'MIntRec-OOD' 13 | do 14 | for dataset in 'MIntRec' 15 | do 16 | for ood_detection_method in 'ma' 17 | do 18 | for ablation_type in 'full' 19 | do 20 | python run.py \ 21 | --dataset $dataset \ 22 | --data_path 'Datasets' \ 23 | --ood_dataset $ood_dataset \ 24 | --logger_name ${method}_${ood_detection_method} \ 25 | --multimodal_method $method \ 26 | --method ${method}\ 27 | --ood_detection_method $ood_detection_method \ 28 | --ablation_type $ablation_type \ 29 | --ood \ 30 | --tune \ 31 | --save_results \ 32 | --save_model \ 33 | --gpu_id '0' \ 34 | --video_feats_path 'swin_feats.pkl' \ 35 | --audio_feats_path 'wavlm_feats.pkl' \ 36 | --ood_video_feats_path 'swin_feats.pkl' \ 37 | --ood_audio_feats_path 'wavlm_feats.pkl' \ 38 | --text_backbone $text_backbone \ 39 | --config_file_name ${method}_${dataset} \ 40 | --output_path "outputs" \ 41 | --results_file_name 'results_test.csv' 42 | done 43 | done 44 | done 45 | done 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /examples/run_train.sh: -------------------------------------------------------------------------------- 1 | ## Methods: mintood text mag_bert mult mmim tcl_map sdif spectra 2 | ## Dataset Configurations: MIntRec+MIntRec-OOD MELD-DA+MELD-DA-OOD IEMOCAP+IEMOCAP-DA-OOD 3 | ## OOD Detection methods: ma vim residual msp ma maxlogit 4 | ## Ablation Types: full text fusion_add fusion_concat sampler_beta wo_contrast wo_cosine wo_binary 5 | ## Note: If using SPECTRA, audio_feats and ood_audio_feats need to use features compatible with WavLM (replace audio_feats_path and ood_audio_feats_path with 'spectra_audio.pkl'). For details, refer to WavLM documentation at https://huggingface.co/docs/transformers/model_doc/wavlm. 6 | 7 | 8 | for method in 'sdif' 9 | do 10 | for text_backbone in 'bert-base-uncased' 11 | do 12 | for ood_dataset in 'MIntRec-OOD' 13 | do 14 | for dataset in 'MIntRec' 15 | do 16 | for ood_detection_method in 'ma' 17 | do 18 | for ablation_type in 'full' 19 | do 20 | python run.py \ 21 | --dataset $dataset \ 22 | --data_path '/home/sharing/Datasets' \ 23 | --ood_dataset $ood_dataset \ 24 | --logger_name ${method}_${ood_detection_method} \ 25 | --multimodal_method $method \ 26 | --method ${method}\ 27 | --ood_detection_method $ood_detection_method \ 28 | --ablation_type $ablation_type \ 29 | --train \ 30 | --ood \ 31 | --tune \ 32 | --save_results \ 33 | --save_model \ 34 | --gpu_id '0' \ 35 | --video_feats_path 'swin_feats.pkl' \ 36 | --audio_feats_path 'wavlm_feats.pkl' \ 37 | --ood_video_feats_path 'swin_feats.pkl' \ 38 | --ood_audio_feats_path 'wavlm_feats.pkl' \ 39 | --text_backbone $text_backbone \ 40 | --config_file_name ${method}_${dataset} \ 41 | --output_path "outputs" \ 42 | --results_file_name 'results_mintood_train.csv' 43 | done 44 | done 45 | done 46 | done 47 | done 48 | done 49 | -------------------------------------------------------------------------------- /configs/mag_bert_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('bert'): 29 | 30 | hyper_parameters = { 31 | 'need_aligned': True, 32 | 'eval_monitor': ['f1'], 33 | 'train_batch_size': 16, 34 | 'eval_batch_size': 8, 35 | 'test_batch_size': 8, 36 | 'wait_patience': 8, 37 | 'num_train_epochs': 100, 38 | ################ 39 | 'beta_shift': [0.006], 40 | 'dropout_prob': [0.4], 41 | 'warmup_proportion': [0.1], 42 | 'lr': [2e-5], 43 | 'aligned_method': 'ctc', 44 | 'weight_decay': [0.03], 45 | 'scale':32 46 | } 47 | else: 48 | raise ValueError('Not supported text backbone') 49 | 50 | if args.ood_detection_method in ood_detection_parameters.keys(): 51 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 52 | hyper_parameters.update(ood_parameters) 53 | 54 | return hyper_parameters -------------------------------------------------------------------------------- /data/mm_pre.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | 5 | __all__ = ['MMDataset'] 6 | 7 | class MMDataset(Dataset): 8 | 9 | def __init__(self, label_ids, text_data, video_data, audio_data): 10 | 11 | self.label_ids = label_ids 12 | self.text_data = text_data 13 | self.video_data = video_data 14 | self.audio_data = audio_data 15 | self.size = len(self.text_data) 16 | # print('111111111111', len(self.text_data)) 17 | # print('22222222222', len(self.video_data['feats'])) 18 | # print('3333333333', len(self.audio_data['feats'])) 19 | 20 | 21 | def __len__(self): 22 | return self.size 23 | 24 | def __getitem__(self, index): 25 | # print('1111111', index) 26 | # print('222222', self.label_ids) 27 | sample = { 28 | 'label_ids': torch.tensor(self.label_ids[index]), 29 | 'text_feats': torch.tensor(self.text_data[index]), 30 | 'video_feats': torch.tensor(np.array(self.video_data['feats'][index])), 31 | 'video_lengths': torch.tensor(np.array(self.video_data['lengths'][index])), 32 | 'audio_feats': torch.tensor(np.array(self.audio_data['feats'][index])), 33 | 'audio_lengths': torch.tensor(np.array(self.audio_data['lengths'][index])) 34 | } 35 | return sample 36 | 37 | 38 | class TCL_MAPDataset(Dataset): 39 | 40 | def __init__(self, label_ids, text_feats, video_feats, audio_feats, cons_text_feats, condition_idx): 41 | 42 | 43 | self.label_ids = label_ids 44 | self.text_feats = text_feats 45 | self.cons_text_feats = cons_text_feats 46 | self.condition_idx = condition_idx 47 | self.video_feats = video_feats 48 | self.audio_feats = audio_feats 49 | self.size = len(self.text_feats) 50 | 51 | def __len__(self): 52 | return self.size 53 | 54 | def __getitem__(self, index): 55 | 56 | sample = { 57 | 'label_ids': torch.tensor(self.label_ids[index]), 58 | 'text_feats': torch.tensor(self.text_feats[index]), 59 | 'video_feats': torch.tensor(self.video_feats['feats'][index]), 60 | 'audio_feats': torch.tensor(self.audio_feats['feats'][index]), 61 | 'cons_text_feats': torch.tensor(self.cons_text_feats[index]), 62 | 'condition_idx': torch.tensor(self.condition_idx[index]) 63 | } 64 | return sample -------------------------------------------------------------------------------- /configs/text_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | lr (float): The learning rate of backbone. 17 | """ 18 | ood_detection_parameters = { 19 | 'sbm':{ 20 | 'temperature': [1e4], 21 | 'scale': [20] 22 | }, 23 | 'hub':{ 24 | 'temperature': [1e6], 25 | 'scale': [20], 26 | 'k': [10], 27 | 'alpha': [1.0] 28 | } 29 | } 30 | if args.text_backbone.startswith('bert'): 31 | hyper_parameters = { 32 | 'eval_monitor': ['f1'], 33 | 'train_batch_size': 16, 34 | 'eval_batch_size': 8, 35 | 'test_batch_size': 8, 36 | 'wait_patience': 8, 37 | 'num_train_epochs': [100], 38 | 'multiple_ood': 1, 39 | ################## 40 | 'warmup_proportion': 0.1, 41 | 'lr':0.00002, 42 | 'weight_decay': 0.1, 43 | } 44 | elif args.text_backbone.startswith('roberta'): 45 | hyper_parameters = { 46 | 'eval_monitor': ['acc'], 47 | 'train_batch_size': 16, 48 | 'eval_batch_size': 8, 49 | 'test_batch_size': 8, 50 | 'wait_patience': 8, 51 | 'num_train_epochs': 100, 52 | ################### 53 | 'warmup_proportion': 0.1, 54 | 'lr':0.00002, 55 | 'weight_decay': 0.1, 56 | } 57 | 58 | if args.ood_detection_method in ood_detection_parameters.keys(): 59 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 60 | hyper_parameters.update(ood_parameters) 61 | 62 | return hyper_parameters -------------------------------------------------------------------------------- /configs/text_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | lr (float): The learning rate of backbone. 17 | """ 18 | ood_detection_parameters = { 19 | 'sbm':{ 20 | 'temperature': [1e4], 21 | 'scale': [20] 22 | }, 23 | 'hub':{ 24 | 'temperature': [1e6], 25 | 'scale': [20], 26 | 'k': [10], 27 | 'alpha': [1.0] 28 | } 29 | } 30 | if args.text_backbone.startswith('bert'): 31 | hyper_parameters = { 32 | 'eval_monitor': ['f1'], 33 | 'train_batch_size': 16, 34 | 'eval_batch_size': 8, 35 | 'test_batch_size': 8, 36 | 'wait_patience': 8, 37 | 'num_train_epochs': [100], 38 | 'multiple_ood': 1, 39 | ################## 40 | 'warmup_proportion': 0.1, 41 | 'lr':0.00001, 42 | 'weight_decay': 0.1, 43 | } 44 | elif args.text_backbone.startswith('roberta'): 45 | hyper_parameters = { 46 | 'eval_monitor': ['acc'], 47 | 'train_batch_size': 16, 48 | 'eval_batch_size': 8, 49 | 'test_batch_size': 8, 50 | 'wait_patience': 8, 51 | 'num_train_epochs': 100, 52 | ################### 53 | 'warmup_proportion': 0.1, 54 | 'lr':0.00001, 55 | 'weight_decay': 0.1, 56 | } 57 | 58 | if args.ood_detection_method in ood_detection_parameters.keys(): 59 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 60 | hyper_parameters.update(ood_parameters) 61 | 62 | return hyper_parameters -------------------------------------------------------------------------------- /configs/text_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | lr (float): The learning rate of backbone. 17 | """ 18 | ood_detection_parameters = { 19 | 'sbm':{ 20 | 'temperature': [1e4], 21 | 'scale': [20] 22 | }, 23 | 'hub':{ 24 | 'temperature': [1e6], 25 | 'scale': [20], 26 | 'k': [10], 27 | 'alpha': [1.0] 28 | } 29 | } 30 | if args.text_backbone.startswith('bert'): 31 | hyper_parameters = { 32 | 'eval_monitor': ['f1'], 33 | 'train_batch_size': 16, 34 | 'eval_batch_size': 8, 35 | 'test_batch_size': 8, 36 | 'wait_patience': 8, 37 | 'num_train_epochs': [100], 38 | 'multiple_ood': 1, 39 | ################## 40 | 'warmup_proportion': 0.1, 41 | 'lr':0.00002, 42 | 'weight_decay': 0.1, 43 | } 44 | elif args.text_backbone.startswith('roberta'): 45 | hyper_parameters = { 46 | 'eval_monitor': ['acc'], 47 | 'train_batch_size': 16, 48 | 'eval_batch_size': 8, 49 | 'test_batch_size': 8, 50 | 'wait_patience': 8, 51 | 'num_train_epochs': 100, 52 | ################### 53 | 'warmup_proportion': 0.1, 54 | 'lr':0.00002, 55 | 'weight_decay': 0.1, 56 | } 57 | 58 | if args.ood_detection_method in ood_detection_parameters.keys(): 59 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 60 | hyper_parameters.update(ood_parameters) 61 | 62 | return hyper_parameters -------------------------------------------------------------------------------- /configs/mintood_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | hyper_parameters = { 24 | 'need_aligned': True, 25 | 'freeze_parameters': False, 26 | 'eval_monitor': 'f1', 27 | 'eval_batch_size': 16, 28 | 'wait_patience': [8], 29 | 'binary_multiple_ood': 1.0, 30 | 'base_dim': [768], 31 | 'lr': [3e-5], #3e-5 32 | 'temperature': [2], # bigger is usually better 33 | 'alpha': [2], #0.5, 1 34 | 'mlp_hidden_size': [256], 35 | 'mlp_dropout': [0.1], 36 | 're_prob': [0.1], 37 | 'num_train_epochs': [100], # [30, 40, 50] 38 | 'train_batch_size': [32], # [32, 64, 128] 39 | 'weight_decay': [0.01], # [0.01, 0.05, 0.1] 40 | 'multiple_ood': [1.0], # try average number 41 | 'contrast_dropout': [0.1], 42 | 'select_number_min': [2], 43 | 'select_number_max': [3], 44 | 'weight_dropout': [0.5], 45 | 'weight_hidden_dim': [256], 46 | # 'weight': [2, 3], 47 | 'aligned_method': ['ctc'], 48 | 'warmup_proportion': [0.1], 49 | 'scale': [16], 50 | 'encoder_layers_a': [1], 51 | 'encoder_layers_v': [2], 52 | 'attn_dropout': [0.0], 53 | 'relu_dropout': [0.1], 54 | 'embed_dropout': [0.0], 55 | 'res_dropout': [0.2], #0 56 | 'attn_mask': [False], #True 57 | 'nheads': [2], #4 58 | } 59 | 60 | return hyper_parameters -------------------------------------------------------------------------------- /ood_detection/vim.py: -------------------------------------------------------------------------------- 1 | # import numpy as np 2 | # from sklearn.covariance import EmpiricalCovariance 3 | # from numpy.linalg import norm, pinv 4 | # from scipy.special import logsumexp 5 | 6 | # def func(args, inputs): 7 | 8 | # w = inputs['w'] 9 | # w = w.data.cpu().numpy() 10 | # b = inputs['b'] 11 | # b = b.data.cpu().numpy() 12 | 13 | # train_feats = inputs['train_feats'] 14 | # train_logits = train_feats @ w.T + b 15 | 16 | # u = -np.matmul(pinv(w), b) 17 | 18 | # ec = EmpiricalCovariance(assume_centered=True) 19 | # ec.fit(train_feats - u) 20 | # eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 21 | # NS = np.ascontiguousarray((eigen_vectors.T[np.argsort(eig_vals * -1)[args.num_labels:]]).T) 22 | # # 主子空间p的正交补 23 | 24 | # vlogit_ind_train = norm(np.matmul(train_feats - u, NS), axis=-1) 25 | # alpha = train_logits.max(axis=-1).mean() / vlogit_ind_train.mean() 26 | 27 | # features = inputs['y_feat'] 28 | # logit = features @ w.T + b 29 | # vlogit = norm(np.matmul(features - u, NS), axis=-1) * alpha 30 | # energy = logsumexp(logit, axis=-1) 31 | # scores = -vlogit + energy 32 | 33 | # return scores 34 | import numpy as np 35 | from sklearn.covariance import EmpiricalCovariance 36 | from numpy.linalg import norm, pinv 37 | from scipy.special import logsumexp 38 | 39 | def func(args, inputs): 40 | 41 | w = inputs['w'] 42 | w = w.data.cpu().numpy() 43 | 44 | b = inputs['b'] 45 | b = b.data.cpu().numpy() 46 | 47 | train_feats = inputs['train_feats'] 48 | print(w.shape) 49 | if args.method == 'mmco': 50 | w /= np.linalg.norm(w, ord=2, axis=1, keepdims=True) 51 | train_feats /= np.linalg.norm(train_feats, ord=2, axis=1, keepdims=True) 52 | 53 | train_logits = (train_feats * args.scale) @ w.T + b # 54 | 55 | u = -np.matmul(pinv(w), b) 56 | 57 | ec = EmpiricalCovariance(assume_centered=True) 58 | ec.fit(train_feats - u) 59 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 60 | NS = np.ascontiguousarray((eigen_vectors.T[np.argsort(eig_vals * -1)[args.num_labels:]]).T) 61 | # 主子空间p的正交补 62 | 63 | if args.method == 'mmco': 64 | vlogit_ind_train = norm(np.matmul((train_feats - u) * args.scale, NS), axis=-1) 65 | else: 66 | vlogit_ind_train = norm(np.matmul(train_feats - u, NS), axis=-1) 67 | 68 | alpha = train_logits.max(axis=-1).mean() / vlogit_ind_train.mean() 69 | 70 | features = inputs['y_feat'] 71 | if args.method == 'mmco': 72 | logit = (features * args.scale) @ w.T + b # 73 | vlogit = norm(np.matmul((features - u) * args.scale, NS), axis=-1) * alpha 74 | else: 75 | logit = features @ w.T + b # 76 | vlogit = norm(np.matmul(features - u, NS), axis=-1) * alpha 77 | energy = logsumexp(logit, axis=-1) 78 | scores = -vlogit + energy 79 | 80 | return scores -------------------------------------------------------------------------------- /configs/mult_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | dst_feature_dims (int): The destination dimensions (assume d(l) = d(v) = d(t)). 12 | nheads (int): The number of heads for the transformer network. 13 | n_levels (int): The number of layers in the network. 14 | attn_dropout (float): The attention dropout. 15 | attn_dropout_v (float): The attention dropout for the video modality. 16 | attn_dropout_a (float): The attention dropout for the audio modality. 17 | relu_dropout (float): The relu dropout. 18 | embed_dropout (float): The embedding dropout. 19 | res_dropout (float): The residual block dropout. 20 | output_dropout (float): The output layer dropout. 21 | text_dropout (float): The dropout for text features. 22 | grad_clip (float): The gradient clip value. 23 | attn_mask (bool): Whether to use attention mask for Transformer. 24 | conv1d_kernel_size_l (int): The kernel size for temporal convolutional layers (text modality). 25 | conv1d_kernel_size_v (int): The kernel size for temporal convolutional layers (video modality). 26 | conv1d_kernel_size_a (int): The kernel size for temporal convolutional layers (audio modality). 27 | lr (float): The learning rate of backbone. 28 | """ 29 | 30 | ood_detection_parameters = { 31 | 'sbm':{ 32 | 'temperature': [1e6], 33 | 'scale': [20] 34 | }, 35 | 'hub':{ 36 | 'temperature': [1e6], 37 | 'scale': [20], 38 | 'k': [50], 39 | 'alpha': [1.0] 40 | } 41 | } 42 | 43 | hyper_parameters = { 44 | 'padding_mode': 'zero', 45 | 'padding_loc': 'end', 46 | 'need_aligned': False, 47 | 'eval_monitor': ['f1'], 48 | 'train_batch_size': 16, 49 | 'eval_batch_size': 8, 50 | 'test_batch_size': 8, 51 | 'wait_patience': 8, 52 | 'num_train_epochs': [100], 53 | 'dst_feature_dims': [120], # 80 54 | 'nheads': [12], #4 55 | 'n_levels': [2], #8 56 | 'attn_dropout': 0.0, 57 | 'attn_dropout_v': 0.2, #0.2 58 | 'attn_dropout_a': 0.1, #0.2 59 | 'relu_dropout': 0.0, 60 | 'embed_dropout': 0.1, 61 | 'res_dropout': 0.0, #0 62 | 'output_dropout': 0.2, #0.2 63 | 'text_dropout': [0.4], #0.4 64 | 'grad_clip': 0.5, 65 | 'attn_mask': [True], #True 66 | 'conv1d_kernel_size_l': 6, #5 67 | 'conv1d_kernel_size_v': 1, #1 68 | 'conv1d_kernel_size_a': 1, #1 69 | 'lr': [0.00002], # 5e-6 70 | 'scale':20 71 | } 72 | return hyper_parameters 73 | 74 | -------------------------------------------------------------------------------- /configs/mult_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | dst_feature_dims (int): The destination dimensions (assume d(l) = d(v) = d(t)). 12 | nheads (int): The number of heads for the transformer network. 13 | n_levels (int): The number of layers in the network. 14 | attn_dropout (float): The attention dropout. 15 | attn_dropout_v (float): The attention dropout for the video modality. 16 | attn_dropout_a (float): The attention dropout for the audio modality. 17 | relu_dropout (float): The relu dropout. 18 | embed_dropout (float): The embedding dropout. 19 | res_dropout (float): The residual block dropout. 20 | output_dropout (float): The output layer dropout. 21 | text_dropout (float): The dropout for text features. 22 | grad_clip (float): The gradient clip value. 23 | attn_mask (bool): Whether to use attention mask for Transformer. 24 | conv1d_kernel_size_l (int): The kernel size for temporal convolutional layers (text modality). 25 | conv1d_kernel_size_v (int): The kernel size for temporal convolutional layers (video modality). 26 | conv1d_kernel_size_a (int): The kernel size for temporal convolutional layers (audio modality). 27 | lr (float): The learning rate of backbone. 28 | """ 29 | 30 | ood_detection_parameters = { 31 | 'sbm':{ 32 | 'temperature': [1e6], 33 | 'scale': [20] 34 | }, 35 | 'hub':{ 36 | 'temperature': [1e6], 37 | 'scale': [20], 38 | 'k': [50], 39 | 'alpha': [1.0] 40 | } 41 | } 42 | 43 | hyper_parameters = { 44 | 'padding_mode': 'zero', 45 | 'padding_loc': 'end', 46 | 'need_aligned': False, 47 | 'eval_monitor': ['f1'], 48 | 'train_batch_size': 16, 49 | 'eval_batch_size': 8, 50 | 'test_batch_size': 8, 51 | 'wait_patience': 8, 52 | 'num_train_epochs': [100], 53 | 'dst_feature_dims': [120], # 80 54 | 'nheads': [8], #4 55 | 'n_levels': [8], #8 56 | 'attn_dropout': 0.0, 57 | 'attn_dropout_v': 0.2, #0.2 58 | 'attn_dropout_a': 0.2, #0.2 59 | 'relu_dropout': 0.0, 60 | 'embed_dropout': 0.1, 61 | 'res_dropout': 0.0, #0 62 | 'output_dropout': 0.2, #0.2 63 | 'text_dropout': [0.4], #0.4 64 | 'grad_clip': 0.5, 65 | 'attn_mask': [True], #True 66 | 'conv1d_kernel_size_l': 5, #5 67 | 'conv1d_kernel_size_v': 1, #1 68 | 'conv1d_kernel_size_a': 1, #1 69 | 'lr': [0.00003], # 5e-6 70 | 'scale':20 71 | } 72 | return hyper_parameters 73 | 74 | -------------------------------------------------------------------------------- /configs/mult_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | dst_feature_dims (int): The destination dimensions (assume d(l) = d(v) = d(t)). 12 | nheads (int): The number of heads for the transformer network. 13 | n_levels (int): The number of layers in the network. 14 | attn_dropout (float): The attention dropout. 15 | attn_dropout_v (float): The attention dropout for the video modality. 16 | attn_dropout_a (float): The attention dropout for the audio modality. 17 | relu_dropout (float): The relu dropout. 18 | embed_dropout (float): The embedding dropout. 19 | res_dropout (float): The residual block dropout. 20 | output_dropout (float): The output layer dropout. 21 | text_dropout (float): The dropout for text features. 22 | grad_clip (float): The gradient clip value. 23 | attn_mask (bool): Whether to use attention mask for Transformer. 24 | conv1d_kernel_size_l (int): The kernel size for temporal convolutional layers (text modality). 25 | conv1d_kernel_size_v (int): The kernel size for temporal convolutional layers (video modality). 26 | conv1d_kernel_size_a (int): The kernel size for temporal convolutional layers (audio modality). 27 | lr (float): The learning rate of backbone. 28 | """ 29 | 30 | ood_detection_parameters = { 31 | 'sbm':{ 32 | 'temperature': [1e6], 33 | 'scale': [20] 34 | }, 35 | 'hub':{ 36 | 'temperature': [1e6], 37 | 'scale': [20], 38 | 'k': [50], 39 | 'alpha': [1.0] 40 | } 41 | } 42 | 43 | hyper_parameters = { 44 | 'padding_mode': 'zero', 45 | 'padding_loc': 'end', 46 | 'need_aligned': False, 47 | 'eval_monitor': ['f1'], 48 | 'train_batch_size': 16, 49 | 'eval_batch_size': 8, 50 | 'test_batch_size': 8, 51 | 'wait_patience': 8, 52 | 'num_train_epochs': [100], 53 | 'dst_feature_dims': [120], # 80 54 | 'nheads': [8], #4 55 | 'n_levels': [8], #8 56 | 'attn_dropout': 0.0, 57 | 'attn_dropout_v': 0.2, #0.2 58 | 'attn_dropout_a': 0.2, #0.2 59 | 'relu_dropout': 0.0, 60 | 'embed_dropout': 0.1, 61 | 'res_dropout': 0.0, #0 62 | 'output_dropout': 0.2, #0.2 63 | 'text_dropout': [0.4], #0.4 64 | 'grad_clip': 0.5, 65 | 'attn_mask': [True], #True 66 | 'conv1d_kernel_size_l': 5, #5 67 | 'conv1d_kernel_size_v': 1, #1 68 | 'conv1d_kernel_size_a': 1, #1 69 | 'lr': [0.00003], # 5e- 70 | 'scale': 20 71 | } 72 | return hyper_parameters 73 | 74 | -------------------------------------------------------------------------------- /configs/mmim_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | self.hyper_param = self._get_hyper_parameters(args) 5 | 6 | def _get_hyper_parameters(self, args): 7 | ''' 8 | Args: 9 | num_train_epochs (int): The number of training epochs. 10 | num_labels (autofill): The output dimension. 11 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 12 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 13 | feat_dim (int): The feature dimension. 14 | warmup_proportion (float): The warmup ratio for learning rate. 15 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 16 | train_batch_size (int): The batch size for training. 17 | eval_batch_size (int): The batch size for evaluation. 18 | test_batch_size (int): The batch size for testing. 19 | wait_patient (int): Patient steps for Early Stop. 20 | ''' 21 | ood_detection_parameters = { 22 | 'sbm':{ 23 | 'temperature': [1e6], 24 | 'scale': [20] 25 | }, 26 | 'hub':{ 27 | 'temperature': [1e6], 28 | 'scale': [20], 29 | 'k': [10], 30 | 'alpha': [0.5] 31 | } 32 | } 33 | if args.text_backbone.startswith('bert'): 34 | hyper_parameters = { 35 | 'need_aligned': False, 36 | 'eval_monitor': ['f1'], 37 | 'train_batch_size': 16, 38 | 'eval_batch_size': 8, 39 | 'test_batch_size': 8, 40 | 'wait_patience': 8, 41 | 'num_train_epochs': 100, 42 | ##################### 43 | 'gamma': [1], 44 | 'theta': [1], 45 | 'add_va': False, 46 | 'cpc_activation': 'Tanh', 47 | 'mmilb_mid_activation': 'ReLU', 48 | 'mmilb_last_activation': 'Tanh', 49 | 'optim': 'Adam', 50 | 'contrast': True, 51 | 'bidirectional': True, 52 | 'grad_clip': [0.9], 53 | 'lr_main': [2e-5], 54 | 'weight_decay_main': [0.0001], 55 | 'lr_bert': [2e-5], 56 | 'weight_decay_bert': [0.001], 57 | 'lr_mmilb': [0.001], 58 | 'weight_decay_mmilb': [0.0004], 59 | 'alpha': [0.4], 60 | 'beta': [0.1], 61 | 'dropout_a': [0.2], #0.1 62 | 'dropout_v': [0.2], #0.1 63 | 'dropout_prj': [0.2], 64 | 'n_layer': [2], 65 | 'cpc_layers': [4], 66 | 'd_vh': [8], 67 | 'd_ah': [4], 68 | 'd_vout': [4], 69 | 'd_aout': [4], 70 | 'd_prjh': [512], 71 | 'scale': [20] 72 | } 73 | # if args.ood_detection_method in ood_detection_parameters.keys(): 74 | # ood_parameters = ood_detection_parameters[args.ood_detection_method] 75 | # hyper_parameters.update(ood_parameters) 76 | 77 | return hyper_parameters -------------------------------------------------------------------------------- /configs/tcl_map_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('bert'): 29 | 30 | hyper_parameters = { 31 | 'feats_processing_type': 'padding', 32 | 'padding_mode': 'zero', 33 | 'padding_loc': 'end', 34 | 'need_aligned': True, 35 | 'eval_monitor': ['f1'], 36 | 'train_batch_size': 16, 37 | 'eval_batch_size': 8, 38 | 'test_batch_size': 8, 39 | 'wait_patience': 8, 40 | 'num_train_epochs': 100, 41 | # method parameters 42 | 'warmup_proportion': 0.1, 43 | 'grad_clip': [-1.0], 44 | 'lr': [2e-5], 45 | 'weight_decay': 0.1, 46 | 'mag_aligned_method': ['ctc'], 47 | # parameters of similarity-based modality alignment 48 | 'aligned_method': ['sim'], 49 | 'shared_dim': [256], 50 | 'eps': 1e-9, 51 | # parameters of NT-Xent 52 | 'loss': 'SupCon', 53 | 'temperature': [0.5], 54 | # parameters of multimodal fusion 55 | 'beta_shift': [0.006], 56 | 'dropout_prob': [0.5], 57 | # parameters of modality-aware prompting 58 | 'use_ctx': True, 59 | 'prompt_len': 3, 60 | 'nheads': [8], 61 | 'n_levels': [5], 62 | 'attn_dropout': [0.1], 63 | 'relu_dropout': 0.0, 64 | 'embed_dropout': [0.2], 65 | 'res_dropout': 0.1, 66 | 'attn_mask': True, 67 | 68 | } 69 | else: 70 | raise ValueError('Not supported text backbone') 71 | 72 | if args.ood_detection_method in ood_detection_parameters.keys(): 73 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 74 | hyper_parameters.update(ood_parameters) 75 | 76 | return hyper_parameters -------------------------------------------------------------------------------- /configs/tcl_map_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('bert'): 29 | 30 | hyper_parameters = { 31 | # common parameters 32 | 'feats_processing_type': 'padding', 33 | 'padding_mode': 'zero', 34 | 'padding_loc': 'end', 35 | 'need_aligned': True, 36 | 'eval_monitor': ['f1'], 37 | 'train_batch_size': 16, 38 | 'eval_batch_size': 8, 39 | 'test_batch_size': 8, 40 | 'wait_patience': 8, 41 | 'num_train_epochs': 100, 42 | # method parameters 43 | 'warmup_proportion': 0.1, 44 | 'grad_clip': [0.4], 45 | 'lr': [3e-5], 46 | 'weight_decay': 0.1, 47 | 'mag_aligned_method': ['ctc'], 48 | # parameters of similarity-based modality alignment 49 | 'aligned_method': ['sim'], 50 | 'shared_dim': [256], 51 | 'eps': 1e-9, 52 | # parameters of NT-Xent 53 | 'loss': 'SupCon', 54 | 'temperature': [0.07], 55 | # parameters of multimodal fusion 56 | 'beta_shift': [0.006], 57 | 'dropout_prob': [0.4], 58 | # parameters of modality-aware prompting 59 | 'use_ctx': True, 60 | 'ctx_len': 3, 61 | 'nheads': [8], 62 | 'n_levels': [5], 63 | 'attn_dropout': [0.1], 64 | 'relu_dropout': 0.0, 65 | 'embed_dropout': [0.2], 66 | 'res_dropout': 0.1, 67 | 'attn_mask': True, 68 | } 69 | else: 70 | raise ValueError('Not supported text backbone') 71 | 72 | if args.ood_detection_method in ood_detection_parameters.keys(): 73 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 74 | hyper_parameters.update(ood_parameters) 75 | 76 | return hyper_parameters -------------------------------------------------------------------------------- /configs/tcl_map_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | """ 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | """ 22 | ood_detection_parameters = { 23 | 'sbm':{ 24 | 'temperature': [1e4], 25 | 'scale': [20] 26 | } 27 | } 28 | if args.text_backbone.startswith('bert'): 29 | 30 | hyper_parameters = { 31 | # common parameters 32 | 'feats_processing_type': 'padding', 33 | 'padding_mode': 'zero', 34 | 'padding_loc': 'end', 35 | 'need_aligned': True, 36 | 'eval_monitor': ['f1'], 37 | 'train_batch_size': 16, 38 | 'eval_batch_size': 8, 39 | 'test_batch_size': 8, 40 | 'wait_patience': 8, 41 | 'num_train_epochs': 100, 42 | # method parameters 43 | 'warmup_proportion': 0.1, 44 | 'grad_clip': [0.4], 45 | 'lr': [3e-5], 46 | 'weight_decay': 0.1, 47 | 'mag_aligned_method': ['ctc'], 48 | # parameters of similarity-based modality alignment 49 | 'aligned_method': ['sim'], 50 | 'shared_dim': [256], 51 | 'eps': 1e-9, 52 | # parameters of NT-Xent 53 | 'loss': 'SupCon', 54 | 'temperature': [0.07], 55 | # parameters of multimodal fusion 56 | 'beta_shift': [0.006], 57 | 'dropout_prob': [0.4], 58 | # parameters of modality-aware prompting 59 | 'use_ctx': True, 60 | 'ctx_len': 3, 61 | 'nheads': [8], 62 | 'n_levels': [5], 63 | 'attn_dropout': [0.1], 64 | 'relu_dropout': 0.0, 65 | 'embed_dropout': [0.2], 66 | 'res_dropout': 0.1, 67 | 'attn_mask': True, 68 | } 69 | else: 70 | raise ValueError('Not supported text backbone') 71 | 72 | if args.ood_detection_method in ood_detection_parameters.keys(): 73 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 74 | hyper_parameters.update(ood_parameters) 75 | 76 | return hyper_parameters -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Classification and Out-of-distribution Detection for Multimodal Intent Understanding 2 | 3 | ## 1. Introduction 4 | 5 | This repository contains the official PyTorch implementation of the research paper [Multimodal Classification and Out-of-distribution Detection for Multimodal Intent Understanding (Accepted by IEEE TMM)](https://arxiv.org/abs/2412.12453). 6 | 7 | Multimodal intent understanding is a substantial field that needs to utilize nonverbal modalities effectively in analyzing human language. However, there is a lack of research focused on adapting these methods to real-world scenarios with out-of-distribution (OOD) samples, a key aspect in creating robust and secure systems. In this paper, we propose MIntOOD, a module for OOD detection in multimodal intent understanding. 8 | 9 | ## 2. Dependencies 10 | 11 | We use anaconda to create python environment: 12 | 13 | ``` 14 | conda create --name python=3.9 15 | ``` 16 | 17 | Install all required libraries: 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## 3. Usage 24 | 25 | The data can be downloaded through the following link: 26 | 27 | ``` 28 | https://cloud.tsinghua.edu.cn/d/58d20e316df24700ae8a/ 29 | ``` 30 | 31 | The downloaded data contains two folders for each dataset, representing the ID data and OOD data. Taking MIntRec as an example, its directory structure is as follows: 32 | 33 | ``` 34 | -TMM_MIntOOD 35 | -MIntRec 36 | -video_data 37 | -swin_feats.pkl 38 | -audio_data 39 | -wavlm_feats.pkl 40 | -spectra_audio.pkl 41 | -train.tsv 42 | -dev.tsv 43 | -test.tsv 44 | 45 | -MIntRec-OOD 46 | -video_data 47 | -swin_feats.pkl 48 | -audio_data 49 | -wavlm_feats.pkl 50 | -spectra_audio.pkl 51 | -test.tsv 52 | 53 | ... 54 | ``` 55 | 56 | Notably, the `video_data` and `audio_data` in MELD-DA-OOD and IEMOCAP-OOD are the same as those in MELD-DA and IEMOCAP. Therefore, after decompression, the following command needs to be executed to copy `video_data` and `audio_data`: 57 | 58 | ``` 59 | cp -r MELD-DA/video_data MELD-DA-OOD/ 60 | cp -r MELD-DA/audio_data MELD-DA-OOD/ 61 | cp -r IEMOCAP/video_data IEMOCAP-OOD/ 62 | cp -r IEMOCAP/audio_data IEMOCAP-OOD/ 63 | ``` 64 | 65 | You can evaluate the performance of our proposed MIntOOD under different settings by using the following commands and changing the parameters: 66 | 67 | ``` 68 | sh examples/run_train.sh 69 | sh examples/run_test.sh 70 | 71 | # Parameters in *.sh file 72 | ## Methods: mintood text mag_bert mult mmim tcl_map sdif spectra 73 | ## Dataset Configurations: MIntRec+MIntRec-OOD MELD-DA+MELD-DA-OOD IEMOCAP+IEMOCAP-DA-OOD 74 | ## OOD Detection methods: ma vim residual msp ma maxlogit 75 | ## Ablation Types: full text fusion_add fusion_concat sampler_beta wo_contrast wo_cosine wo_binary 76 | ## Note: If using SPECTRA, audio_feats and ood_audio_feats need to use features compatible with WavLM (replace audio_feats_path and ood_audio_feats_path with 'spectra_audio.pkl'). For details, refer to WavLM documentation at https://huggingface.co/docs/transformers/model_doc/wavlm. 77 | ``` 78 | 79 | You can change the hyper-parameters in the **configs** folder. The default hyper-parameters are the best on three datasets. 80 | 81 | ## Citations 82 | 83 | If you are insterested in this work, and want to use the codes or results in this repository, please **star** this repository and **cite** the following works: 84 | ``` 85 | @article{zhang2024mintood, 86 | title={Multimodal Classification and Out-of-distribution Detection for Multimodal Intent Understanding}, 87 | author={Hanlei Zhang and Qianrui Zhou and Hua Xu and Jianhua Su and Roberto Evans and Kai Gao}, 88 | year={2025}, 89 | journal={IEEE Transactions on Multimedia}, 90 | } 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /methods/TCL_MAP/loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class SupConLoss(nn.Module): 6 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 7 | It also supports the unsupervised contrastive loss in SimCLR""" 8 | def __init__(self, temperature=0.07, contrast_mode='all'): 9 | super(SupConLoss, self).__init__() 10 | self.temperature = temperature 11 | self.contrast_mode = contrast_mode 12 | 13 | def forward(self, features, labels=None, mask=None): 14 | """Compute loss for model. If both `labels` and `mask` are None, 15 | it degenerates to SimCLR unsupervised loss: 16 | https://arxiv.org/pdf/2002.05709.pdf 17 | 18 | Args: 19 | features: hidden vector of shape [bsz, n_views, ...]. 20 | labels: ground truth of shape [bsz]. 21 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 22 | has the same class as sample i. Can be asymmetric. 23 | Returns: 24 | A loss scalar. 25 | """ 26 | device = (torch.device('cuda') 27 | if features.is_cuda 28 | else torch.device('cpu')) 29 | 30 | if len(features.shape) < 3: 31 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 32 | 'at least 3 dimensions are required') 33 | if len(features.shape) > 3: 34 | features = features.view(features.shape[0], features.shape[1], -1) 35 | 36 | features = F.normalize(features, dim=2) 37 | batch_size = features.shape[0] 38 | if labels is not None and mask is not None: 39 | raise ValueError('Cannot define both `labels` and `mask`') 40 | elif labels is None and mask is None: 41 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 42 | elif labels is not None: 43 | labels = labels.contiguous().view(-1, 1) 44 | if labels.shape[0] != batch_size: 45 | raise ValueError('Num of labels does not match num of features') 46 | mask = torch.eq(labels, labels.T).float().to(device) 47 | else: 48 | mask = mask.float().to(device) 49 | 50 | contrast_count = features.shape[1] 51 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 52 | if self.contrast_mode == 'one': 53 | anchor_feature = features[:, 0] 54 | anchor_count = 1 55 | elif self.contrast_mode == 'all': 56 | anchor_feature = contrast_feature 57 | anchor_count = contrast_count 58 | else: 59 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 60 | 61 | # compute logits 62 | anchor_dot_contrast = torch.div( 63 | torch.matmul(anchor_feature, contrast_feature.T), 64 | self.temperature) 65 | # for numerical stability 66 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 67 | logits = anchor_dot_contrast - logits_max.detach() 68 | 69 | 70 | # tile mask 71 | mask = mask.repeat(anchor_count, contrast_count) 72 | # mask-out self-contrast cases 73 | logits_mask = torch.scatter( 74 | torch.ones_like(mask), 75 | 1, 76 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 77 | 0 78 | ) 79 | mask = mask * logits_mask 80 | 81 | # compute log_prob 82 | exp_logits = torch.exp(logits) * logits_mask 83 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 84 | 85 | # compute mean of log-likelihood over positive 86 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 87 | 88 | # loss 89 | loss = - mean_log_prob_pos 90 | loss = loss.view(anchor_count, batch_size).mean() 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /losses/SupConLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class SupConLoss(nn.Module): 5 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 6 | It also supports the unsupervised contrastive loss in SimCLR""" 7 | def __init__(self, contrast_mode='all'): 8 | super(SupConLoss, self).__init__() 9 | self.contrast_mode = contrast_mode 10 | 11 | def forward(self, features, labels=None, mask=None, temperature = 0.07, device = None): 12 | """Compute loss for model. If both `labels` and `mask` are None, 13 | it degenerates to SimCLR unsupervised loss: 14 | https://arxiv.org/pdf/2002.05709.pdf 15 | Args: 16 | features: hidden vector of shape [bsz, n_views, ...]. 17 | labels: ground truth of shape [bsz]. 18 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 19 | has the same class as sample i. Can be asymmetric. 20 | Returns: 21 | A loss scalar. 22 | """ 23 | # device = (torch.device('cuda') 24 | # if features.is_cuda 25 | # else torch.device('cpu')) 26 | 27 | if len(features.shape) < 3: 28 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 29 | 'at least 3 dimensions are required') 30 | if len(features.shape) > 3: 31 | features = features.view(features.shape[0], features.shape[1], -1) 32 | 33 | batch_size = features.shape[0] 34 | 35 | if labels is not None and mask is not None: 36 | raise ValueError('Cannot define both `labels` and `mask`') 37 | elif labels is None and mask is None: 38 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 39 | elif labels is not None: 40 | labels = labels.contiguous().view(-1, 1) 41 | if labels.shape[0] != batch_size: 42 | raise ValueError('Num of labels does not match num of features') 43 | mask = torch.eq(labels, labels.T).float().to(device) 44 | else: 45 | mask = mask.float().to(device) 46 | 47 | contrast_count = features.shape[1] 48 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 49 | 50 | if self.contrast_mode == 'one': 51 | anchor_feature = features[:, 0] 52 | anchor_count = 1 53 | elif self.contrast_mode == 'all': 54 | anchor_feature = contrast_feature 55 | anchor_count = contrast_count 56 | else: 57 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 58 | 59 | # compute logits 60 | anchor_dot_contrast = torch.div( 61 | torch.matmul(anchor_feature, contrast_feature.T), 62 | temperature) 63 | 64 | # for numerical stability 65 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 66 | logits = anchor_dot_contrast - logits_max.detach() 67 | 68 | # tile mask 69 | mask = mask.repeat(anchor_count, contrast_count) 70 | # mask-out self-contrast cases 71 | logits_mask = torch.scatter( 72 | torch.ones_like(mask), 73 | 1, 74 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 75 | 0 76 | ) 77 | 78 | mask = mask * logits_mask 79 | 80 | # compute log_prob 81 | exp_logits = torch.exp(logits) * logits_mask 82 | # print("000000",exp_logits) 83 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 84 | # print("222222",torch.log(exp_logits.sum(1, keepdim=True))) 85 | 86 | # compute mean of log-likelihood over positive 87 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 88 | 89 | # loss 90 | loss = - mean_log_prob_pos 91 | loss = loss.view(anchor_count, batch_size).mean() 92 | 93 | return loss 94 | -------------------------------------------------------------------------------- /backbones/SubNets/transformers_encoder/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Code adapted from the fairseq repo. 7 | 8 | def make_positions(tensor, padding_idx, left_pad): 9 | """Replace non-padding symbols with their position numbers. 10 | Position numbers begin at padding_idx+1. 11 | Padding symbols are ignored, but it is necessary to specify whether padding 12 | is added on the left side (left_pad=True) or right side (left_pad=False). 13 | """ 14 | max_pos = padding_idx + 1 + tensor.size(1) 15 | device = tensor.get_device() 16 | buf_name = f'range_buf_{device}' 17 | if not hasattr(make_positions, buf_name): 18 | setattr(make_positions, buf_name, tensor.new()) 19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor)) 20 | if getattr(make_positions, buf_name).numel() < max_pos: 21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name)) 22 | mask = tensor.ne(padding_idx) 23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor) 24 | if left_pad: 25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) 26 | new_tensor = tensor.clone() 27 | return new_tensor.masked_scatter_(mask, positions[mask]).long() 28 | 29 | 30 | class SinusoidalPositionalEmbedding(nn.Module): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored, but it is necessary to specify whether padding 33 | is added on the left side (left_pad=True) or right side (left_pad=False). 34 | """ 35 | 36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128): 37 | super().__init__() 38 | self.embedding_dim = embedding_dim 39 | self.padding_idx = padding_idx 40 | self.left_pad = left_pad 41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-( 42 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 43 | 44 | @staticmethod 45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | """Build sinusoidal embeddings. 47 | This matches the implementation in tensor2tensor, but differs slightly 48 | from the description in Section 3.5 of "Attention Is All You Need". 49 | """ 50 | half_dim = embedding_dim // 2 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 55 | if embedding_dim % 2 == 1: 56 | # zero pad 57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 58 | if padding_idx is not None: 59 | emb[padding_idx, :] = 0 60 | return emb 61 | 62 | def forward(self, input): 63 | """Input is expected to be of size [bsz x seqlen].""" 64 | bsz, seq_len = input.size() 65 | max_pos = self.padding_idx + 1 + seq_len 66 | device = input.get_device() 67 | if device not in self.weights or max_pos > self.weights[device].size(0): 68 | # recompute/expand embeddings if needed 69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding( 70 | max_pos, 71 | self.embedding_dim, 72 | self.padding_idx, 73 | ) 74 | self.weights[device] = self.weights[device].type_as(self._float_tensor).to(input.device) 75 | positions = make_positions(input, self.padding_idx, self.left_pad) 76 | return self.weights[device].index_select(0, positions.reshape(-1)).reshape(bsz, seq_len, -1).detach() 77 | 78 | def max_positions(self): 79 | """Maximum number of supported positions.""" 80 | return int(1e5) # an arbitrary large number -------------------------------------------------------------------------------- /configs/sdif_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | 24 | hyper_parameters = { 25 | 'data_mode': 'multi-class', 26 | 'padding_mode': 'zero', 27 | 'padding_loc': 'end', 28 | 'need_aligned': False, 29 | 'eval_monitor': ['f1'], 30 | 'train_batch_size': [16], 31 | 'eval_batch_size': 8, 32 | 'test_batch_size': 8, 33 | 'wait_patience': [8], 34 | 'num_train_epochs': [100], 35 | 'dst_feature_dims': 768, 36 | 'n_levels_self': 1, 37 | 'n_levels_cross': 1, 38 | 'dropout_rate': 0.2, 39 | 'cross_dp_rate': 0.3, 40 | 'cross_num_heads': 12, 41 | 'self_num_heads': 8, 42 | 'grad_clip': 7, 43 | 'lr': [9e-6], #9e-6原始 44 | 'opt_patience': 8, 45 | 'factor': 0.5, 46 | 'weight_decay': 0.01, 47 | 'aug_lr': 1e-6, #1e-6原始 48 | 'aug_epoch': 1, 49 | 'aug_dp': 0.3, 50 | 'aug_weight_decay': 0.1, 51 | 'aug_grad_clip': -1.0, 52 | 'aug_batch_size': 16, 53 | 'aug': True, 54 | 'aug_num': 25000, 55 | 'use_wandb': False, 56 | 'scale': [16], 57 | } 58 | # hyper_parameters = { 59 | # 'need_aligned': True, 60 | # 'freeze_parameters': False, 61 | # 'eval_monitor': 'f1', 62 | # 'eval_batch_size': 16, 63 | # 'wait_patience': [8], 64 | # 'binary_multiple_ood': 1.0, 65 | # 'base_dim': [768], 66 | # 'lr': [3e-5], #3e-5 67 | # 'temperature': [2], # bigger is usually better 68 | # 'alpha': [2], #0.5, 1 69 | # 'mlp_hidden_size': [256], 70 | # 'mlp_dropout': [0.1], 71 | # 're_prob': [0.1], 72 | # 'num_train_epochs': [100], # [30, 40, 50] 73 | # 'train_batch_size': [32], # [32, 64, 128] 74 | # 'weight_decay': [0.01], # [0.01, 0.05, 0.1] 75 | # 'multiple_ood': [1.0], # try average number 76 | # 'contrast_dropout': [0.1], 77 | # 'select_number_min': [2], 78 | # 'select_number_max': [3], 79 | # 'weight_dropout': [0.5], 80 | # 'weight_hidden_dim': [256], 81 | # # 'weight': [2, 3], 82 | # 'aligned_method': ['ctc'], 83 | # 'warmup_proportion': [0.1], 84 | # 'scale': [16], 85 | # 'encoder_layers_a': [1], 86 | # 'encoder_layers_v': [2], 87 | # 'attn_dropout': [0.0], 88 | # 'relu_dropout': [0.1], 89 | # 'embed_dropout': [0.0], 90 | # 'res_dropout': [0.2], #0 91 | # 'attn_mask': [False], #True 92 | # 'nheads': [2], #4 93 | # } 94 | 95 | return hyper_parameters -------------------------------------------------------------------------------- /configs/sdif_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | 24 | hyper_parameters = { 25 | 'data_mode': 'multi-class', 26 | 'padding_mode': 'zero', 27 | 'padding_loc': 'end', 28 | 'need_aligned': False, 29 | 'eval_monitor': ['f1'], 30 | 'train_batch_size': [16], 31 | 'eval_batch_size': 8, 32 | 'test_batch_size': 8, 33 | 'wait_patience': [8], 34 | 'num_train_epochs': [100], 35 | 'dst_feature_dims': 768, 36 | 'n_levels_self': 1, 37 | 'n_levels_cross': 1, 38 | 'dropout_rate': 0.2, 39 | 'cross_dp_rate': 0.3, 40 | 'cross_num_heads': 12, 41 | 'self_num_heads': 8, 42 | 'grad_clip': 7, 43 | 'lr': [9e-6], #9e-6原始 44 | 'opt_patience': 8, 45 | 'factor': 0.5, 46 | 'weight_decay': 0.01, 47 | 'aug_lr': 1e-6, #1e-6原始 48 | 'aug_epoch': 1, 49 | 'aug_dp': 0.3, 50 | 'aug_weight_decay': 0.1, 51 | 'aug_grad_clip': -1.0, 52 | 'aug_batch_size': 16, 53 | 'aug': True, 54 | 'aug_num': 25000, 55 | 'use_wandb': False, 56 | 'scale': [16], 57 | } 58 | # hyper_parameters = { 59 | # 'need_aligned': True, 60 | # 'freeze_parameters': False, 61 | # 'eval_monitor': 'f1', 62 | # 'eval_batch_size': 16, 63 | # 'wait_patience': [8], 64 | # 'binary_multiple_ood': 1.0, 65 | # 'base_dim': [768], 66 | # 'lr': [3e-5], #3e-5 67 | # 'temperature': [2], # bigger is usually better 68 | # 'alpha': [2], #0.5, 1 69 | # 'mlp_hidden_size': [256], 70 | # 'mlp_dropout': [0.1], 71 | # 're_prob': [0.1], 72 | # 'num_train_epochs': [100], # [30, 40, 50] 73 | # 'train_batch_size': [32], # [32, 64, 128] 74 | # 'weight_decay': [0.01], # [0.01, 0.05, 0.1] 75 | # 'multiple_ood': [1.0], # try average number 76 | # 'contrast_dropout': [0.1], 77 | # 'select_number_min': [2], 78 | # 'select_number_max': [3], 79 | # 'weight_dropout': [0.5], 80 | # 'weight_hidden_dim': [256], 81 | # # 'weight': [2, 3], 82 | # 'aligned_method': ['ctc'], 83 | # 'warmup_proportion': [0.1], 84 | # 'scale': [16], 85 | # 'encoder_layers_a': [1], 86 | # 'encoder_layers_v': [2], 87 | # 'attn_dropout': [0.0], 88 | # 'relu_dropout': [0.1], 89 | # 'embed_dropout': [0.0], 90 | # 'res_dropout': [0.2], #0 91 | # 'attn_mask': [False], #True 92 | # 'nheads': [2], #4 93 | # } 94 | 95 | return hyper_parameters -------------------------------------------------------------------------------- /configs/sdif_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | 24 | hyper_parameters = { 25 | 'data_mode': 'multi-class', 26 | 'padding_mode': 'zero', 27 | 'padding_loc': 'end', 28 | 'need_aligned': False, 29 | 'eval_monitor': ['f1'], 30 | 'train_batch_size': [16], 31 | 'eval_batch_size': 8, 32 | 'test_batch_size': 8, 33 | 'wait_patience': [8], 34 | 'num_train_epochs': [100], 35 | 'dst_feature_dims': 768, 36 | 'n_levels_self': 1, 37 | 'n_levels_cross': 1, 38 | 'dropout_rate': 0.2, 39 | 'cross_dp_rate': 0.3, 40 | 'cross_num_heads': 12, 41 | 'self_num_heads': 8, 42 | 'grad_clip': 7, 43 | 'lr': [9e-6], #9e-6原始 44 | 'opt_patience': 8, 45 | 'factor': 0.5, 46 | 'weight_decay': 0.01, 47 | 'aug_lr': 1e-6, #1e-6原始 48 | 'aug_epoch': 1, 49 | 'aug_dp': 0.3, 50 | 'aug_weight_decay': 0.1, 51 | 'aug_grad_clip': -1.0, 52 | 'aug_batch_size': 16, 53 | 'aug': True, 54 | 'aug_num': 25000, 55 | 'use_wandb': False, 56 | 'scale': [32], 57 | } 58 | # hyper_parameters = { 59 | # 'need_aligned': True, 60 | # 'freeze_parameters': False, 61 | # 'eval_monitor': 'f1', 62 | # 'eval_batch_size': 16, 63 | # 'wait_patience': [8], 64 | # 'binary_multiple_ood': 1.0, 65 | # 'base_dim': [768], 66 | # 'lr': [3e-5], #3e-5 67 | # 'temperature': [2], # bigger is usually better 68 | # 'alpha': [2], #0.5, 1 69 | # 'mlp_hidden_size': [256], 70 | # 'mlp_dropout': [0.1], 71 | # 're_prob': [0.1], 72 | # 'num_train_epochs': [100], # [30, 40, 50] 73 | # 'train_batch_size': [32], # [32, 64, 128] 74 | # 'weight_decay': [0.01], # [0.01, 0.05, 0.1] 75 | # 'multiple_ood': [1.0], # try average number 76 | # 'contrast_dropout': [0.1], 77 | # 'select_number_min': [2], 78 | # 'select_number_max': [3], 79 | # 'weight_dropout': [0.5], 80 | # 'weight_hidden_dim': [256], 81 | # # 'weight': [2, 3], 82 | # 'aligned_method': ['ctc'], 83 | # 'warmup_proportion': [0.1], 84 | # 'scale': [16], 85 | # 'encoder_layers_a': [1], 86 | # 'encoder_layers_v': [2], 87 | # 'attn_dropout': [0.0], 88 | # 'relu_dropout': [0.1], 89 | # 'embed_dropout': [0.0], 90 | # 'res_dropout': [0.2], #0 91 | # 'attn_mask': [False], #True 92 | # 'nheads': [2], #4 93 | # } 94 | 95 | return hyper_parameters -------------------------------------------------------------------------------- /configs/mintood_MELD-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | hyper_parameters = { 24 | 'need_aligned': True, 25 | 'freeze_parameters': False, 26 | 'eval_monitor': 'f1', 27 | 'eval_batch_size': 16, 28 | 'wait_patience': [3], 29 | 'binary_multiple_ood': 1.0, 30 | 'base_dim': [768], 31 | 'lr': [4e-6], #3e-5 32 | 'temperature': [1], # bigger is usually better 33 | 'alpha': [0.7], #0.5, 1 34 | 'mlp_hidden_size': [256], 35 | 'mlp_dropout': [0.2], 36 | 're_prob': [0.1], 37 | 'num_train_epochs': [100], # [30, 40, 50] 38 | 'train_batch_size': [32], # [32, 64, 128] 39 | 'weight_decay': [0.1], # [0.01, 0.05, 0.1] 40 | 'multiple_ood': [1.0], # try average number 41 | 'contrast_dropout': [0.1], 42 | 'select_number_min': [2], 43 | 'select_number_max': [3], 44 | 'weight_dropout': [0.5], 45 | 'weight_hidden_dim': [256], 46 | # 'weight': [2, 3], 47 | 'aligned_method': ['ctc'], 48 | 'warmup_proportion': [0.1], 49 | 'scale': [16], 50 | 'encoder_layers_a': [1], 51 | 'encoder_layers_v': [2], 52 | 'attn_dropout': [0.0], 53 | 'relu_dropout': [0.1], 54 | 'embed_dropout': [0.0], 55 | 'res_dropout': [0.2], #0 56 | 'attn_mask': [False], #True 57 | 'nheads': [2], #4 58 | 'grad_norm': [-1], 59 | } 60 | # hyper_parameters = { 61 | # 'need_aligned': True, 62 | # 'freeze_parameters': False, 63 | # 'eval_monitor': 'f1', 64 | # 'eval_batch_size': 16, 65 | # 'wait_patience': [3], 66 | # 'binary_multiple_ood': 1.0, 67 | # 'base_dim': [768], # ok 68 | # 'lr': [3e-6], # 69 | # 'temperature': [0.5], # bigger is usually better 70 | # 'alpha': [0.5], #0.5, 1 71 | # 'mlp_hidden_size': [256], 72 | # 'mlp_dropout': [0.1], 73 | # 'num_train_epochs': [100], # ok 74 | # 'train_batch_size': [32], # ok 75 | # 'weight_decay': [0.1], # ok 76 | # 'multiple_ood': [1.0], # try average number 77 | # 'select_number_min': [2], 78 | # 'select_number_max': [5], 79 | # 'weight_dropout': [0.1], 80 | # 'weight_hidden_dim': [256], 81 | # 'warmup_proportion': [0.1], 82 | # 'scale': [16], # ok 83 | # 'encoder_layers_a': [1], # ok 84 | # 'encoder_layers_v': [2], # ok 85 | # 'grad_norm': [-1], 86 | # 'attn_dropout': [0.2], # ok 87 | # 'relu_dropout': [0.2], # ok 88 | # 'embed_dropout': [0], 89 | # 'res_dropout': [0.3], 90 | # 'attn_mask': [False], #True 91 | # 'nheads': [8], # ok 92 | # } 93 | return hyper_parameters -------------------------------------------------------------------------------- /configs/mmim_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | self.hyper_param = self._get_hyper_parameters(args) 5 | 6 | def _get_hyper_parameters(self, args): 7 | ''' 8 | Args: 9 | num_train_epochs (int): The number of training epochs. 10 | num_labels (autofill): The output dimension. 11 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 12 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 13 | feat_dim (int): The feature dimension. 14 | warmup_proportion (float): The warmup ratio for learning rate. 15 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 16 | train_batch_size (int): The batch size for training. 17 | eval_batch_size (int): The batch size for evaluation. 18 | test_batch_size (int): The batch size for testing. 19 | wait_patient (int): Patient steps for Early Stop. 20 | ''' 21 | ood_detection_parameters = { 22 | 'sbm':{ 23 | 'temperature': [1e6], 24 | 'scale': [20] 25 | }, 26 | 'hub':{ 27 | 'temperature': [1e6], 28 | 'scale': [20], 29 | 'k': [10], 30 | 'alpha': [0.5] 31 | } 32 | } 33 | if args.text_backbone.startswith('bert'): 34 | hyper_parameters = { 35 | 'need_aligned': False, 36 | 'eval_monitor': ['f1'], 37 | 'train_batch_size': 16, 38 | 'eval_batch_size': 8, 39 | 'test_batch_size': 8, 40 | 'wait_patience': 8, 41 | 'num_train_epochs': 100, 42 | ##################### 43 | 'add_va': False, 44 | 'cpc_activation': 'Tanh', 45 | 'mmilb_mid_activation': 'ReLU', 46 | 'mmilb_last_activation': 'Tanh', 47 | 'optim': 'Adam', 48 | 'contrast': True, 49 | 'bidirectional': True, 50 | 'grad_clip': 1.0, 51 | 'lr_main': 2e-5, 52 | 'weight_decay_main': 1e-4, 53 | 'lr_bert': 2e-5, 54 | 'weight_decay_bert': 4e-5, 55 | 'lr_mmilb': 0.003, 56 | 'weight_decay_mmilb': 0.0003, 57 | 'alpha': 0.1, 58 | 'dropout_a': 0.1, 59 | 'dropout_v': 0.1, 60 | 'dropout_prj': 0.1, 61 | 'n_layer': 1, 62 | 'cpc_layers': 1, 63 | 'd_vh': 8, 64 | 'd_ah': 32, 65 | 'd_vout': 16, 66 | 'd_aout': 16, 67 | 'd_prjh': 512, 68 | 'scale': 20, 69 | 'beta':0.5 70 | } 71 | elif args.text_backbone.startswith('roberta'): 72 | hyper_parameters = { 73 | 'need_aligned': False, 74 | 'eval_monitor': ['weighted_f1'], 75 | 'train_batch_size': 16, 76 | 'eval_batch_size': 8, 77 | 'test_batch_size': 8, 78 | 'wait_patience': 8, 79 | 'num_train_epochs': 100, 80 | ##################### 81 | 'add_va': False, 82 | 'cpc_activation': 'Tanh', 83 | 'mmilb_mid_activation': 'ReLU', 84 | 'mmilb_last_activation': 'Tanh', 85 | 'optim': 'Adam', 86 | 'contrast': True, 87 | 'bidirectional': True, 88 | 'grad_clip': 1.0, 89 | 'lr_main': 1e-4, 90 | 'weight_decay_main': 1e-4, 91 | 'lr_bert': 4e-5, 92 | 'weight_decay_bert': 8e-5, 93 | 'lr_mmilb': 0.001, 94 | 'weight_decay_mmilb': 0.0001, 95 | 'alpha': 0.1, 96 | 'beta': 0.01, 97 | 'dropout_a': 0.1, 98 | 'dropout_v': 0.1, 99 | 'dropout_prj': 0.1, 100 | 'n_layer': 1, 101 | 'cpc_layers': 1, 102 | 'd_vh': 32, 103 | 'd_ah': 32, 104 | 'd_vout': 16, 105 | 'd_aout': 16, 106 | 'd_prjh': 512, 107 | 'beta':0.5 108 | } 109 | # if args.ood_detection_method in ood_detection_parameters.keys(): 110 | # ood_parameters = ood_detection_parameters[args.ood_detection_method] 111 | # hyper_parameters.update(ood_parameters) 112 | 113 | return hyper_parameters -------------------------------------------------------------------------------- /configs/mmim_MIntRec.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | self.hyper_param = self._get_hyper_parameters(args) 5 | 6 | def _get_hyper_parameters(self, args): 7 | ''' 8 | Args: 9 | num_train_epochs (int): The number of training epochs. 10 | num_labels (autofill): The output dimension. 11 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 12 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 13 | feat_dim (int): The feature dimension. 14 | warmup_proportion (float): The warmup ratio for learning rate. 15 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 16 | train_batch_size (int): The batch size for training. 17 | eval_batch_size (int): The batch size for evaluation. 18 | test_batch_size (int): The batch size for testing. 19 | wait_patient (int): Patient steps for Early Stop. 20 | ''' 21 | ood_detection_parameters = { 22 | 'sbm':{ 23 | 'temperature': [1e6], 24 | 'scale': [20] 25 | }, 26 | 'hub':{ 27 | 'temperature': [1e6], 28 | 'scale': [20], 29 | 'k': [10], 30 | 'alpha': [0.5] 31 | } 32 | } 33 | if args.text_backbone.startswith('bert'): 34 | hyper_parameters = { 35 | 'need_aligned': False, 36 | 'eval_monitor': ['f1'], 37 | 'train_batch_size': 16, 38 | 'eval_batch_size': 8, 39 | 'test_batch_size': 8, 40 | 'wait_patience': 8, 41 | 'num_train_epochs': 100, 42 | ##################### 43 | 'add_va': False, 44 | 'cpc_activation': 'Tanh', 45 | 'mmilb_mid_activation': 'ReLU', 46 | 'mmilb_last_activation': 'Tanh', 47 | 'optim': 'Adam', 48 | 'contrast': True, 49 | 'bidirectional': True, 50 | 'grad_clip': 1.0, 51 | 'lr_main': 1e-4, 52 | 'weight_decay_main': 1e-4, 53 | 'lr_bert': 4e-5, 54 | 'weight_decay_bert': 8e-5, 55 | 'lr_mmilb': 0.001, 56 | 'weight_decay_mmilb': 0.0001, 57 | 'alpha': 0.1, 58 | 'dropout_a': 0.1, 59 | 'dropout_v': 0.1, 60 | 'dropout_prj': 0.1, 61 | 'n_layer': 1, 62 | 'cpc_layers': 1, 63 | 'd_vh': 32, 64 | 'd_ah': 32, 65 | 'd_vout': 16, 66 | 'd_aout': 16, 67 | 'd_prjh': 512, 68 | 'scale': 20, 69 | 'beta':0.5 70 | } 71 | elif args.text_backbone.startswith('roberta'): 72 | hyper_parameters = { 73 | 'need_aligned': False, 74 | 'eval_monitor': ['weighted_f1'], 75 | 'train_batch_size': 16, 76 | 'eval_batch_size': 8, 77 | 'test_batch_size': 8, 78 | 'wait_patience': 8, 79 | 'num_train_epochs': 100, 80 | ##################### 81 | 'add_va': False, 82 | 'cpc_activation': 'Tanh', 83 | 'mmilb_mid_activation': 'ReLU', 84 | 'mmilb_last_activation': 'Tanh', 85 | 'optim': 'Adam', 86 | 'contrast': True, 87 | 'bidirectional': True, 88 | 'grad_clip': 1.0, 89 | 'lr_main': 1e-4, 90 | 'weight_decay_main': 1e-4, 91 | 'lr_bert': 4e-5, 92 | 'weight_decay_bert': 8e-5, 93 | 'lr_mmilb': 0.001, 94 | 'weight_decay_mmilb': 0.0001, 95 | 'alpha': 0.1, 96 | 'beta': 0.01, 97 | 'dropout_a': 0.1, 98 | 'dropout_v': 0.1, 99 | 'dropout_prj': 0.1, 100 | 'n_layer': 1, 101 | 'cpc_layers': 1, 102 | 'd_vh': 32, 103 | 'd_ah': 32, 104 | 'd_vout': 16, 105 | 'd_aout': 16, 106 | 'd_prjh': 512, 107 | 'beta':0.5 108 | } 109 | # if args.ood_detection_method in ood_detection_parameters.keys(): 110 | # ood_parameters = ood_detection_parameters[args.ood_detection_method] 111 | # hyper_parameters.update(ood_parameters) 112 | 113 | return hyper_parameters -------------------------------------------------------------------------------- /backbones/SubNets/AlignNets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['CTCModule', 'AlignSubNet'] 5 | 6 | class CTCModule(nn.Module): 7 | def __init__(self, in_dim, out_seq_len): 8 | ''' 9 | This module is performing alignment from A (e.g., audio) to B (e.g., text). 10 | :param in_dim: Dimension for input modality A 11 | :param out_seq_len: Sequence length for output modality B 12 | From: https://github.com/yaohungt/Multimodal-Transformer 13 | ''' 14 | super(CTCModule, self).__init__() 15 | # Use LSTM for predicting the position from A to B 16 | self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank 17 | 18 | self.out_seq_len = out_seq_len 19 | 20 | self.softmax = nn.Softmax(dim=2) 21 | 22 | def forward(self, x): 23 | ''' 24 | :input x: Input with shape [batch_size x in_seq_len x in_dim] 25 | ''' 26 | # NOTE that the index 0 refers to blank. 27 | pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x) 28 | 29 | prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1 30 | prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len 31 | prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len 32 | pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim 33 | 34 | # pseudo_aligned_out is regarded as the aligned A (w.r.t B) 35 | # return pseudo_aligned_out, (pred_output_position_inclu_blank) 36 | return pseudo_aligned_out 37 | 38 | class AlignSubNet(nn.Module): 39 | def __init__(self, args, mode): 40 | """ 41 | mode: the way of aligning 42 | avg_pool, ctc, conv1d 43 | """ 44 | super(AlignSubNet, self).__init__() 45 | assert mode in ['avg_pool', 'ctc', 'conv1d'] 46 | 47 | in_dim_t, in_dim_a, in_dim_v = args.text_feat_dim, args.audio_feat_dim, args.video_feat_dim 48 | seq_len_t, seq_len_a, seq_len_v = args.text_seq_len, args.audio_seq_len, args.video_seq_len 49 | self.dst_len = seq_len_t 50 | self.mode = mode 51 | 52 | self.ALIGN_WAY = { 53 | 'avg_pool': self.__avg_pool, 54 | 'ctc': self.__ctc, 55 | 'conv1d': self.__conv1d 56 | } 57 | 58 | if mode == 'conv1d': 59 | self.conv1d_T = nn.Conv1d(seq_len_t, self.dst_len, kernel_size=1, bias=False) 60 | self.conv1d_A = nn.Conv1d(seq_len_a, self.dst_len, kernel_size=1, bias=False) 61 | self.conv1d_V = nn.Conv1d(seq_len_v, self.dst_len, kernel_size=1, bias=False) 62 | elif mode == 'ctc': 63 | self.ctc_t = CTCModule(in_dim_t, self.dst_len) 64 | self.ctc_a = CTCModule(in_dim_a, self.dst_len) 65 | self.ctc_v = CTCModule(in_dim_v, self.dst_len) 66 | 67 | def get_seq_len(self): 68 | return self.dst_len 69 | 70 | def __ctc(self, text_x, audio_x, video_x): 71 | text_x = self.ctc_t(text_x) if text_x.size(1) != self.dst_len else text_x 72 | audio_x = self.ctc_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x 73 | video_x = self.ctc_v(video_x) if video_x.size(1) != self.dst_len else video_x 74 | return text_x, audio_x, video_x 75 | 76 | def __avg_pool(self, text_x, audio_x, video_x): 77 | def align(x): 78 | raw_seq_len = x.size(1) 79 | if raw_seq_len == self.dst_len: 80 | return x 81 | if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len: 82 | pad_len = 0 83 | pool_size = raw_seq_len // self.dst_len 84 | else: 85 | pad_len = self.dst_len - raw_seq_len % self.dst_len 86 | pool_size = raw_seq_len // self.dst_len + 1 87 | pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)]) 88 | x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1) 89 | x = x.mean(dim=1) 90 | return x 91 | text_x = align(text_x) 92 | audio_x = align(audio_x) 93 | video_x = align(video_x) 94 | return text_x, audio_x, video_x 95 | 96 | def __conv1d(self, text_x, audio_x, video_x): 97 | text_x = self.conv1d_T(text_x) if text_x.size(1) != self.dst_len else text_x 98 | audio_x = self.conv1d_A(audio_x) if audio_x.size(1) != self.dst_len else audio_x 99 | video_x = self.conv1d_V(video_x) if video_x.size(1) != self.dst_len else video_x 100 | return text_x, audio_x, video_x 101 | 102 | def forward(self, text_x, audio_x, video_x): 103 | # already aligned 104 | if text_x.size(1) == audio_x.size(1) == video_x.size(1): 105 | return text_x, audio_x, video_x 106 | return self.ALIGN_WAY[self.mode](text_x, audio_x, video_x) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | benchmarks = { 4 | 'MIntRec':{ 5 | 'intent_labels': [ 6 | 'Complain', 'Praise', 'Apologise', 'Thank', 'Criticize', 7 | 'Agree', 'Taunt', 'Flaunt', 8 | 'Joke', 'Oppose', 9 | 'Comfort', 'Care', 'Inform', 'Advise', 'Arrange', 'Introduce', 'Leave', 10 | 'Prevent', 'Greet', 'Ask for help' 11 | ], 12 | 'binary_maps': { 13 | 'Complain': 'Emotion', 'Praise':'Emotion', 'Apologise': 'Emotion', 'Thank':'Emotion', 'Criticize': 'Emotion', 14 | 'Care': 'Emotion', 'Agree': 'Emotion', 'Taunt': 'Emotion', 'Flaunt': 'Emotion', 15 | 'Joke':'Emotion', 'Oppose': 'Emotion', 16 | 'Inform':'Goal', 'Advise':'Goal', 'Arrange': 'Goal', 'Introduce': 'Goal', 'Leave':'Goal', 17 | 'Prevent':'Goal', 'Greet': 'Goal', 'Ask for help': 'Goal', 'Comfort': 'Goal' 18 | }, 19 | 'binary_intent_labels': ['Emotion', 'Goal'], 20 | 'label_len': 4, 21 | 'max_seq_lengths': { 22 | 'text': 30, 23 | 'video': 230, 24 | 'audio': 480 25 | }, 26 | 'feat_dims': { 27 | 'text': 768, 28 | 'video': 1024, 29 | 'audio': 768 30 | }, 31 | 'ood_data':{ 32 | 'MIntRec-OOD': {'ood_label': 'UNK'}, 33 | 'TED-OOD': {'ood_label': 'UNK'}, 34 | 'IEMOCAP-DA': {'ood_label': 'oth'}, 35 | 'IEMOCAP-DA-OOD': {'ood_label': 'oth'}, 36 | 'MELD-DA': {'ood_label': 'oth'}, 37 | 'MELD-DA-OOD': {'ood_label': 'oth'} 38 | } 39 | }, 40 | 'MIntRec2.0':{ 41 | 'intent_labels': [ 42 | 'Acknowledge', 'Advise', 'Agree', 'Apologise', 'Arrange', 43 | 'Ask for help', 'Asking for opinions', 'Care', 'Comfort', 'Complain', 44 | 'Confirm', 'Criticize', 'Doubt', 'Emphasize', 'Explain', 45 | 'Flaunt', 'Greet', 'Inform', 'Introduce', 'Invite', 46 | 'Joke', 'Leave', 'Oppose', 'Plan', 'Praise', 47 | 'Prevent', 'Refuse', 'Taunt', 'Thank', 'Warn', 48 | ], 49 | 'max_seq_lengths': { 50 | 'text': 50, # truth: 51 (max), 23 (mean+3std) 51 | 'video': 180, # truth: 475 (max), 67 (avg), 181 (mean+3std) 52 | 'audio': 400, # truth: 992 (max), 387 (mean+3std), 53 | }, 54 | 'feat_dims': { 55 | 'text': 1024, 56 | 'video': 256, 57 | 'audio': 768 58 | }, 59 | 'ood_data': { 60 | 'MIntRec2.0-OOD': {'ood_label': 'UNK'} 61 | } 62 | }, 63 | 'MELD-DA':{ 64 | 'intent_labels': [ 65 | 'Greeting', 'Question', 'Answer', 'Statement Opinion', 'Statement Non Opinion', 66 | 'Apology', 'Command', 'Agreement', 'Disagreement', 67 | 'Acknowledge', 'Backchannel' 68 | ], 69 | 'label_maps': { 70 | 'g': 'Greeting', 'q': 'Question', 'ans': 'Answer', 'o': 'Statement Opinion', 's': 'Statement Non Opinion', 71 | 'ap': 'Apology', 'c': 'Command', 'ag': 'Agreement', 'dag': 'Disagreement', 72 | 'a': 'Acknowledge', 'b': 'Backchannel' 73 | }, 74 | 'max_seq_lengths': { 75 | 'text': 70, # max: 69, final: 27 76 | 'video': 250, ### max: 618, final: 242 77 | 'audio': 530 ### max 2052, final: 524 78 | }, 79 | 'label_len': 3, 80 | 'feat_dims': { 81 | 'text': 768, 82 | 'video': 1024, 83 | 'audio': 768 84 | }, 85 | 'ood_data':{ 86 | 'MELD-DA-OOD': {'ood_label': 'oth'}, 87 | 'MIntRec-OOD': {'ood_label': 'UNK'}, 88 | 'TED-OOD': {'ood_label': 'UNK'}, 89 | 'IEMOCAP-DA': {'ood_label': 'oth'}, 90 | 'IEMOCAP-DA-OOD': {'ood_label': 'oth'}, 91 | 'MIntRec': {'ood_label': 'UNK'} 92 | } 93 | }, 94 | 'IEMOCAP-DA':{ 95 | 'intent_labels': [ 96 | 'Greeting', 'Question', 'Answer', 'Statement Opinion', 'Statement Non Opinion', 97 | 'Apology', 'Command', 'Agreement', 'Disagreement', 98 | 'Acknowledge', 'Backchannel' 99 | ], 100 | 'label_maps': { 101 | 'g': 'Greeting', 'q': 'Question', 'ans': 'Answer', 'o': 'Statement Opinion', 's': 'Statement Non Opinion', 102 | 'ap': 'Apology', 'c': 'Command', 'ag': 'Agreement', 'dag': 'Disagreement', 103 | 'a': 'Acknowledge', 'b': 'Backchannel' 104 | }, 105 | 'max_seq_lengths': { 106 | 'text': 44, 107 | 'video': 230, # mean+sigma 108 | 'audio': 380 109 | }, 110 | 'label_len': 3, 111 | 'feat_dims': { 112 | 'text': 768, 113 | 'video': 1024, 114 | 'audio': 768 115 | }, 116 | 'ood_data':{ 117 | 'IEMOCAP-DA-OOD': {'ood_label': 'oth'}, 118 | 'MELD-DA': {'ood_label': 'oth'}, 119 | 'MELD-DA-OOD': {'ood_label': 'oth'}, 120 | 'MIntRec-OOD': {'ood_label': 'UNK'}, 121 | 'TED-OOD': {'ood_label': 'UNK'}, 122 | 'MIntRec': {'ood_label': 'UNK'}, 123 | } 124 | }, 125 | } 126 | 127 | -------------------------------------------------------------------------------- /backbones/FusionNets/MIntOOD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import math 5 | from losses import loss_map 6 | from ..SubNets.FeatureNets import BERTEncoder, SubNet 7 | from ..SubNets.transformers_encoder.transformer import TransformerEncoder 8 | from .sampler import MIntOODSampler 9 | from torch import nn 10 | from ..SubNets import text_backbones_map 11 | from data.__init__ import benchmarks 12 | from ..SubNets.AlignNets import AlignSubNet 13 | from torch.nn.parameter import Parameter 14 | from .MIntOOD_Fusion import Fusion 15 | 16 | activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()} 17 | __all__ = ['MIntOOD'] 18 | 19 | class CosNorm_Classifier(nn.Module): 20 | 21 | def __init__(self, in_dims, out_dims, scale = 32, device = None): 22 | 23 | super(CosNorm_Classifier, self).__init__() 24 | self.in_dims = in_dims 25 | self.out_dims = out_dims 26 | self.weight = Parameter(torch.Tensor(out_dims, in_dims).to(device)) 27 | self.scale = scale 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | stdv = 1. / math.sqrt(self.weight.size(1)) 32 | self.weight.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, binary_scores = None, *args): 35 | 36 | norm_x = torch.norm(input, 2, 1, keepdim=True) 37 | ex = input / norm_x 38 | ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True) 39 | 40 | return torch.mm(ex * self.scale, ew.t()) 41 | 42 | class MLP_head(nn.Module): 43 | 44 | def __init__(self, args, num_classes): 45 | 46 | super(MLP_head, self).__init__() 47 | self.args = args 48 | self.num_classes = num_classes 49 | 50 | if num_classes == 2: 51 | 52 | self.layer1 = nn.Linear(args.base_dim, args.mlp_hidden_size) 53 | self.layer2 = nn.Linear(args.mlp_hidden_size, args.mlp_hidden_size) 54 | 55 | self.relu = nn.ReLU() 56 | self.gelu = nn.GELU() 57 | self.dropout = nn.Dropout(p=args.mlp_dropout) 58 | self.output_layer = nn.Linear(args.mlp_hidden_size, num_classes) 59 | 60 | else: 61 | 62 | self.relu = nn.ReLU() 63 | self.gelu = nn.GELU() 64 | if args.ablation_type == 'wo_cosine': 65 | self.output_layer_1 = nn.Linear(args.base_dim, args.num_labels) 66 | else: 67 | self.output_layer_1 = CosNorm_Classifier(args.base_dim, args.num_labels, args.scale, args.device) 68 | 69 | if args.ablation_type != 'wo_contrast': 70 | self.contrast_head = nn.Sequential( 71 | nn.Linear(args.base_dim, args.num_labels) 72 | ) 73 | 74 | def adjust_scores(self, scores): 75 | eps = 1e-6 76 | adjusted_scores = scores / (1 - scores + eps) 77 | return adjusted_scores 78 | 79 | def forward(self, x, binary_scores = None, mode = 'ind', return_mlp=False): 80 | 81 | if self.num_classes == 2: 82 | 83 | x = self.relu(self.layer1(x)) 84 | x = self.dropout(x) 85 | x = self.relu(self.layer2(x)) 86 | x = self.dropout(x) 87 | x = self.output_layer(x) 88 | 89 | return x 90 | 91 | else: 92 | if binary_scores is not None: 93 | 94 | binary_scores = self.adjust_scores(binary_scores) 95 | binary_scores = binary_scores.unsqueeze(1).expand(-1, x.shape[1]) 96 | fusion_x = x * binary_scores 97 | 98 | logits = self.output_layer_1(fusion_x) 99 | 100 | if self.args.ablation_type != 'wo_contrast': 101 | contrast_logits = self.contrast_head(x) 102 | else: 103 | contrast_logits = logits 104 | 105 | 106 | if return_mlp: 107 | return fusion_x, logits, contrast_logits 108 | else: 109 | return logits, contrast_logits 110 | 111 | else: 112 | fusion_x = x 113 | logits = self.output_layer_1(fusion_x) 114 | contrast_logits = self.contrast_head(x) 115 | 116 | if return_mlp: 117 | return fusion_x, logits, contrast_logits 118 | else: 119 | return logits, contrast_logits 120 | 121 | def vim(self): 122 | w = self.output_layer_1.weight 123 | b = torch.zeros(w.size(0)) 124 | 125 | return w, b 126 | 127 | class MMEncoder(nn.Module): 128 | 129 | def __init__(self, args): 130 | 131 | super(MMEncoder, self).__init__() 132 | self.model = Fusion(args) 133 | self.sampler = MIntOODSampler(args) 134 | 135 | def forward(self, text_feats, video_feats, audio_feats, labels = None, ood_sampling = False, data_aug = False, probs = None, binary = False, ood_elems = None): 136 | 137 | if ood_sampling: 138 | pooled_output, mixed_labels = self.model(text_feats, video_feats, audio_feats, self.sampler, labels, ood_sampling = ood_sampling, data_aug = data_aug, probs = probs, binary = binary, ood_elems = ood_elems) 139 | return pooled_output, mixed_labels 140 | else: 141 | pooled_output = self.model(text_feats, video_feats, audio_feats, self.sampler, labels, ood_sampling = ood_sampling, data_aug = data_aug, probs = probs, binary = binary, ood_elems = ood_elems) 142 | return pooled_output 143 | 144 | 145 | -------------------------------------------------------------------------------- /configs/mintood_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | class Param(): 2 | 3 | def __init__(self, args): 4 | 5 | self.hyper_param = self._get_hyper_parameters(args) 6 | 7 | def _get_hyper_parameters(self, args): 8 | ''' 9 | Args: 10 | num_train_epochs (int): The number of training epochs. 11 | num_labels (autofill): The output dimension. 12 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 13 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 14 | feat_dim (int): The feature dimension. 15 | warmup_proportion (float): The warmup ratio for learning rate. 16 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 17 | train_batch_size (int): The batch size for training. 18 | eval_batch_size (int): The batch size for evaluation. 19 | test_batch_size (int): The batch size for testing. 20 | wait_patient (int): Patient steps for Early Stop. 21 | ''' 22 | if args.text_backbone.startswith('bert'): 23 | hyper_parameters = { 24 | 'need_aligned': True, 25 | 'freeze_parameters': False, 26 | 'eval_monitor': 'f1', 27 | 'eval_batch_size': 16, 28 | 'wait_patience': [3], 29 | 'binary_multiple_ood': 1.0, 30 | 'base_dim': [768], 31 | 'lr': [3e-6], #3e-5 32 | 'temperature': [0.7], # bigger is usually better 33 | 'alpha': [0.7], #0.5, 1 34 | 'mlp_hidden_size': [256], 35 | 'mlp_dropout': [0.2], 36 | 're_prob': [0.1], 37 | 'num_train_epochs': [100], # [30, 40, 50] 38 | 'train_batch_size': [32], # [32, 64, 128] 39 | 'weight_decay': [0.1], # [0.01, 0.05, 0.1] 40 | 'multiple_ood': [1.0], # try average number 41 | 'contrast_dropout': [0.1], 42 | 'select_number_min': [2], 43 | 'select_number_max': [3], 44 | 'weight_dropout': [0.1], 45 | 'weight_hidden_dim': [256], 46 | # 'weight': [2, 3], 47 | 'aligned_method': ['ctc'], 48 | 'warmup_proportion': [0.1], 49 | 'scale': [32], 50 | 'encoder_layers_a': [1], 51 | 'encoder_layers_v': [2], 52 | 'attn_dropout': [0.0], 53 | 'relu_dropout': [0.1], 54 | 'embed_dropout': [0.0], 55 | 'res_dropout': [0.2], #0 56 | 'attn_mask': [False], #True 57 | 'nheads': [8], #4 58 | 'grad_norm': [-1], 59 | } 60 | return hyper_parameters 61 | 62 | # class Param(): 63 | 64 | # def __init__(self, args): 65 | 66 | # self.hyper_param = self._get_hyper_parameters(args) 67 | 68 | # def _get_hyper_parameters(self, args): 69 | # ''' 70 | # Args: 71 | # num_train_epochs (int): The number of training epochs. 72 | # num_labels (autofill): The output dimension. 73 | # max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 74 | # freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 75 | # feat_dim (int): The feature dimension. 76 | # warmup_proportion (float): The warmup ratio for learning rate. 77 | # activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 78 | # train_batch_size (int): The batch size for training. 79 | # eval_batch_size (int): The batch size for evaluation. 80 | # test_batch_size (int): The batch size for testing. 81 | # wait_patient (int): Patient steps for Early Stop. 82 | # ''' 83 | # if args.text_backbone.startswith('bert'): 84 | # hyper_parameters = { 85 | # 'need_aligned': True, 86 | # 'freeze_parameters': False, 87 | # 'eval_monitor': 'f1', 88 | # 'eval_batch_size': 16, 89 | # 'wait_patience': [3], 90 | # 'binary_multiple_ood': 1.0, 91 | # 'base_dim': [768], 92 | # 'lr': [3e-6], #3e-5 93 | # 'temperature': [1.0], # bigger is usually better 94 | # 'alpha': [0.7], #0.5, 1 95 | # 'mlp_hidden_size': [256], 96 | # 'mlp_dropout': [0.2], 97 | # 're_prob': [0.1], 98 | # 'num_train_epochs': [100], # [30, 40, 50] 99 | # 'train_batch_size': [32], # [32, 64, 128] 100 | # 'weight_decay': [0.1], # [0.01, 0.05, 0.1] 101 | # 'multiple_ood': [1.0], # try average number 102 | # 'contrast_dropout': [0.1], 103 | # 'select_number_min': [2], 104 | # 'select_number_max': [3], 105 | # 'weight_dropout': [0.1], 106 | # 'weight_hidden_dim': [256], 107 | # # 'weight': [2, 3], 108 | # 'aligned_method': ['ctc'], 109 | # 'warmup_proportion': [0.1], 110 | # 'scale': [16], 111 | # 'encoder_layers_a': [1], 112 | # 'encoder_layers_v': [2], 113 | # 'attn_dropout': [0.0], 114 | # 'relu_dropout': [0.1], 115 | # 'embed_dropout': [0.0], 116 | # 'res_dropout': [0.2], #0 117 | # 'attn_mask': [False], #True 118 | # 'nheads': [8], #4 119 | # 'grad_norm': [-1], 120 | # } 121 | # return hyper_parameters -------------------------------------------------------------------------------- /configs/mag_bert_MELD-DA.py: -------------------------------------------------------------------------------- 1 | # class Param(): 2 | 3 | # def __init__(self, args): 4 | 5 | # self.common_param = self._get_common_parameters(args) 6 | # self.hyper_param = self._get_hyper_parameters(args) 7 | 8 | # def _get_common_parameters(self, args): 9 | # """ 10 | # padding_mode (str): The mode for sequence padding ('zero' or 'normal'). 11 | # padding_loc (str): The location for sequence padding ('start' or 'end'). 12 | # eval_monitor (str): The monitor for evaluation ('loss' or metrics, e.g., 'f1', 'acc', 'precision', 'recall'). 13 | # need_aligned: (bool): Whether to perform data alignment between different modalities. 14 | # train_batch_size (int): The batch size for training. 15 | # eval_batch_size (int): The batch size for evaluation. 16 | # test_batch_size (int): The batch size for testing. 17 | # wait_patience (int): Patient steps for Early Stop. 18 | # """ 19 | # common_parameters = { 20 | # 'padding_mode': 'zero', 21 | # 'padding_loc': 'end', 22 | # 'need_aligned': False, 23 | # 'eval_monitor': 'f1', 24 | # 'train_batch_size': 16, 25 | # 'eval_batch_size': 8, 26 | # 'test_batch_size': 8, 27 | # 'wait_patience': 8 28 | # } 29 | # return common_parameters 30 | 31 | # def _get_hyper_parameters(self, args): 32 | # """ 33 | # Args: 34 | # num_train_epochs (int): The number of training epochs. 35 | # dst_feature_dims (int): The destination dimensions (assume d(l) = d(v) = d(t)). 36 | # nheads (int): The number of heads for the transformer network. 37 | # n_levels (int): The number of layers in the network. 38 | # attn_dropout (float): The attention dropout. 39 | # attn_dropout_v (float): The attention dropout for the video modality. 40 | # attn_dropout_a (float): The attention dropout for the audio modality. 41 | # relu_dropout (float): The relu dropout. 42 | # embed_dropout (float): The embedding dropout. 43 | # res_dropout (float): The residual block dropout. 44 | # output_dropout (float): The output layer dropout. 45 | # text_dropout (float): The dropout for text features. 46 | # grad_clip (float): The gradient clip value. 47 | # attn_mask (bool): Whether to use attention mask for Transformer. 48 | # conv1d_kernel_size_l (int): The kernel size for temporal convolutional layers (text modality). 49 | # conv1d_kernel_size_v (int): The kernel size for temporal convolutional layers (video modality). 50 | # conv1d_kernel_size_a (int): The kernel size for temporal convolutional layers (audio modality). 51 | # lr (float): The learning rate of backbone. 52 | # """ 53 | # hyper_parameters = { 54 | # 'num_train_epochs': 100, 55 | # 'dst_feature_dims': 120, 56 | # 'nheads': 8, 57 | # 'n_levels': 8, 58 | # 'attn_dropout': 0.0, 59 | # 'attn_dropout_v': 0.2, 60 | # 'attn_dropout_a': 0.2, 61 | # 'relu_dropout': 0.0, 62 | # 'embed_dropout': 0.1, 63 | # 'res_dropout': 0.0, 64 | # 'output_dropout': 0.2, 65 | # 'text_dropout': 0.4, 66 | # 'grad_clip': 0.5, 67 | # 'attn_mask': True, 68 | # 'conv1d_kernel_size_l': 5, 69 | # 'conv1d_kernel_size_v': 1, 70 | # 'conv1d_kernel_size_a': 1, 71 | # 'lr': 0.00003, 72 | # } 73 | # return hyper_parameters 74 | 75 | class Param(): 76 | 77 | def __init__(self, args): 78 | 79 | self.hyper_param = self._get_hyper_parameters(args) 80 | 81 | def _get_hyper_parameters(self, args): 82 | """ 83 | Args: 84 | num_train_epochs (int): The number of training epochs. 85 | num_labels (autofill): The output dimension. 86 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 87 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 88 | feat_dim (int): The feature dimension. 89 | warmup_proportion (float): The warmup ratio for learning rate. 90 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 91 | train_batch_size (int): The batch size for training. 92 | eval_batch_size (int): The batch size for evaluation. 93 | test_batch_size (int): The batch size for testing. 94 | wait_patient (int): Patient steps for Early Stop. 95 | """ 96 | ood_detection_parameters = { 97 | 'sbm':{ 98 | 'temperature': [1e4], 99 | 'scale': [20] 100 | } 101 | } 102 | if args.text_backbone.startswith('bert'): 103 | 104 | hyper_parameters = { 105 | 'need_aligned': True, 106 | 'eval_monitor': ['f1'], 107 | 'train_batch_size': 16, 108 | 'eval_batch_size': 8, 109 | 'test_batch_size': 8, 110 | 'wait_patience': 8, 111 | 'num_train_epochs': 100, 112 | ################ 113 | 'beta_shift': 0.04, 114 | 'dropout_prob': 0.4, 115 | 'warmup_proportion': 0.3, 116 | 'lr': 4e-5, 117 | 'aligned_method': 'ctc', 118 | 'weight_decay': 0.06, 119 | 'scale':32 120 | 121 | } 122 | else: 123 | raise ValueError('Not supported text backbone') 124 | 125 | if args.ood_detection_method in ood_detection_parameters.keys(): 126 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 127 | hyper_parameters.update(ood_parameters) 128 | 129 | return hyper_parameters -------------------------------------------------------------------------------- /configs/mag_bert_IEMOCAP-DA.py: -------------------------------------------------------------------------------- 1 | # class Param(): 2 | 3 | # def __init__(self, args): 4 | 5 | # self.common_param = self._get_common_parameters(args) 6 | # self.hyper_param = self._get_hyper_parameters(args) 7 | 8 | # def _get_common_parameters(self, args): 9 | # """ 10 | # padding_mode (str): The mode for sequence padding ('zero' or 'normal'). 11 | # padding_loc (str): The location for sequence padding ('start' or 'end'). 12 | # eval_monitor (str): The monitor for evaluation ('loss' or metrics, e.g., 'f1', 'acc', 'precision', 'recall'). 13 | # need_aligned: (bool): Whether to perform data alignment between different modalities. 14 | # train_batch_size (int): The batch size for training. 15 | # eval_batch_size (int): The batch size for evaluation. 16 | # test_batch_size (int): The batch size for testing. 17 | # wait_patience (int): Patient steps for Early Stop. 18 | # """ 19 | # common_parameters = { 20 | # 'padding_mode': 'zero', 21 | # 'padding_loc': 'end', 22 | # 'need_aligned': False, 23 | # 'eval_monitor': 'f1', 24 | # 'train_batch_size': 16, 25 | # 'eval_batch_size': 8, 26 | # 'test_batch_size': 8, 27 | # 'wait_patience': 8 28 | # } 29 | # return common_parameters 30 | 31 | # def _get_hyper_parameters(self, args): 32 | # """ 33 | # Args: 34 | # num_train_epochs (int): The number of training epochs. 35 | # dst_feature_dims (int): The destination dimensions (assume d(l) = d(v) = d(t)). 36 | # nheads (int): The number of heads for the transformer network. 37 | # n_levels (int): The number of layers in the network. 38 | # attn_dropout (float): The attention dropout. 39 | # attn_dropout_v (float): The attention dropout for the video modality. 40 | # attn_dropout_a (float): The attention dropout for the audio modality. 41 | # relu_dropout (float): The relu dropout. 42 | # embed_dropout (float): The embedding dropout. 43 | # res_dropout (float): The residual block dropout. 44 | # output_dropout (float): The output layer dropout. 45 | # text_dropout (float): The dropout for text features. 46 | # grad_clip (float): The gradient clip value. 47 | # attn_mask (bool): Whether to use attention mask for Transformer. 48 | # conv1d_kernel_size_l (int): The kernel size for temporal convolutional layers (text modality). 49 | # conv1d_kernel_size_v (int): The kernel size for temporal convolutional layers (video modality). 50 | # conv1d_kernel_size_a (int): The kernel size for temporal convolutional layers (audio modality). 51 | # lr (float): The learning rate of backbone. 52 | # """ 53 | # hyper_parameters = { 54 | # 'num_train_epochs': 100, 55 | # 'dst_feature_dims': 120, 56 | # 'nheads': 8, 57 | # 'n_levels': 8, 58 | # 'attn_dropout': 0.0, 59 | # 'attn_dropout_v': 0.2, 60 | # 'attn_dropout_a': 0.2, 61 | # 'relu_dropout': 0.0, 62 | # 'embed_dropout': 0.1, 63 | # 'res_dropout': 0.0, 64 | # 'output_dropout': 0.2, 65 | # 'text_dropout': 0.4, 66 | # 'grad_clip': 0.5, 67 | # 'attn_mask': True, 68 | # 'conv1d_kernel_size_l': 5, 69 | # 'conv1d_kernel_size_v': 1, 70 | # 'conv1d_kernel_size_a': 1, 71 | # 'lr': 0.00003, 72 | # } 73 | # return hyper_parameters 74 | 75 | class Param(): 76 | 77 | def __init__(self, args): 78 | 79 | self.hyper_param = self._get_hyper_parameters(args) 80 | 81 | def _get_hyper_parameters(self, args): 82 | """ 83 | Args: 84 | num_train_epochs (int): The number of training epochs. 85 | num_labels (autofill): The output dimension. 86 | max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. 87 | freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer. 88 | feat_dim (int): The feature dimension. 89 | warmup_proportion (float): The warmup ratio for learning rate. 90 | activation (str): The activation function of the hidden layer (support 'relu' and 'tanh'). 91 | train_batch_size (int): The batch size for training. 92 | eval_batch_size (int): The batch size for evaluation. 93 | test_batch_size (int): The batch size for testing. 94 | wait_patient (int): Patient steps for Early Stop. 95 | """ 96 | ood_detection_parameters = { 97 | 'sbm':{ 98 | 'temperature': [1e4], 99 | 'scale': [20] 100 | } 101 | } 102 | if args.text_backbone.startswith('bert'): 103 | 104 | hyper_parameters = { 105 | 'need_aligned': True, 106 | 'eval_monitor': ['f1'], 107 | 'train_batch_size': 16, 108 | 'eval_batch_size': 8, 109 | 'test_batch_size': 8, 110 | 'wait_patience': 8, 111 | 'num_train_epochs': 100, 112 | ################ 113 | 'beta_shift': 0.005, 114 | 'dropout_prob': 0.5, 115 | 'warmup_proportion': 0.1, 116 | 'lr': 2e-5, 117 | 'aligned_method': 'ctc', 118 | 'weight_decay': 0.03, 119 | 'scale':32 120 | 121 | } 122 | else: 123 | raise ValueError('Not supported text backbone') 124 | 125 | if args.ood_detection_method in ood_detection_parameters.keys(): 126 | ood_parameters = ood_detection_parameters[args.ood_detection_method] 127 | hyper_parameters.update(ood_parameters) 128 | 129 | return hyper_parameters -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | # from .sampler_cycle import get_sampler 5 | from torch.utils.data import DataLoader, WeightedRandomSampler, RandomSampler 6 | from collections import Counter 7 | 8 | def get_dataloader(args, data, weighted = False): 9 | 10 | if weighted: 11 | train_label_ids = data['train'].label_ids 12 | class_counts = Counter(train_label_ids) 13 | class_weights = {class_label: 1.0 / count for class_label, count in class_counts.items()} 14 | sample_weights = [class_weights[class_label] for class_label in train_label_ids] 15 | sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_label_ids), replacement=True) 16 | if args.dataset == 'MELD-DA': 17 | train_dataloader = DataLoader(data['train'], sampler = sampler, batch_size = args.train_batch_size, num_workers = args.num_workers, pin_memory = True, drop_last=True) 18 | else: 19 | train_dataloader = DataLoader(data['train'], sampler = sampler, batch_size = args.train_batch_size, num_workers = args.num_workers, pin_memory = True) 20 | 21 | else: 22 | if args.dataset == 'MELD-DA': 23 | train_dataloader = DataLoader(data['train'], shuffle=True, batch_size = args.train_batch_size, num_workers = args.num_workers, pin_memory = True, drop_last=True) 24 | else: 25 | train_dataloader = DataLoader(data['train'], shuffle=True, batch_size = args.train_batch_size, num_workers = args.num_workers, pin_memory = True) 26 | 27 | 28 | dev_dataloader = DataLoader(data['dev'], batch_size = args.eval_batch_size, num_workers = args.num_workers, pin_memory = True) 29 | 30 | test_dataloader = DataLoader(data['test'], batch_size = args.eval_batch_size, num_workers = args.num_workers, pin_memory = True) 31 | 32 | 33 | dataloader = { 34 | 'train': train_dataloader, 35 | 'dev': dev_dataloader, 36 | 'test': test_dataloader 37 | } 38 | 39 | return dataloader 40 | 41 | # def get_cycle_dataloader(args, data): 42 | 43 | # sampler_dic = {'sampler': get_sampler(), 44 | # 'num_samples_cls': 4, 'num_classes': args.num_labels} 45 | # sampler = sampler_dic['sampler'](data['train'], num_samples_cls = sampler_dic['num_samples_cls'], num_classes = sampler_dic['num_classes']) 46 | # train_dataloader = DataLoader(data['train'], sampler = sampler, batch_size = args.train_batch_size, num_workers = args.num_workers, pin_memory = True) 47 | # dev_dataloader = DataLoader(data['dev'], batch_size = args.eval_batch_size, num_workers = args.num_workers, pin_memory = True) 48 | # test_dataloader = DataLoader(data['test'], batch_size = args.eval_batch_size, num_workers = args.num_workers, pin_memory = True) 49 | 50 | # return { 51 | # 'train': train_dataloader, 52 | # 'dev': dev_dataloader, 53 | # 'test': test_dataloader 54 | # } 55 | 56 | def get_v_a_data(data_args, feats_path, max_seq_len, ood = False): 57 | 58 | if not os.path.exists(feats_path): 59 | raise Exception('Error: The directory of features is empty.') 60 | 61 | feats = load_feats(data_args, feats_path, ood) 62 | if data_args['method'] == 'spectra' and 'audio_data' in feats_path: 63 | data = spectra_audio_process(feats) 64 | else: 65 | data = padding_feats(feats, max_seq_len) 66 | return data 67 | 68 | 69 | def spectra_audio_process(feats): 70 | p_feats = {} 71 | 72 | for dataset_type in feats.keys(): 73 | f = feats[dataset_type] 74 | 75 | tmp_list = [] 76 | length_list = [] 77 | 78 | for x in f: 79 | x_f = np.array(x['input_values']) 80 | attn_mask = np.array(x['attention_mask']) 81 | p_feat = np.stack((x_f, attn_mask), axis=0) 82 | 83 | length_list.append(np.sum(attn_mask)) 84 | tmp_list.append(p_feat) 85 | 86 | p_feats[dataset_type] = { 87 | 'feats': tmp_list, 88 | 'lengths': length_list 89 | } 90 | 91 | return p_feats 92 | 93 | 94 | def load_feats(data_args, feats_path, ood): 95 | 96 | with open(feats_path, 'rb') as f: 97 | feats = pickle.load(f) 98 | 99 | if ood: 100 | test_feats = [feats[x] for x in data_args['test_data_index']] 101 | outputs = { 102 | 'test': test_feats 103 | } 104 | else: 105 | # print('22222222', feats.keys()) 106 | 107 | train_feats = [feats[x] for x in data_args['train_data_index']] 108 | dev_feats = [feats[x] for x in data_args['dev_data_index']] 109 | test_feats = [feats[x] for x in data_args['test_data_index']] 110 | 111 | outputs = { 112 | 'train': train_feats, 113 | 'dev': dev_feats, 114 | 'test': test_feats 115 | } 116 | 117 | return outputs 118 | 119 | def padding(feat, max_length, padding_mode = 'zero', padding_loc = 'end'): 120 | """ 121 | padding_mode: 'zero' or 'normal' 122 | padding_loc: 'start' or 'end' 123 | """ 124 | assert padding_mode in ['zero', 'normal'] 125 | assert padding_loc in ['start', 'end'] 126 | 127 | length = feat.shape[0] 128 | if length > max_length: 129 | return feat[:max_length, :] 130 | 131 | if padding_mode == 'zero': 132 | pad = np.zeros([max_length - length, feat.shape[-1]]) 133 | elif padding_mode == 'normal': 134 | mean, std = feat.mean(), feat.std() 135 | pad = np.random.normal(mean, std, (max_length - length, feat.shape[1])) 136 | 137 | if padding_loc == 'start': 138 | feat = np.concatenate((pad, feat), axis = 0) 139 | else: 140 | feat = np.concatenate((feat, pad), axis = 0) 141 | 142 | return feat 143 | 144 | def padding_feats(feats, max_seq_len): 145 | 146 | p_feats = {} 147 | 148 | for dataset_type in feats.keys(): 149 | f = feats[dataset_type] 150 | 151 | tmp_list = [] 152 | length_list = [] 153 | 154 | for x in f: 155 | x_f = np.array(x) 156 | x_f = x_f.squeeze(1) if x_f.ndim == 3 else x_f 157 | 158 | length_list.append(min(len(x_f), max_seq_len)) 159 | p_feat = padding(x_f, max_seq_len) 160 | tmp_list.append(p_feat) 161 | 162 | p_feats[dataset_type] = { 163 | 'feats': tmp_list, 164 | 'lengths': length_list 165 | } 166 | 167 | return p_feats 168 | -------------------------------------------------------------------------------- /backbones/SubNets/transformers_encoder/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | # Code adapted from the fairseq repo. 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | See "Attention Is All You Need" for more details. 12 | """ 13 | 14 | def __init__(self, embed_dim, num_heads, attn_dropout=0., 15 | bias=True, add_bias_kv=False, add_zero_attn=False): 16 | super().__init__() 17 | self.embed_dim = embed_dim 18 | self.num_heads = num_heads 19 | self.attn_dropout = attn_dropout 20 | self.head_dim = embed_dim // num_heads 21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 22 | self.scaling = self.head_dim ** -0.5 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 25 | self.register_parameter('in_proj_bias', None) 26 | if bias: 27 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 28 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 29 | 30 | if add_bias_kv: 31 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 32 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 33 | else: 34 | self.bias_k = self.bias_v = None 35 | 36 | self.add_zero_attn = add_zero_attn 37 | 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | nn.init.xavier_uniform_(self.in_proj_weight) 42 | nn.init.xavier_uniform_(self.out_proj.weight) 43 | if self.in_proj_bias is not None: 44 | nn.init.constant_(self.in_proj_bias, 0.) 45 | nn.init.constant_(self.out_proj.bias, 0.) 46 | if self.bias_k is not None: 47 | nn.init.xavier_normal_(self.bias_k) 48 | if self.bias_v is not None: 49 | nn.init.xavier_normal_(self.bias_v) 50 | 51 | def forward(self, query, key, value, attn_mask=None): 52 | """Input shape: Time x Batch x Channel 53 | Self-attention can be implemented by passing in the same arguments for 54 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 55 | `attn_mask` argument. Padding elements can be excluded from 56 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 57 | batch x src_len, where padding elements are indicated by 1s. 58 | """ 59 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 60 | kv_same = key.data_ptr() == value.data_ptr() 61 | 62 | tgt_len, bsz, embed_dim = query.size() 63 | assert embed_dim == self.embed_dim 64 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 65 | assert key.size() == value.size() 66 | 67 | aved_state = None 68 | 69 | if qkv_same: 70 | # self-attention 71 | q, k, v = self.in_proj_qkv(query) 72 | elif kv_same: 73 | # encoder-decoder attention 74 | q = self.in_proj_q(query) 75 | 76 | if key is None: 77 | assert value is None 78 | k = v = None 79 | else: 80 | k, v = self.in_proj_kv(key) 81 | else: 82 | q = self.in_proj_q(query) 83 | k = self.in_proj_k(key) 84 | v = self.in_proj_v(value) 85 | q = q * self.scaling 86 | 87 | if self.bias_k is not None: 88 | assert self.bias_v is not None 89 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 90 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 91 | if attn_mask is not None: 92 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 93 | 94 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 95 | if k is not None: 96 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 97 | if v is not None: 98 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 99 | 100 | src_len = k.size(1) 101 | 102 | if self.add_zero_attn: 103 | src_len += 1 104 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 105 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 106 | if attn_mask is not None: 107 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 108 | 109 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 110 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 111 | 112 | if attn_mask is not None: 113 | try: 114 | attn_weights += attn_mask.unsqueeze(0) 115 | except: 116 | print(attn_weights.shape) 117 | print(attn_mask.unsqueeze(0).shape) 118 | assert False 119 | 120 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 121 | # attn_weights = F.relu(attn_weights) 122 | # attn_weights = attn_weights / torch.max(attn_weights) 123 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 124 | 125 | attn = torch.bmm(attn_weights, v) 126 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 127 | 128 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 129 | attn = self.out_proj(attn) 130 | 131 | # average attention weights over heads 132 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 133 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 134 | return attn, attn_weights 135 | 136 | def in_proj_qkv(self, query): 137 | return self._in_proj(query).chunk(3, dim=-1) 138 | 139 | def in_proj_kv(self, key): 140 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 141 | 142 | def in_proj_q(self, query, **kwargs): 143 | return self._in_proj(query, end=self.embed_dim, **kwargs) 144 | 145 | def in_proj_k(self, key): 146 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 147 | 148 | def in_proj_v(self, value): 149 | return self._in_proj(value, start=2 * self.embed_dim) 150 | 151 | def _in_proj(self, input, start=0, end=None, **kwargs): 152 | weight = kwargs.get('weight', self.in_proj_weight) 153 | bias = kwargs.get('bias', self.in_proj_bias) 154 | weight = weight[start:end, :] 155 | if bias is not None: 156 | bias = bias[start:end] 157 | return F.linear(input, weight, bias) -------------------------------------------------------------------------------- /backbones/FusionNets/MULT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ..SubNets.FeatureNets import BERTEncoder 4 | from ..SubNets.transformers_encoder.transformer import TransformerEncoder 5 | from torch import nn 6 | from ..SubNets import text_backbones_map 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | __all__ = ['MULT'] 11 | 12 | class MULT(nn.Module): 13 | 14 | def __init__(self, args): 15 | 16 | super(MULT, self).__init__() 17 | 18 | text_backbone = text_backbones_map[args.text_backbone] 19 | 20 | self.text_subnet = text_backbone(args) 21 | 22 | video_feat_dim = args.video_feat_dim 23 | text_feat_dim = args.text_feat_dim 24 | audio_feat_dim = args.audio_feat_dim 25 | 26 | dst_feature_dims = args.dst_feature_dims 27 | 28 | self.orig_d_l, self.orig_d_a, self.orig_d_v = text_feat_dim, audio_feat_dim, video_feat_dim 29 | self.d_l = self.d_a = self.d_v = dst_feature_dims 30 | 31 | self.num_heads = args.nheads 32 | self.layers = args.n_levels 33 | self.attn_dropout = args.attn_dropout 34 | self.attn_dropout_a = args.attn_dropout_a 35 | self.attn_dropout_v = args.attn_dropout_v 36 | 37 | self.relu_dropout = args.relu_dropout 38 | self.embed_dropout = args.embed_dropout 39 | self.res_dropout = args.res_dropout 40 | self.output_dropout = args.output_dropout 41 | self.text_dropout = args.text_dropout 42 | self.attn_mask = args.attn_mask 43 | 44 | self.combined_dim = combined_dim = 2 * (self.d_l + self.d_a + self.d_v) 45 | output_dim = args.num_labels 46 | 47 | self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False) 48 | self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=args.conv1d_kernel_size_a, padding=0, bias=False) 49 | self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=args.conv1d_kernel_size_v, padding=0, bias=False) 50 | 51 | self.trans_l_with_a = self._get_network(self_type='la') 52 | self.trans_l_with_v = self._get_network(self_type='lv') 53 | 54 | self.trans_a_with_l = self._get_network(self_type='al') 55 | self.trans_a_with_v = self._get_network(self_type='av') 56 | 57 | self.trans_v_with_l = self._get_network(self_type='vl') 58 | self.trans_v_with_a = self._get_network(self_type='va') 59 | 60 | self.trans_l_mem = self._get_network(self_type='l_mem', layers=3) 61 | self.trans_a_mem = self._get_network(self_type='a_mem', layers=3) 62 | self.trans_v_mem = self._get_network(self_type='v_mem', layers=3) 63 | 64 | self.proj1 = nn.Linear(combined_dim, combined_dim) 65 | self.proj2 = nn.Linear(combined_dim, combined_dim) 66 | self.out_layer = nn.Linear(combined_dim, output_dim) 67 | 68 | 69 | def _get_network(self, self_type='l', layers=-1): 70 | 71 | if self_type in ['l', 'vl', 'al']: 72 | embed_dim, attn_dropout = self.d_l, self.attn_dropout 73 | elif self_type in ['a', 'la', 'va']: 74 | embed_dim, attn_dropout = self.d_a, self.attn_dropout_a 75 | elif self_type in ['v', 'lv', 'av']: 76 | embed_dim, attn_dropout = self.d_v, self.attn_dropout_v 77 | elif self_type == 'l_mem': 78 | embed_dim, attn_dropout = 2 * self.d_l, self.attn_dropout 79 | elif self_type == 'a_mem': 80 | embed_dim, attn_dropout = 2 * self.d_a, self.attn_dropout 81 | elif self_type == 'v_mem': 82 | embed_dim, attn_dropout = 2 * self.d_v, self.attn_dropout 83 | else: 84 | raise ValueError("Unknown network type") 85 | 86 | return TransformerEncoder(embed_dim=embed_dim, 87 | num_heads=self.num_heads, 88 | layers=max(self.layers, layers), 89 | attn_dropout=attn_dropout, 90 | relu_dropout=self.relu_dropout, 91 | res_dropout=self.res_dropout, 92 | embed_dropout=self.embed_dropout, 93 | attn_mask=self.attn_mask) 94 | 95 | def forward(self, text_feats, video_feats, audio_feats, binary_inputs = None, feature_ext = False): 96 | 97 | text = self.text_subnet(text_feats) 98 | 99 | if feature_ext: 100 | return text, video_feats, audio_feats 101 | 102 | x_l = F.dropout(text.transpose(1, 2), p=self.text_dropout, training=self.training) 103 | x_a = audio_feats.transpose(1, 2).float() 104 | x_v = video_feats.transpose(1, 2).float() 105 | 106 | proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l) 107 | proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a) 108 | proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v) 109 | 110 | proj_x_a = proj_x_a.permute(2, 0, 1) 111 | proj_x_v = proj_x_v.permute(2, 0, 1) 112 | proj_x_l = proj_x_l.permute(2, 0, 1) 113 | 114 | # (V,A) --> L 115 | h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a) 116 | h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v) # Dimension (L, N, d_l) 117 | 118 | h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim = 2) 119 | h_ls = self.trans_l_mem(h_ls) 120 | 121 | if type(h_ls) == tuple: 122 | h_ls = h_ls[0] 123 | 124 | last_h_l = last_hs = h_ls[-1] # Take the last output for prediction 125 | 126 | # (L,V) --> A 127 | h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l) 128 | h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v) 129 | h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2) 130 | h_as = self.trans_a_mem(h_as) 131 | if type(h_as) == tuple: 132 | h_as = h_as[0] 133 | last_h_a = last_hs = h_as[-1] 134 | 135 | # (L,A) --> V 136 | h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l) 137 | h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a) 138 | h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2) 139 | h_vs = self.trans_v_mem(h_vs) 140 | if type(h_vs) == tuple: 141 | h_vs = h_vs[0] 142 | last_h_v = last_hs = h_vs[-1] 143 | 144 | last_hs = torch.cat([last_h_l, last_h_a, last_h_v], dim=1) 145 | 146 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training)) 147 | last_hs_proj += last_hs 148 | 149 | 150 | logits = self.out_layer(last_hs_proj) 151 | 152 | 153 | return logits, last_hs_proj 154 | 155 | def vim(self): 156 | 157 | return self.out_layer.weight, self.out_layer.bias -------------------------------------------------------------------------------- /backbones/FusionNets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from torch import nn 5 | from scipy.spatial.distance import cdist 6 | 7 | def mixup_data(alpha=1.0): 8 | '''Return lambda''' 9 | if alpha > 0.: 10 | lam = np.random.beta(alpha, alpha) 11 | else: 12 | lam = 1. 13 | return lam 14 | 15 | class MIntOODSampler(nn.Module): 16 | 17 | def __init__(self, args): 18 | super(MIntOODSampler, self).__init__() 19 | self.ood_label_id = args.ood_label_id 20 | self.args = args 21 | 22 | def alternate_mixup(self, data1, data2): 23 | mixed_data = torch.zeros_like(data1) 24 | for i in range(data1.size(0)): 25 | if i % 2 == 0: 26 | mixed_data[i, :] = data1[i, :] 27 | else: 28 | mixed_data[i, :] = data2[i, :] 29 | return mixed_data 30 | 31 | def forward(self, ind_text_feats, ind_video_data, ind_audio_data, ind_label_ids, extended_attention_mask, attention_mask, device=None, binary = False, ood_elems = None): 32 | 33 | if binary: 34 | num_ood = int(len(ind_text_feats) * self.args.binary_multiple_ood) 35 | else: 36 | num_ood = int(len(ind_text_feats) * self.args.multiple_ood) 37 | 38 | ood_text_list, ood_video_list, ood_audio_list, ood_mask_list, ood_attention_mask_list = [], [], [], [], [] 39 | text_seq_length, video_seq_length, audio_seq_length = ind_text_feats.shape[1], ind_video_data.shape[1], ind_audio_data.shape[1] 40 | 41 | select_elems = [] 42 | 43 | if self.args.ablation_type == 'sampler_beta': 44 | 45 | while len(ood_text_list) < num_ood: 46 | 47 | cdt = np.random.choice(ind_label_ids.size(0), 2, replace=False) 48 | 49 | if len(set(ind_label_ids[cdt].tolist())) >= 2: 50 | s = mixup_data(self.args.alpha) 51 | 52 | ood_text = (s * ind_text_feats[cdt[0]] + (1 - s) * ind_text_feats[cdt[1]]) 53 | ood_video = (s * ind_video_data[cdt[0]] + (1 - s) * ind_video_data[cdt[1]]) 54 | ood_audio = (s * ind_audio_data[cdt[0]] + (1 - s) * ind_audio_data[cdt[1]]) 55 | 56 | 57 | lengths = [torch.sum(extended_attention_mask[cdt[i]]).item() for i in range(2)] 58 | idx = cdt[np.argmin(lengths)] 59 | ood_mask = extended_attention_mask[idx] 60 | ood_attention_mask = attention_mask[idx] 61 | 62 | ood_text_list.append(ood_text) 63 | ood_video_list.append(ood_video) 64 | ood_audio_list.append(ood_audio) 65 | ood_mask_list.append(ood_mask) 66 | ood_attention_mask_list.append(ood_attention_mask) 67 | 68 | 69 | select_elems.append([cdt[0], s, cdt[1]]) 70 | 71 | else: 72 | while len(ood_text_list) < num_ood: 73 | 74 | if self.args.select_number_min == self.args.select_number_max: 75 | select_number = self.args.select_number_min 76 | else: 77 | select_number = np.random.randint(self.args.select_number_min, self.args.select_number_max + 1) 78 | 79 | if ind_label_ids.size(0) >= select_number: 80 | cdt = np.random.choice(ind_label_ids.size(0), select_number, replace=False) 81 | 82 | if len(set(ind_label_ids[cdt].tolist())) >= 2: 83 | s = np.random.dirichlet(alpha=[self.args.alpha] * select_number) 84 | 85 | ood_text = sum(s[i] * ind_text_feats[cdt[i]] for i in range(select_number)) 86 | ood_video = sum(s[i] * ind_video_data[cdt[i]] for i in range(select_number)) 87 | ood_audio = sum(s[i] * ind_audio_data[cdt[i]] for i in range(select_number)) 88 | 89 | 90 | lengths = [torch.sum(extended_attention_mask[cdt[i]]).item() for i in range(select_number)] 91 | idx = cdt[np.argmin(lengths)] 92 | ood_mask = extended_attention_mask[idx] 93 | ood_attention_mask = attention_mask[idx] 94 | 95 | ood_text_list.append(ood_text) 96 | ood_video_list.append(ood_video) 97 | ood_audio_list.append(ood_audio) 98 | ood_mask_list.append(ood_mask) 99 | ood_attention_mask_list.append(ood_attention_mask) 100 | 101 | 102 | select_elems.append([cdt.tolist(), s.tolist()]) 103 | 104 | if ind_text_feats.ndim == 3: 105 | ood_text_feats = torch.cat(ood_text_list, dim = 0).view(num_ood, text_seq_length, -1) 106 | ood_mask_feats = torch.cat(ood_mask_list, dim = 0).view(num_ood, extended_attention_mask.shape[1], extended_attention_mask.shape[2], extended_attention_mask.shape[3]) 107 | ood_attention_mask_feats = torch.cat(ood_attention_mask_list, dim = 0).view(num_ood, -1) 108 | elif ind_text_feats.ndim == 2: 109 | ood_text_feats = torch.cat(ood_text_list, dim = 0).view(num_ood, -1) 110 | 111 | if ind_video_data.ndim == 3: 112 | ood_video_feats = torch.cat(ood_video_list, dim = 0).view(num_ood, video_seq_length, -1) 113 | elif ind_video_data.ndim == 2: 114 | ood_video_feats = torch.cat(ood_video_list, dim = 0).view(num_ood, -1) 115 | 116 | if ind_audio_data.ndim == 3: 117 | ood_audio_feats = torch.cat(ood_audio_list, dim = 0).view(num_ood, audio_seq_length, -1) 118 | elif ind_audio_data.ndim == 2: 119 | ood_audio_feats = torch.cat(ood_audio_list, dim = 0).view(num_ood, -1) 120 | 121 | mix_text = torch.cat((ind_text_feats, ood_text_feats), dim = 0) 122 | 123 | mix_video = torch.cat((ind_video_data, ood_video_feats), dim = 0) 124 | mix_audio = torch.cat((ind_audio_data, ood_audio_feats), dim = 0) 125 | mix_mask = torch.cat((extended_attention_mask, ood_mask_feats), dim = 0) 126 | mix_attention_mask = torch.cat((attention_mask, ood_attention_mask_feats), dim = 0) 127 | 128 | semi_label_ids = torch.cat((ind_label_ids.cpu(), torch.tensor([self.ood_label_id] * num_ood)), dim=0) 129 | binary_label_ids = torch.cat((torch.tensor([1] * len(ind_text_feats)) , torch.tensor([0] * num_ood)), dim=0) 130 | 131 | mix_data = {} 132 | mix_data['text'] = mix_text.to(device) 133 | mix_data['video'] = mix_video.to(device) 134 | mix_data['audio'] = mix_audio.to(device) 135 | mix_data['mask'] = mix_mask.to(device) 136 | mix_data['attention_mask'] = mix_attention_mask.to(device) 137 | 138 | mix_labels = { 139 | 'ind': ind_label_ids.to(device), 140 | 'semi': semi_label_ids.to(device), 141 | 'binary': binary_label_ids.to(device), 142 | 'select_elems': select_elems 143 | } 144 | 145 | return mix_data, mix_labels 146 | -------------------------------------------------------------------------------- /backbones/FusionNets/SDIF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..SubNets.FeatureNets import BERTEncoder, BertCrossEncoder, BERTEncoderSDIF 3 | from ..SubNets import text_backbones_map 4 | from torch import nn 5 | 6 | __all__ = ['SDIF'] 7 | 8 | 9 | class SDIF(nn.Module): 10 | 11 | def __init__(self, args): 12 | 13 | super(SDIF, self).__init__() 14 | self.args = args 15 | 16 | 17 | self.text_subnet = BERTEncoderSDIF.from_pretrained(args.text_pretrained_model) 18 | self.visual_size = args.video_feat_dim 19 | self.acoustic_size = args.audio_feat_dim 20 | self.text_size = args.text_feat_dim 21 | self.device = args.device 22 | self.dst_feature_dims = args.dst_feature_dims 23 | # self.capture_activation_output = capture_activation_output 24 | 25 | self.layers_cross = args.n_levels_cross 26 | self.layers_self = args.n_levels_self 27 | 28 | self.dropout_rate = args.dropout_rate 29 | self.cross_dp_rate = args.cross_dp_rate 30 | self.cross_num_heads = args.cross_num_heads 31 | self.self_num_heads = args.self_num_heads 32 | 33 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.dst_feature_dims, nhead=self.self_num_heads) 34 | self.self_att = nn.TransformerEncoder(encoder_layer, num_layers=self.layers_self) 35 | 36 | self.video2text_cross = BertCrossEncoder(self.cross_num_heads, self.dst_feature_dims, self.cross_dp_rate, n_layers=self.layers_cross) 37 | self.audio2text_cross = BertCrossEncoder(self.cross_num_heads, self.dst_feature_dims, self.cross_dp_rate, n_layers=self.layers_cross) 38 | 39 | self.v2t_project = nn.Linear(self.visual_size, self.text_size) 40 | 41 | self.mlp_project = nn.Sequential( 42 | nn.Linear(self.dst_feature_dims, self.dst_feature_dims), 43 | nn.Dropout(args.dropout_rate), 44 | nn.GELU() 45 | ) 46 | 47 | self.shallow_att_project = nn.Linear(self.dst_feature_dims, 1, bias=False) 48 | self.deep_att_project = nn.Linear(self.dst_feature_dims, 1, bias=False) 49 | 50 | self.activation = nn.ReLU() 51 | self.fusion = nn.Sequential() 52 | self.fusion.add_module('fusion_layer_1', nn.Linear(in_features=self.dst_feature_dims * 6, out_features=self.dst_feature_dims * 3)) 53 | self.fusion.add_module('fusion_layer_1_dropout', nn.Dropout(self.dropout_rate)) 54 | self.fusion.add_module('fusion_layer_1_activation', self.activation) 55 | self.fusion.add_module('fusion_layer_3', nn.Linear(in_features=self.dst_feature_dims * 3, out_features= args.num_labels)) 56 | 57 | if args.aug: 58 | self.out_layer = nn.Linear(self.dst_feature_dims, args.num_labels) 59 | self.aug_dp = nn.Dropout(args.aug_dp) 60 | 61 | def forward(self, text_feats, video_feats, audio_feats, pre_train = False): 62 | 63 | if pre_train: 64 | text_outputs = self.text_subnet(text_feats) 65 | text_seq, text_rep = text_outputs['last_hidden_state'], text_outputs['pooler_output'] 66 | text_rep = self.aug_dp(text_rep) 67 | logits = self.out_layer(text_rep) 68 | return logits 69 | 70 | # first layer : T,V,A 71 | bert_sent, bert_sent_mask, bert_sent_type = text_feats[:,0], text_feats[:,1], text_feats[:,2] 72 | text_outputs = self.text_subnet(text_feats) 73 | text_seq, text_rep = text_outputs['last_hidden_state'], text_outputs['pooler_output'] 74 | 75 | video_feats = video_feats.to(self.v2t_project.weight.dtype) 76 | video_seq = self.v2t_project(video_feats) 77 | audio_seq = audio_feats 78 | 79 | video_mask = torch.sum(video_feats.ne(torch.zeros(video_feats[0].shape[-1]).to(self.device)).int(), dim=-1)/video_feats[0].shape[-1] 80 | video_mask_len = torch.sum(video_mask, dim=1, keepdim=True) 81 | 82 | video_mask_len = torch.where(video_mask_len > 0.5, video_mask_len, torch.ones([1]).to(self.device)) 83 | video_masked_output = torch.mul(video_mask.unsqueeze(2), video_seq) 84 | video_rep = torch.sum(video_masked_output, dim=1, keepdim=False) / video_mask_len 85 | 86 | 87 | audio_mask = torch.sum(audio_feats.ne(torch.zeros(audio_feats[0].shape[-1]).to(self.device)).int(), dim=-1)/audio_feats[0].shape[-1] 88 | audio_mask_len = torch.sum(audio_mask, dim=1, keepdim=True) 89 | 90 | audio_masked_output = torch.mul(audio_mask.unsqueeze(2), audio_seq) 91 | audio_rep = torch.sum(audio_masked_output, dim=1, keepdim=False) / audio_mask_len 92 | 93 | 94 | # Second layer (V,A) --> T: V_T, A_T 95 | extended_video_mask = video_mask.unsqueeze(1).unsqueeze(2) 96 | extended_video_mask = extended_video_mask.to(dtype=next(self.parameters()).dtype) 97 | extended_video_mask = (1.0 - extended_video_mask) * -10000.0 98 | video2text_seq = self.video2text_cross(text_seq, video_seq, extended_video_mask) 99 | 100 | extended_audio_mask = audio_mask.unsqueeze(1).unsqueeze(2) 101 | extended_audio_mask = extended_audio_mask.to(dtype=next(self.parameters()).dtype) 102 | extended_audio_mask = (1.0 - extended_audio_mask) * -10000.0 103 | audio2text_seq = self.audio2text_cross(text_seq, audio_seq, extended_audio_mask) 104 | 105 | text_mask_len = torch.sum(bert_sent_mask, dim=1, keepdim=True) 106 | 107 | video2text_masked_output = torch.mul(bert_sent_mask.unsqueeze(2), video2text_seq) 108 | video2text_rep = torch.sum(video2text_masked_output, dim=1, keepdim=False) / text_mask_len 109 | 110 | 111 | audio2text_masked_output = torch.mul(bert_sent_mask.unsqueeze(2), audio2text_seq) 112 | audio2text_rep = torch.sum(audio2text_masked_output, dim=1, keepdim=False) / text_mask_len 113 | 114 | 115 | # Third layer: mlp->VAL 116 | shallow_seq = self.mlp_project(torch.cat([audio2text_seq, text_seq, video2text_seq], dim=1)) 117 | 118 | # Deep Interaction 119 | tri_cat_mask = torch.cat([bert_sent_mask, bert_sent_mask, bert_sent_mask], dim=-1) 120 | 121 | tri_mask_len = torch.sum(tri_cat_mask, dim=1, keepdim=True) 122 | shallow_masked_output = torch.mul(tri_cat_mask.unsqueeze(2), shallow_seq) 123 | shallow_rep = torch.sum(shallow_masked_output, dim=1, keepdim=False) / tri_mask_len 124 | 125 | text_rep = text_rep.to(video2text_rep.dtype) 126 | video_rep = video_rep.to(video2text_rep.dtype) 127 | audio_rep = audio_rep.to(video2text_rep.dtype) 128 | audio2text_rep = audio2text_rep.to(video2text_rep.dtype) 129 | 130 | all_reps = torch.stack((text_rep, video_rep, audio_rep, video2text_rep, audio2text_rep, shallow_rep), dim=0) 131 | all_hiddens = self.self_att(all_reps) 132 | deep_rep = torch.cat((all_hiddens[0], all_hiddens[1], all_hiddens[2], all_hiddens[3], all_hiddens[4], all_hiddens[5]), dim=1) 133 | 134 | self.text_rep = text_rep 135 | self.video_rep = video_rep 136 | self.audio_rep = audio_rep 137 | self.video2text_rep = video2text_rep 138 | self.audio2text_rep = audio2text_rep 139 | self.shallow_rep = shallow_rep 140 | 141 | 142 | logits = self.fusion(deep_rep) 143 | 144 | return logits 145 | 146 | def vim(self): 147 | 148 | return self.fusion.fusion_layer_3.weight, self.fusion.fusion_layer_3.bias 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import random 7 | import logging 8 | import copy 9 | from .metrics import Metrics, OOD_Metrics, OID_Metrics 10 | 11 | # def softmax_cross_entropy_with_softtarget(input, num_labels, device): 12 | # """ 13 | # :param input: (batch, *) 14 | # :param target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1. 15 | # """ 16 | # ood_length = input.shape[0] 17 | 18 | # ood_targets = (1. / num_labels) * torch.ones(ood_length).to(device) 19 | 20 | # logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1) 21 | # batchloss = - torch.sum(ood_targets.view(ood_targets.shape[0], -1) * logprobs, dim=1) 22 | # # batchloss = - torch.sum(ood_targets.unsqueeze(1) * logprobs, dim=1) 23 | 24 | # return torch.mean(batchloss) 25 | 26 | class EarlyStopping: 27 | """Early stops the training if validation loss doesn't improve after a given patience.""" 28 | def __init__(self, args, delta=1e-6): 29 | """ 30 | Args: 31 | patience (int): How long to wait after last time validation loss improved. 32 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 33 | """ 34 | self.patience = args.wait_patience 35 | self.logger = logging.getLogger(args.logger_name) 36 | self.monitor = args.eval_monitor 37 | self.counter = 0 38 | self.best_score = 1e8 if self.monitor == 'loss' else 1e-6 39 | self.early_stop = False 40 | self.delta = delta 41 | 42 | def __call__(self, score, model, multiclass_head=None, binary_head=None): 43 | 44 | better_flag = score <= (self.best_score - self.delta) if self.monitor == 'loss' else score >= (self.best_score + self.delta) 45 | 46 | if better_flag: 47 | self.counter = 0 48 | self.best_model = copy.deepcopy(model) 49 | self.best_score = score 50 | 51 | if multiclass_head is not None: 52 | self.best_multiclass_head = copy.deepcopy(multiclass_head) 53 | if binary_head is not None: 54 | self.best_binary_head = copy.deepcopy(binary_head) 55 | 56 | else: 57 | self.counter += 1 58 | self.logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}') 59 | # print(self.patience) 60 | # print(self.counter) 61 | if self.counter >= self.patience: 62 | self.early_stop = True 63 | 64 | def set_torch_seed(seed): 65 | random.seed(seed) 66 | np.random.seed(seed) 67 | torch.manual_seed(seed) 68 | torch.cuda.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 70 | torch.backends.cudnn.deterministic = True 71 | torch.backends.cudnn.benchmark = False 72 | os.environ['PYTHONHASHSEED'] = str(seed) 73 | 74 | def set_output_path(args, save_model_name): 75 | 76 | if not os.path.exists(args.output_path): 77 | os.makedirs(args.output_path) 78 | 79 | pred_output_path = os.path.join(args.output_path, save_model_name) 80 | if not os.path.exists(pred_output_path): 81 | os.makedirs(pred_output_path) 82 | 83 | model_path = os.path.join(pred_output_path, args.model_path) 84 | if not os.path.exists(model_path): 85 | os.makedirs(model_path) 86 | 87 | return pred_output_path, model_path 88 | 89 | def save_npy(npy_file, path, file_name): 90 | npy_path = os.path.join(path, file_name) 91 | np.save(npy_path, npy_file) 92 | 93 | def load_npy(path, file_name): 94 | npy_path = os.path.join(path, file_name) 95 | npy_file = np.load(npy_path) 96 | return npy_file 97 | 98 | def save_model(model, model_dir): 99 | 100 | save_model = model.module if hasattr(model, 'module') else model 101 | model_file = os.path.join(model_dir, 'pytorch_model.bin') 102 | 103 | torch.save(save_model.state_dict(), model_file) 104 | 105 | def restore_model(model, model_dir, device): 106 | output_model_file = os.path.join(model_dir, 'pytorch_model.bin') 107 | m = torch.load(output_model_file, map_location=device) 108 | model.load_state_dict(m, strict = False) 109 | return model 110 | 111 | def save_results(args, test_results, debug_args = None): 112 | 113 | save_keys = ['y_pred', 'y_true', 'features', 'scores'] 114 | for s_k in save_keys: 115 | if s_k in test_results.keys(): 116 | save_path = os.path.join(args.output_path, s_k + '.npy') 117 | np.save(save_path, test_results[s_k]) 118 | 119 | results = {} 120 | metrics = Metrics(args) 121 | ood_metrics = OOD_Metrics(args) 122 | oid_metrics = OID_Metrics(args) 123 | 124 | for key in metrics.eval_metrics: 125 | if key in test_results.keys(): 126 | results[key] = round(test_results[key] * 100, 2) 127 | 128 | for key in ood_metrics.eval_metrics: 129 | if key in test_results.keys(): 130 | results[key] = round(test_results[key] * 100, 2) 131 | 132 | for key in oid_metrics.eval_metrics: 133 | if key in test_results.keys(): 134 | results[key] = round(test_results[key] * 100, 2) 135 | 136 | if 'best_eval_score' in test_results: 137 | eval_key = 'eval_' + args.eval_monitor 138 | results.update({eval_key: test_results['best_eval_score']}) 139 | 140 | _vars = [args.dataset, args.ood_dataset, args.method, args.ablation_type, args.ood_detection_method, args.text_backbone, args.seed, args.log_id] 141 | _names = ['dataset', 'ood_dataset', 'method', 'ablation_type', 'ood_detection_method', 'text_backbone', 'seed', 'log_id'] 142 | 143 | if debug_args is not None: 144 | _vars.extend([args[key] for key in debug_args.keys()]) 145 | _names.extend(debug_args.keys()) 146 | 147 | vars_dict = {k:v for k,v in zip(_names, _vars)} 148 | results = dict(results,**vars_dict) 149 | 150 | keys = list(results.keys()) 151 | values = list(results.values()) 152 | 153 | if not os.path.exists(args.results_path): 154 | os.makedirs(args.results_path) 155 | 156 | results_path = os.path.join(args.results_path, args.results_file_name) 157 | 158 | if not os.path.exists(results_path) or os.path.getsize(results_path) == 0: 159 | ori = [] 160 | ori.append(values) 161 | df1 = pd.DataFrame(ori,columns = keys) 162 | df1.to_csv(results_path,index=False) 163 | else: 164 | df1 = pd.read_csv(results_path) 165 | new = pd.DataFrame(results,index=[1]) 166 | df1 = df1.append(new,ignore_index=True) 167 | df1.to_csv(results_path,index=False) 168 | data_diagram = pd.read_csv(results_path) 169 | 170 | print('test_results', data_diagram) 171 | 172 | # def get_mixup_embedding(embs, select_elems): 173 | 174 | # idx_a = [e[0] for e in select_elems] 175 | # idx_b = [e[2] for e in select_elems] 176 | # alphas = [e[1] for e in select_elems] 177 | 178 | # ood_feats = [] 179 | # ood_len = embs.shape[0] 180 | 181 | # for i in range(len(idx_a)): 182 | # feat_a = embs[idx_a[i]] 183 | # feat_b = embs[idx_b[i]] 184 | # alpha = alphas[i] 185 | # ood_feat = feat_a * alpha + (1 - alpha) * feat_b 186 | # ood_feats.append(ood_feat) 187 | 188 | # if embs.ndim == 2: 189 | # ood_feats = torch.cat(ood_feats, dim = 0).view(ood_len, -1) 190 | # else: 191 | # ood_feats = torch.cat(ood_feats, dim = 0).view(ood_len, embs.shape[1], -1) 192 | 193 | # mixed_embs = torch.cat((embs, ood_feats), dim = 0) 194 | 195 | # return mixed_embs 196 | 197 | # def get_mixup_mask(masks, select_elems): 198 | 199 | # idx_a = [e[0] for e in select_elems] 200 | # idx_b = [e[2] for e in select_elems] 201 | 202 | # ood_masks = [] 203 | # ood_len = masks.shape[0] 204 | 205 | # for i in range(len(idx_a)): 206 | # mask_a = masks[idx_a[i]] 207 | # mask_b = masks[idx_b[i]] 208 | # ood_mask = torch.max(mask_a, mask_b) 209 | # ood_masks.append(ood_mask) 210 | 211 | # ood_masks = torch.stack(ood_masks, dim=0).view(ood_len, masks.shape[1], masks.shape[2], masks.shape[3]) 212 | # mixed_masks = torch.cat((masks, ood_masks), dim=0) 213 | 214 | # return mixed_masks 215 | -------------------------------------------------------------------------------- /backbones/FusionNets/AlignNets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['CTCModule', 'AlignSubNet', 'SimModule'] 7 | 8 | class CTCModule(nn.Module): 9 | def __init__(self, in_dim, out_seq_len, args): 10 | ''' 11 | This module is performing alignment from A (e.g., audio) to B (e.g., text). 12 | :param in_dim: Dimension for input modality A 13 | :param out_seq_len: Sequence length for output modality B 14 | From: https://github.com/yaohungt/Multimodal-Transformer 15 | ''' 16 | super(CTCModule, self).__init__() 17 | # Use LSTM for predicting the position from A to B 18 | self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank 19 | self.out_seq_len = out_seq_len 20 | 21 | self.softmax = nn.Softmax(dim=2) 22 | 23 | def forward(self, x): 24 | ''' 25 | :input x: Input with shape [batch_size x in_seq_len x in_dim] 26 | ''' 27 | # NOTE that the index 0 refers to blank. 28 | 29 | pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x) 30 | 31 | prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1 32 | prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len 33 | prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len 34 | pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim 35 | 36 | # pseudo_aligned_out is regarded as the aligned A (w.r.t B) 37 | # return pseudo_aligned_out, (pred_output_position_inclu_blank) 38 | return pseudo_aligned_out 39 | 40 | # similarity-based modality alignment 41 | class SimModule(nn.Module): 42 | def __init__(self, in_dim_x, in_dim_y, shared_dim, out_seq_len, args): 43 | ''' 44 | This module is performing alignment from A (e.g., audio) to B (e.g., text). 45 | :param in_dim: Dimension for input modality A 46 | :param out_seq_len: Sequence length for output modality B 47 | ''' 48 | super(SimModule, self).__init__() 49 | # Use LSTM for predicting the position from A to B 50 | self.ctc = CTCModule(in_dim_x, out_seq_len, args) 51 | self.eps = args.eps 52 | 53 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 54 | self.proj_x = nn.Linear(in_features=in_dim_x, out_features=shared_dim) 55 | self.proj_y = nn.Linear(in_features=in_dim_y, out_features=shared_dim) 56 | 57 | self.fc1 = nn.Linear(in_features=out_seq_len, out_features=round(out_seq_len / 2)) 58 | self.fc2 = nn.Linear(in_features=round(out_seq_len / 2), out_features=out_seq_len) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | def forward(self, x, y): 63 | ''' 64 | :input x: Input with shape [batch_size x in_seq_len x in_dim] 65 | ''' 66 | 67 | pseudo_aligned_out = self.ctc(x) 68 | 69 | x_common = self.proj_x(pseudo_aligned_out) 70 | x_n = x_common.norm(dim=-1, keepdim=True) 71 | x_norm = x_common / torch.max(x_n, self.eps * torch.ones_like(x_n)) 72 | 73 | y_common = self.proj_y(y) 74 | y_n = y_common.norm(dim=-1, keepdim=True) 75 | y_norm = y_common / torch.max(y_n, self.eps * torch.ones_like(y_n)) 76 | 77 | # cosine similarity as logits 78 | logit_scale = self.logit_scale.exp() 79 | similarity_matrix = logit_scale * torch.bmm(y_norm, x_norm.permute(0, 2, 1)) 80 | 81 | logits = similarity_matrix.softmax(dim=-1) 82 | logits = self.fc1(logits) 83 | logits = self.relu(logits) 84 | logits = self.fc2(logits) 85 | logits = self.sigmoid(logits) 86 | 87 | aligned_out = torch.bmm(logits, pseudo_aligned_out) 88 | 89 | return aligned_out 90 | 91 | 92 | 93 | class AlignSubNet(nn.Module): 94 | def __init__(self, args, mode): 95 | """ 96 | mode: the way of aligning 97 | avg_pool, ctc, conv1d 98 | """ 99 | super(AlignSubNet, self).__init__() 100 | assert mode in ['avg_pool', 'ctc', 'conv1d', 'sim'] 101 | 102 | in_dim_t, in_dim_v, in_dim_a = args.text_feat_dim, args.video_feat_dim, args.audio_feat_dim 103 | 104 | seq_len_t, seq_len_v, seq_len_a = args.max_cons_seq_length, args.video_seq_len, args.audio_seq_len 105 | self.dst_len = seq_len_t 106 | self.dst_dim = in_dim_t 107 | self.mode = mode 108 | 109 | self.ALIGN_WAY = { 110 | 'avg_pool': self.__avg_pool, 111 | 'ctc': self.__ctc, 112 | 'conv1d': self.__conv1d, 113 | 'sim': self.__sim, 114 | } 115 | 116 | if mode == 'conv1d': 117 | self.conv1d_t = nn.Conv1d(seq_len_t, self.dst_len, kernel_size=1, bias=False) 118 | self.conv1d_v = nn.Conv1d(seq_len_v, self.dst_len, kernel_size=1, bias=False) 119 | self.conv1d_a = nn.Conv1d(seq_len_a, self.dst_len, kernel_size=1, bias=False) 120 | elif mode == 'ctc': 121 | self.ctc_t = CTCModule(in_dim_t, self.dst_len, args) 122 | self.ctc_v = CTCModule(in_dim_v, self.dst_len, args) 123 | self.ctc_a = CTCModule(in_dim_a, self.dst_len, args) 124 | elif mode == 'sim': 125 | self.shared_dim = args.shared_dim 126 | self.sim_t = SimModule(in_dim_t, self.dst_dim, self.shared_dim, self.dst_len, args) 127 | self.sim_v = SimModule(in_dim_v, self.dst_dim, self.shared_dim, self.dst_len, args) 128 | self.sim_a = SimModule(in_dim_a, self.dst_dim, self.shared_dim, self.dst_len, args) 129 | 130 | def get_seq_len(self): 131 | return self.dst_len 132 | 133 | def __ctc(self, text_x, video_x, audio_x): 134 | text_x = self.ctc_t(text_x) if text_x.size(1) != self.dst_len else text_x 135 | video_x = self.ctc_v(video_x) if video_x.size(1) != self.dst_len else video_x 136 | audio_x = self.ctc_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x 137 | return text_x, video_x, audio_x 138 | 139 | def __avg_pool(self, text_x, video_x, audio_x): 140 | def align(x): 141 | raw_seq_len = x.size(1) 142 | if raw_seq_len == self.dst_len: 143 | return x 144 | if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len: 145 | pad_len = 0 146 | pool_size = raw_seq_len // self.dst_len 147 | else: 148 | pad_len = self.dst_len - raw_seq_len % self.dst_len 149 | pool_size = raw_seq_len // self.dst_len + 1 150 | pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)]) 151 | x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1) 152 | x = x.mean(dim=1) 153 | return x 154 | text_x = align(text_x) 155 | video_x = align(video_x) 156 | audio_x = align(audio_x) 157 | return text_x, video_x, audio_x 158 | 159 | def __conv1d(self, text_x, video_x, audio_x): 160 | text_x = self.conv1d_t(text_x) if text_x.size(1) != self.dst_len else text_x 161 | video_x = self.conv1d_v(video_x) if video_x.size(1) != self.dst_len else video_x 162 | audio_x = self.conv1d_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x 163 | return text_x, video_x, audio_x 164 | 165 | def __sim(self, text_x, video_x, audio_x): 166 | 167 | text_x = self.sim_t(text_x, text_x) if text_x.size(1) != self.dst_len else text_x 168 | video_x = self.sim_v(video_x, text_x) if video_x.size(1) != self.dst_len else video_x 169 | audio_x = self.sim_a(audio_x, text_x) if audio_x.size(1) != self.dst_len else audio_x 170 | return text_x, video_x, audio_x 171 | 172 | def forward(self, text_x, video_x, audio_x): 173 | # already aligned 174 | if text_x.size(1) == video_x.size(1) and text_x.size(1) == audio_x.size(1): 175 | return text_x, video_x, audio_x 176 | return self.ALIGN_WAY[self.mode](text_x, video_x, audio_x) 177 | -------------------------------------------------------------------------------- /methods/MULT/manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | from torch import nn 5 | from utils.functions import restore_model, save_model, EarlyStopping 6 | from tqdm import trange, tqdm 7 | from utils.metrics import AverageMeter, Metrics, OOD_Metrics, OID_Metrics 8 | from data.utils import get_dataloader 9 | from ood_detection import ood_detection_map 10 | import numpy as np 11 | from torch import optim 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | import pandas as pd 14 | 15 | 16 | class MULT: 17 | 18 | def __init__(self, args, data, model): 19 | 20 | self.logger = logging.getLogger(args.logger_name) 21 | 22 | self.device = model.device 23 | self.model = model._set_model(args) 24 | 25 | self.optimizer = optim.Adam(self.model.parameters(), lr = args.lr) 26 | self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, verbose=True, patience=args.wait_patience) 27 | 28 | mm_data = data.data 29 | mm_dataloader = get_dataloader(args, mm_data) 30 | 31 | self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ 32 | mm_dataloader['train'], mm_dataloader['dev'], mm_dataloader['test'] 33 | 34 | self.args = args 35 | self.criterion = nn.CrossEntropyLoss() 36 | self.metrics = Metrics(args) 37 | self.oid_metrics = OID_Metrics(args) 38 | self.ood_metrics = OOD_Metrics(args) 39 | self.ood_detection_func = ood_detection_map[args.ood_detection_method] 40 | 41 | if args.train: 42 | self.best_eval_score = 0 43 | else: 44 | self.model = restore_model(self.model, args.model_output_path, self.device) 45 | 46 | def _train(self, args): 47 | 48 | early_stopping = EarlyStopping(args) 49 | 50 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 51 | self.model.train() 52 | loss_record = AverageMeter() 53 | 54 | for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")): 55 | 56 | text_feats = batch['text_feats'].to(self.device) 57 | video_feats = batch['video_feats'].to(self.device) 58 | audio_feats = batch['audio_feats'].to(self.device) 59 | label_ids = batch['label_ids'].to(self.device) 60 | 61 | with torch.set_grad_enabled(True): 62 | 63 | preds, last_hiddens = self.model(text_feats, video_feats, audio_feats) 64 | 65 | loss = self.criterion(preds, label_ids) 66 | 67 | self.optimizer.zero_grad() 68 | 69 | loss.backward() 70 | loss_record.update(loss.item(), label_ids.size(0)) 71 | 72 | if args.grad_clip != -1.0: 73 | nn.utils.clip_grad_value_([param for param in self.model.parameters() if param.requires_grad], args.grad_clip) 74 | 75 | self.optimizer.step() 76 | 77 | outputs = self._get_outputs(args, mode = 'eval') 78 | self.scheduler.step(outputs['loss']) 79 | eval_score = outputs[args.eval_monitor] 80 | 81 | eval_results = { 82 | 'train_loss': round(loss_record.avg, 4), 83 | 'best_eval_score': round(early_stopping.best_score, 4), 84 | 'eval_score': round(eval_score, 4), 85 | } 86 | 87 | self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1)) 88 | for key in sorted(eval_results.keys()): 89 | self.logger.info(" %s = %s", key, str(eval_results[key])) 90 | 91 | early_stopping(eval_score, self.model) 92 | 93 | if early_stopping.early_stop: 94 | self.logger.info(f'EarlyStopping at epoch {epoch + 1}') 95 | break 96 | 97 | self.best_eval_score = early_stopping.best_score 98 | self.model = early_stopping.best_model 99 | 100 | if args.save_model: 101 | self.logger.info('Trained models are saved in %s', args.model_output_path) 102 | save_model(self.model, args.model_output_path) 103 | 104 | def _get_outputs(self, args, mode = 'eval', show_results = False ,test_ind = False): 105 | 106 | if mode == 'eval': 107 | dataloader = self.eval_dataloader 108 | elif mode == 'test': 109 | dataloader = self.test_dataloader 110 | elif mode == 'train': 111 | dataloader = self.train_dataloader 112 | 113 | self.model.eval() 114 | 115 | total_labels = torch.empty(0,dtype=torch.long).to(self.device) 116 | total_preds = torch.empty(0,dtype=torch.long).to(self.device) 117 | total_features = torch.empty((0, self.model.model.combined_dim)).to(self.device) 118 | total_logits = torch.empty((0, args.num_labels)).to(self.device) 119 | 120 | loss_record = AverageMeter() 121 | 122 | for batch in tqdm(dataloader, desc="Iteration"): 123 | 124 | text_feats = batch['text_feats'].to(self.device) 125 | video_feats = batch['video_feats'].to(self.device) 126 | audio_feats = batch['audio_feats'].to(self.device) 127 | label_ids = batch['label_ids'].to(self.device) 128 | 129 | with torch.set_grad_enabled(False): 130 | 131 | logits, last_hiddens = self.model(text_feats, video_feats, audio_feats) 132 | 133 | total_logits = torch.cat((total_logits, logits)) 134 | total_features = torch.cat((total_features, last_hiddens)) 135 | total_labels = torch.cat((total_labels, label_ids)) 136 | 137 | if mode == 'eval': 138 | loss = self.criterion(logits, label_ids) 139 | loss_record.update(loss.item(), label_ids.size(0)) 140 | 141 | total_probs = F.softmax(total_logits.detach(), dim=1) 142 | total_maxprobs, total_preds = total_probs.max(dim = 1) 143 | 144 | y_logit = total_logits.cpu().numpy() 145 | y_pred = total_preds.cpu().numpy() 146 | y_true = total_labels.cpu().numpy() 147 | y_feat = total_features.cpu().numpy() 148 | y_prob = total_maxprobs.cpu().numpy() 149 | 150 | outputs = self.metrics(y_true, y_pred, show_results=show_results) 151 | 152 | if test_ind: 153 | outputs = self.metrics(y_true[y_true != args.ood_label_id], y_pred[y_true != args.ood_label_id]) 154 | else: 155 | outputs = self.metrics(y_true, y_pred, show_results = show_results) 156 | 157 | if mode == 'eval': 158 | outputs.update({'loss': loss_record.avg}) 159 | 160 | outputs.update( 161 | { 162 | 'y_prob': y_prob, 163 | 'y_logit': y_logit, 164 | 'y_true': y_true, 165 | 'y_pred': y_pred, 166 | 'y_feat': y_feat 167 | } 168 | ) 169 | 170 | return outputs 171 | 172 | 173 | def _test(self, args): 174 | 175 | test_results = {} 176 | 177 | ind_test_results = self._get_outputs(args, mode = 'test', show_results = True, test_ind = True) 178 | if args.train: 179 | test_results['best_eval_score'] = round(self.best_eval_score, 4) 180 | test_results.update(ind_test_results) 181 | 182 | if args.ood: 183 | 184 | tmp_outputs = self._get_outputs(args, mode = 'test') 185 | if args.ood_detection_method in ['residual', 'ma', 'vim']: 186 | ind_train_outputs = self._get_outputs(args, mode = 'train') 187 | 188 | tmp_outputs['train_feats'] = ind_train_outputs['y_feat'] 189 | tmp_outputs['train_labels'] = ind_train_outputs['y_true'] 190 | 191 | w, b = self.model.vim() 192 | tmp_outputs['w'] = w 193 | tmp_outputs['b'] = b 194 | 195 | scores = self.ood_detection_func(args, tmp_outputs) 196 | binary_labels = np.array([1 if x != args.ood_label_id else 0 for x in tmp_outputs['y_true']]) 197 | 198 | ood_test_scores = self.ood_metrics(scores, binary_labels, show_results = True) 199 | test_results.update(ood_test_scores) 200 | return test_results -------------------------------------------------------------------------------- /methods/SDIF/manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | import numpy as np 5 | from torch import nn 6 | from utils.functions import restore_model, save_model, EarlyStopping 7 | from tqdm import trange, tqdm 8 | from data.utils import get_dataloader 9 | from utils.metrics import AverageMeter, Metrics, OOD_Metrics 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from torch import optim 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from ood_detection import ood_detection_map 14 | 15 | 16 | class SDIF: 17 | 18 | def __init__(self, args, data, model): 19 | 20 | self.logger = logging.getLogger(args.logger_name) 21 | 22 | self.device = model.device 23 | self.model = model._set_model(args) 24 | 25 | mm_data = data.data 26 | mm_dataloader = get_dataloader(args, mm_data) 27 | 28 | self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ 29 | mm_dataloader['train'], mm_dataloader['dev'], mm_dataloader['test'] 30 | 31 | self.args = args 32 | self.criterion = nn.CrossEntropyLoss() 33 | self.metrics = Metrics(args) 34 | self.ood_metrics = OOD_Metrics(args) 35 | self.ood_detection_func = ood_detection_map[args.ood_detection_method] 36 | 37 | if args.train: 38 | self.best_eval_score = 0 39 | else: 40 | self.model = restore_model(self.model, args.model_output_path, self.device) 41 | 42 | def _train(self, args): 43 | 44 | early_stopping = EarlyStopping(args) 45 | self.optimizer = AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 46 | self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=args.factor, verbose=True, patience=args.opt_patience) 47 | 48 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 49 | self.model.train() 50 | loss_record = AverageMeter() 51 | 52 | for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")): 53 | 54 | text_feats = batch['text_feats'].to(self.device) 55 | video_feats = batch['video_feats'].to(self.device) 56 | audio_feats = batch['audio_feats'].to(self.device) 57 | label_ids = batch['label_ids'].to(self.device) 58 | 59 | with torch.set_grad_enabled(True): 60 | 61 | logits = self.model(text_feats, video_feats, audio_feats) 62 | 63 | loss = self.criterion(logits, label_ids) 64 | 65 | self.optimizer.zero_grad() 66 | 67 | loss.backward() 68 | loss_record.update(loss.item(), label_ids.size(0)) 69 | if args.grad_clip != -1.0: 70 | nn.utils.clip_grad_value_([param for param in self.model.parameters() if param.requires_grad], args.grad_clip) 71 | 72 | 73 | self.optimizer.step() 74 | # self.scheduler.step() 75 | 76 | outputs = self._get_outputs(args, self.eval_dataloader) 77 | eval_score = outputs[args.eval_monitor] 78 | self.scheduler.step(outputs['loss']) 79 | 80 | eval_results = { 81 | 'train_loss': round(loss_record.avg, 4), 82 | 'eval_score': round(eval_score, 4), 83 | 'best_eval_score': round(early_stopping.best_score, 4), 84 | } 85 | 86 | self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1)) 87 | for key in eval_results.keys(): 88 | self.logger.info(" %s = %s", key, str(eval_results[key])) 89 | 90 | early_stopping(eval_score, self.model) 91 | 92 | if early_stopping.early_stop: 93 | self.logger.info(f'EarlyStopping at epoch {epoch + 1}') 94 | break 95 | 96 | self.best_eval_score = early_stopping.best_score 97 | self.model = early_stopping.best_model 98 | 99 | if args.save_model: 100 | self.logger.info('Trained models are saved in %s', args.model_output_path) 101 | save_model(self.model, args.model_output_path) 102 | 103 | def _get_outputs(self, args, dataloader, show_results = False, test_ind = False, flag_test=False): 104 | 105 | 106 | def capture_activation_output(module, input, output): 107 | global activation_output 108 | activation_output = output 109 | 110 | self.model.eval() 111 | 112 | total_labels = torch.empty(0,dtype=torch.long).to(self.device) 113 | total_preds = torch.empty(0,dtype=torch.long).to(self.device) 114 | total_logits = torch.empty((0, args.num_labels)).to(self.device) 115 | total_features = torch.empty((0, args.dst_feature_dims * 3)).to(self.device) 116 | 117 | loss_record = AverageMeter() 118 | hook_handle = self.model.model.fusion._modules['fusion_layer_1_activation'].register_forward_hook(capture_activation_output) 119 | 120 | for batch in tqdm(dataloader, desc="Iteration"): 121 | 122 | text_feats = batch['text_feats'].to(self.device) 123 | video_feats = batch['video_feats'].to(self.device) 124 | audio_feats = batch['audio_feats'].to(self.device) 125 | label_ids = batch['label_ids'].to(self.device) 126 | 127 | with torch.set_grad_enabled(False): 128 | 129 | logits = self.model(text_feats, video_feats, audio_feats) 130 | 131 | total_logits = torch.cat((total_logits, logits)) 132 | total_labels = torch.cat((total_labels, label_ids)) 133 | total_features = torch.cat((total_features, activation_output)) 134 | if not flag_test: 135 | loss = self.criterion(logits, label_ids) 136 | loss_record.update(loss.item(), label_ids.size(0)) 137 | 138 | 139 | total_probs = F.softmax(total_logits.detach(), dim=1) 140 | total_maxprobs, total_preds = total_probs.max(dim = 1) 141 | 142 | y_logit = total_logits.cpu().numpy() 143 | y_pred = total_preds.cpu().numpy() 144 | y_true = total_labels.cpu().numpy() 145 | y_prob = total_maxprobs.cpu().numpy() 146 | y_feat = total_features.cpu().numpy() 147 | 148 | if test_ind: 149 | outputs = self.metrics(y_true[y_true != args.ood_label_id], y_pred[y_true != args.ood_label_id]) 150 | else: 151 | outputs = self.metrics(y_true, y_pred, show_results = show_results) 152 | 153 | outputs.update({'loss': loss_record.avg}) 154 | outputs.update( 155 | { 156 | 'y_prob': y_prob, 157 | 'y_logit': y_logit, 158 | 'y_true': y_true, 159 | 'y_pred': y_pred, 160 | 'y_feat': y_feat 161 | } 162 | ) 163 | 164 | return outputs 165 | 166 | def _test(self, args): 167 | 168 | test_results = {} 169 | 170 | ind_test_results = self._get_outputs(args, self.test_dataloader, show_results = True, test_ind = True, flag_test=True) 171 | if args.train: 172 | test_results['best_eval_score'] = round(self.best_eval_score, 4) 173 | test_results.update(ind_test_results) 174 | 175 | if args.ood: 176 | 177 | tmp_outputs = self._get_outputs(args, self.test_dataloader, flag_test=True) 178 | if args.ood_detection_method in ['residual', 'ma', 'vim']: 179 | ind_train_outputs = self._get_outputs(args, self.train_dataloader) 180 | 181 | tmp_outputs['train_feats'] = ind_train_outputs['y_feat'] 182 | tmp_outputs['train_labels'] = ind_train_outputs['y_true'] 183 | 184 | w, b = self.model.vim() 185 | tmp_outputs['w'] = w 186 | tmp_outputs['b'] = b 187 | 188 | scores = self.ood_detection_func(args, tmp_outputs) 189 | binary_labels = np.array([1 if x != args.ood_label_id else 0 for x in tmp_outputs['y_true']]) 190 | 191 | ood_test_scores = self.ood_metrics(scores, binary_labels, show_results = True) 192 | test_results.update(ood_test_scores) 193 | 194 | return test_results -------------------------------------------------------------------------------- /methods/MAG_BERT/manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | import numpy as np 5 | from torch import nn 6 | from utils.functions import restore_model, save_model, EarlyStopping 7 | from tqdm import trange, tqdm 8 | from data.utils import get_dataloader 9 | from utils.metrics import AverageMeter, Metrics, OOD_Metrics 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from ood_detection import ood_detection_map 12 | 13 | 14 | class MAG_BERT: 15 | 16 | def __init__(self, args, data, model): 17 | 18 | self.logger = logging.getLogger(args.logger_name) 19 | 20 | self.device = model.device 21 | self.model = model._set_model(args) 22 | self.optimizer, self.scheduler = self._set_optimizer(args, self.model) 23 | 24 | mm_data = data.data 25 | mm_dataloader = get_dataloader(args, mm_data) 26 | 27 | self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ 28 | mm_dataloader['train'], mm_dataloader['dev'], mm_dataloader['test'] 29 | 30 | self.args = args 31 | self.criterion = nn.CrossEntropyLoss() 32 | self.metrics = Metrics(args) 33 | self.ood_metrics = OOD_Metrics(args) 34 | self.ood_detection_func = ood_detection_map[args.ood_detection_method] 35 | 36 | if args.train: 37 | self.best_eval_score = 0 38 | else: 39 | self.model = restore_model(self.model, args.model_output_path, self.device) 40 | 41 | def _set_optimizer(self, args, model): 42 | 43 | param_optimizer = list(model.named_parameters()) 44 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 45 | optimizer_grouped_parameters = [ 46 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 47 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 48 | ] 49 | 50 | optimizer = AdamW(optimizer_grouped_parameters, lr = args.lr, correct_bias=False) 51 | 52 | num_train_optimization_steps = int(args.num_train_examples / args.train_batch_size) * args.num_train_epochs 53 | num_warmup_steps= int(args.num_train_examples * args.num_train_epochs * args.warmup_proportion / args.train_batch_size) 54 | 55 | scheduler = get_linear_schedule_with_warmup(optimizer, 56 | num_warmup_steps=num_warmup_steps, 57 | num_training_steps=num_train_optimization_steps) 58 | 59 | return optimizer, scheduler 60 | 61 | def _train(self, args): 62 | 63 | early_stopping = EarlyStopping(args) 64 | 65 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 66 | self.model.train() 67 | loss_record = AverageMeter() 68 | 69 | for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")): 70 | 71 | text_feats = batch['text_feats'].to(self.device) 72 | video_feats = batch['video_feats'].to(self.device) 73 | audio_feats = batch['audio_feats'].to(self.device) 74 | label_ids = batch['label_ids'].to(self.device) 75 | 76 | with torch.set_grad_enabled(True): 77 | 78 | outputs = self.model(text_feats, video_feats, audio_feats) 79 | 80 | loss = self.criterion(outputs['mm'], label_ids) 81 | 82 | self.optimizer.zero_grad() 83 | 84 | loss.backward() 85 | loss_record.update(loss.item(), label_ids.size(0)) 86 | 87 | self.optimizer.step() 88 | self.scheduler.step() 89 | 90 | outputs = self._get_outputs(args, self.eval_dataloader) 91 | eval_score = outputs[args.eval_monitor] 92 | 93 | eval_results = { 94 | 'train_loss': round(loss_record.avg, 4), 95 | 'eval_score': round(eval_score, 4), 96 | 'best_eval_score': round(early_stopping.best_score, 4), 97 | } 98 | 99 | self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1)) 100 | for key in eval_results.keys(): 101 | self.logger.info(" %s = %s", key, str(eval_results[key])) 102 | 103 | early_stopping(eval_score, self.model) 104 | 105 | if early_stopping.early_stop: 106 | self.logger.info(f'EarlyStopping at epoch {epoch + 1}') 107 | break 108 | 109 | self.best_eval_score = early_stopping.best_score 110 | self.model = early_stopping.best_model 111 | 112 | if args.save_model: 113 | self.logger.info('Trained models are saved in %s', args.model_output_path) 114 | save_model(self.model, args.model_output_path) 115 | 116 | def _get_outputs(self, args, dataloader, show_results = False, test_ind = False): 117 | 118 | self.model.eval() 119 | 120 | total_labels = torch.empty(0,dtype=torch.long).to(self.device) 121 | total_preds = torch.empty(0,dtype=torch.long).to(self.device) 122 | total_logits = torch.empty((0, args.num_labels)).to(self.device) 123 | total_features = torch.empty((0, args.feat_size)).to(self.device) 124 | 125 | loss_record = AverageMeter() 126 | 127 | for batch in tqdm(dataloader, desc="Iteration"): 128 | 129 | text_feats = batch['text_feats'].to(self.device) 130 | video_feats = batch['video_feats'].to(self.device) 131 | audio_feats = batch['audio_feats'].to(self.device) 132 | label_ids = batch['label_ids'].to(self.device) 133 | 134 | with torch.set_grad_enabled(False): 135 | 136 | outputs = self.model(text_feats, video_feats, audio_feats) 137 | logits, features = outputs['mm'], outputs['h'] 138 | 139 | total_logits = torch.cat((total_logits, logits)) 140 | total_labels = torch.cat((total_labels, label_ids)) 141 | total_features = torch.cat((total_features, features)) 142 | 143 | 144 | total_probs = F.softmax(total_logits.detach(), dim=1) 145 | total_maxprobs, total_preds = total_probs.max(dim = 1) 146 | 147 | y_logit = total_logits.cpu().numpy() 148 | y_pred = total_preds.cpu().numpy() 149 | y_true = total_labels.cpu().numpy() 150 | y_prob = total_maxprobs.cpu().numpy() 151 | y_feat = total_features.cpu().numpy() 152 | 153 | if test_ind: 154 | outputs = self.metrics(y_true[y_true != args.ood_label_id], y_pred[y_true != args.ood_label_id]) 155 | else: 156 | outputs = self.metrics(y_true, y_pred, show_results = show_results) 157 | 158 | outputs.update( 159 | { 160 | 'y_prob': y_prob, 161 | 'y_logit': y_logit, 162 | 'y_true': y_true, 163 | 'y_pred': y_pred, 164 | 'y_feat': y_feat 165 | } 166 | ) 167 | 168 | return outputs 169 | 170 | def _test(self, args): 171 | 172 | test_results = {} 173 | 174 | ind_test_results = self._get_outputs(args, self.test_dataloader, show_results = True, test_ind = True) 175 | if args.train: 176 | test_results['best_eval_score'] = round(self.best_eval_score, 4) 177 | test_results.update(ind_test_results) 178 | 179 | if args.ood: 180 | 181 | tmp_outputs = self._get_outputs(args, self.test_dataloader) 182 | if args.ood_detection_method in ['residual', 'ma', 'vim']: 183 | ind_train_outputs = self._get_outputs(args, self.train_dataloader) 184 | 185 | tmp_outputs['train_feats'] = ind_train_outputs['y_feat'] 186 | tmp_outputs['train_labels'] = ind_train_outputs['y_true'] 187 | 188 | w, b = self.model.vim() 189 | tmp_outputs['w'] = w 190 | tmp_outputs['b'] = b 191 | 192 | scores = self.ood_detection_func(args, tmp_outputs) 193 | binary_labels = np.array([1 if x != args.ood_label_id else 0 for x in tmp_outputs['y_true']]) 194 | 195 | ood_test_scores = self.ood_metrics(scores, binary_labels, show_results = True) 196 | test_results.update(ood_test_scores) 197 | 198 | return test_results -------------------------------------------------------------------------------- /methods/Spectra/manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | import numpy as np 5 | from torch import nn 6 | from utils.functions import restore_model, save_model, EarlyStopping 7 | from tqdm import trange, tqdm 8 | from data.utils import get_dataloader 9 | from utils.metrics import AverageMeter, Metrics, OOD_Metrics 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from ood_detection import ood_detection_map 12 | 13 | 14 | class Spectra: 15 | 16 | def __init__(self, args, data, model): 17 | 18 | self.logger = logging.getLogger(args.logger_name) 19 | 20 | self.device = model.device 21 | self.model = model._set_model(args) 22 | self.optimizer, self.scheduler = self._set_optimizer(args, self.model) 23 | 24 | mm_data = data.data 25 | mm_dataloader = get_dataloader(args, mm_data) 26 | 27 | self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ 28 | mm_dataloader['train'], mm_dataloader['dev'], mm_dataloader['test'] 29 | 30 | self.args = args 31 | self.criterion = nn.CrossEntropyLoss() 32 | self.metrics = Metrics(args) 33 | self.ood_metrics = OOD_Metrics(args) 34 | self.ood_detection_func = ood_detection_map[args.ood_detection_method] 35 | 36 | if args.train: 37 | self.best_eval_score = 0 38 | else: 39 | self.model = restore_model(self.model, args.model_output_path, self.device) 40 | 41 | def _set_optimizer(self, args, model): 42 | 43 | param_optimizer = list(model.named_parameters()) 44 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 45 | optimizer_grouped_parameters = [ 46 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 47 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 48 | ] 49 | 50 | optimizer = AdamW(optimizer_grouped_parameters, lr = args.lr, correct_bias=False) 51 | 52 | num_train_optimization_steps = int(args.num_train_examples / args.train_batch_size) * args.num_train_epochs 53 | num_warmup_steps= int(args.num_train_examples * args.num_train_epochs * args.warmup_proportion / args.train_batch_size) 54 | 55 | scheduler = get_linear_schedule_with_warmup(optimizer, 56 | num_warmup_steps=num_warmup_steps, 57 | num_training_steps=num_train_optimization_steps) 58 | 59 | return optimizer, scheduler 60 | 61 | def _train(self, args): 62 | 63 | early_stopping = EarlyStopping(args) 64 | 65 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 66 | self.model.train() 67 | loss_record = AverageMeter() 68 | 69 | for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")): 70 | 71 | text_feats = batch['text_feats'].to(self.device) 72 | video_feats = batch['video_feats'] 73 | audio_feats = batch['audio_feats'].to(self.device).squeeze(2) 74 | 75 | label_ids = batch['label_ids'].to(self.device) 76 | 77 | with torch.set_grad_enabled(True): 78 | 79 | logits, feats = self.model(text_feats, video_feats, audio_feats) 80 | 81 | loss = self.criterion(logits, label_ids) 82 | 83 | self.optimizer.zero_grad() 84 | 85 | loss.backward() 86 | loss_record.update(loss.item(), label_ids.size(0)) 87 | 88 | self.optimizer.step() 89 | self.scheduler.step() 90 | 91 | 92 | outputs = self._get_outputs(args, self.eval_dataloader) 93 | eval_score = outputs[args.eval_monitor] 94 | 95 | eval_results = { 96 | 'train_loss': round(loss_record.avg, 4), 97 | 'eval_score': round(eval_score, 4), 98 | 'best_eval_score': round(early_stopping.best_score, 4), 99 | } 100 | 101 | self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1)) 102 | for key in eval_results.keys(): 103 | self.logger.info(" %s = %s", key, str(eval_results[key])) 104 | 105 | early_stopping(eval_score, self.model) 106 | 107 | if early_stopping.early_stop: 108 | self.logger.info(f'EarlyStopping at epoch {epoch + 1}') 109 | break 110 | 111 | self.best_eval_score = early_stopping.best_score 112 | self.model = early_stopping.best_model 113 | 114 | if args.save_model: 115 | self.logger.info('Trained models are saved in %s', args.model_output_path) 116 | save_model(self.model, args.model_output_path) 117 | 118 | def _get_outputs(self, args, dataloader, show_results = False, test_ind = False): 119 | 120 | self.model.eval() 121 | 122 | args.feat_size = 768 123 | total_labels = torch.empty(0,dtype=torch.long).to(self.device) 124 | total_preds = torch.empty(0,dtype=torch.long).to(self.device) 125 | total_logits = torch.empty((0, args.num_labels)).to(self.device) 126 | total_features = torch.empty((0, args.feat_size)).to(self.device) 127 | 128 | loss_record = AverageMeter() 129 | 130 | for batch in tqdm(dataloader, desc="Iteration"): 131 | 132 | text_feats = batch['text_feats'].to(self.device) 133 | video_feats = batch['video_feats'] 134 | audio_feats = batch['audio_feats'].to(self.device).squeeze(2) 135 | label_ids = batch['label_ids'].to(self.device) 136 | 137 | with torch.set_grad_enabled(False): 138 | 139 | logits, features = self.model(text_feats, video_feats, audio_feats) 140 | 141 | total_logits = torch.cat((total_logits, logits)) 142 | total_labels = torch.cat((total_labels, label_ids)) 143 | total_features = torch.cat((total_features, features)) 144 | 145 | 146 | total_probs = F.softmax(total_logits.detach(), dim=1) 147 | total_maxprobs, total_preds = total_probs.max(dim = 1) 148 | 149 | y_logit = total_logits.cpu().numpy() 150 | y_pred = total_preds.cpu().numpy() 151 | y_true = total_labels.cpu().numpy() 152 | y_prob = total_maxprobs.cpu().numpy() 153 | y_feat = total_features.cpu().numpy() 154 | 155 | if test_ind: 156 | outputs = self.metrics(y_true[y_true != args.ood_label_id], y_pred[y_true != args.ood_label_id]) 157 | else: 158 | outputs = self.metrics(y_true, y_pred, show_results = show_results) 159 | 160 | outputs.update( 161 | { 162 | 'y_prob': y_prob, 163 | 'y_logit': y_logit, 164 | 'y_true': y_true, 165 | 'y_pred': y_pred, 166 | 'y_feat': y_feat 167 | } 168 | ) 169 | 170 | return outputs 171 | 172 | def _test(self, args): 173 | 174 | test_results = {} 175 | 176 | ind_test_results = self._get_outputs(args, self.test_dataloader, show_results = True, test_ind = True) 177 | if args.train: 178 | test_results['best_eval_score'] = round(self.best_eval_score, 4) 179 | test_results.update(ind_test_results) 180 | 181 | if args.ood: 182 | 183 | tmp_outputs = self._get_outputs(args, self.test_dataloader) 184 | if args.ood_detection_method in ['residual', 'ma', 'vim']: 185 | ind_train_outputs = self._get_outputs(args, self.train_dataloader) 186 | 187 | tmp_outputs['train_feats'] = ind_train_outputs['y_feat'] 188 | tmp_outputs['train_labels'] = ind_train_outputs['y_true'] 189 | 190 | w, b = self.model.vim() 191 | tmp_outputs['w'] = w 192 | tmp_outputs['b'] = b 193 | 194 | scores = self.ood_detection_func(args, tmp_outputs) 195 | binary_labels = np.array([1 if x != args.ood_label_id else 0 for x in tmp_outputs['y_true']]) 196 | 197 | ood_test_scores = self.ood_metrics(scores, binary_labels, show_results = True) 198 | test_results.update(ood_test_scores) 199 | 200 | return test_results --------------------------------------------------------------------------------