├── src ├── utils.py ├── eval_metrics.py ├── models.py ├── dataset.py ├── train_urfunny.py └── train.py ├── README.md ├── modules ├── position_embedding.py ├── multihead_attention.py ├── fast_attention.py └── sp_transformer.py ├── humor_dataloader.py └── run.py /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from src.dataset import Multimodal_Datasets 4 | 5 | 6 | def get_data(args, dataset, split='train'): 7 | alignment = 'a' if args.aligned else 'na' 8 | data_path = os.path.join(args.data_path, dataset) + f'_{split}_{alignment}.dt' 9 | if not os.path.exists(data_path): 10 | print(f" - Creating new {split} data") 11 | data = Multimodal_Datasets(args.data_path, dataset, split, args.aligned) 12 | torch.save(data, data_path) 13 | else: 14 | print(f" - Found cached {split} data") 15 | data = torch.load(data_path) 16 | return data 17 | 18 | 19 | def save_load_name(args, name=''): 20 | if args.aligned: 21 | name = name if len(name) > 0 else 'aligned_model' 22 | elif not args.aligned: 23 | name = name if len(name) > 0 else 'nonaligned_model' 24 | 25 | return name + '_' + args.model 26 | 27 | 28 | def save_model(args, model, name=''): 29 | name = save_load_name(args, name) 30 | torch.save(model, f'pre_trained_models/{name}.pt') 31 | 32 | 33 | def load_model(args, name=''): 34 | name = save_load_name(args, name) 35 | model = torch.load(f'pre_trained_models/{name}.pt') 36 | return model 37 | 38 | def countparams(module): 39 | params = list(module.parameters()) 40 | return sum(p.numel() for p in params if p.requires_grad) 41 | -------------------------------------------------------------------------------- /src/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import accuracy_score, f1_score 4 | 5 | 6 | def eval_mosei_senti(results, truths, exclude_zero=False): 7 | test_preds = results.view(-1).cpu().detach().numpy() 8 | test_truth = truths.view(-1).cpu().detach().numpy() 9 | 10 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)]) 11 | 12 | f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted') 13 | binary_truth = (test_truth[non_zeros] > 0) 14 | binary_preds = (test_preds[non_zeros] > 0) 15 | 16 | loginfo='' 17 | loginfo+="F1 score: "+str(f_score)+'\n' 18 | loginfo+="Accuracy: "+str(accuracy_score(binary_truth, binary_preds))+'\n' 19 | loginfo+="-" * 50+'\n' 20 | return f_score,loginfo 21 | 22 | def eval_mosi(results, truths, exclude_zero=False): 23 | return eval_mosei_senti(results, truths, exclude_zero) 24 | 25 | def eval_ur_funny(results, truths, exclude_zero=False): 26 | test_preds = results.view(-1).cpu().detach().numpy() 27 | test_truth = truths.view(-1).cpu().detach().numpy() 28 | 29 | binary_truth = (test_truth > 0.5) 30 | binary_preds = (test_preds > 0) 31 | accu=accuracy_score(binary_truth, binary_preds) 32 | 33 | loginfo='' 34 | loginfo+="Accuracy: "+str(accu)+'\n' 35 | loginfo+="-" * 50+'\n' 36 | return accu,loginfo 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Phased Transformer (SPT) 2 | 3 | This is the official repo for the EMNLP 2021 paper "[Multimodal Phased Transformer for Sentiment Analysis](https://aclanthology.org/2021.emnlp-main.189.pdf)" 4 | 5 | Preprocessed MOSI and MOSEI dataset by MulT download in https://github.com/yaohungt/Multimodal-Transformer 6 | UR-FUNNY dataset download in https://github.com/ROC-HCI/UR-FUNNY 7 | 8 | Use run.py to run the model, use Optuna to search hyper-params. 9 | 10 | 11 | ### Citation 12 | If you use this code in your research, please cite the following paper: 13 | 14 | 15 | ``` bibtex 16 | @inproceedings{cheng-etal-2021-multimodal, 17 | title = "Multimodal Phased Transformer for Sentiment Analysis", 18 | author = "Cheng, Junyan and 19 | Fostiropoulos, Iordanis and 20 | Boehm, Barry and 21 | Soleymani, Mohammad", 22 | editor = "Moens, Marie-Francine and 23 | Huang, Xuanjing and 24 | Specia, Lucia and 25 | Yih, Scott Wen-tau", 26 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 27 | month = nov, 28 | year = "2021", 29 | address = "Online and Punta Cana, Dominican Republic", 30 | publisher = "Association for Computational Linguistics", 31 | url = "https://aclanthology.org/2021.emnlp-main.189/", 32 | doi = "10.18653/v1/2021.emnlp-main.189", 33 | pages = "2447--2458", 34 | abstract = "Multimodal Transformers achieve superior performance in multimodal learning tasks. However, the quadratic complexity of the self-attention mechanism in Transformers limits their deployment in low-resource devices and makes their inference and training computationally expensive. We propose multimodal Sparse Phased Transformer (SPT) to alleviate the problem of self-attention complexity and memory footprint. SPT uses a sampling function to generate a sparse attention matrix and compress a long sequence to a shorter sequence of hidden states. SPT concurrently captures interactions between the hidden states of different modalities at every layer. To further improve the efficiency of our method, we use Layer-wise parameter sharing and Factorized Co-Attention that share parameters between Cross Attention Blocks, with minimal impact on task performance. We evaluate our model with three sentiment analysis datasets and achieve comparable or superior performance compared with the existing methods, with a 90{\%} reduction in the number of parameters. We conclude that (SPT) along with parameter sharing can capture multimodal interactions with reduced model size and improved sample efficiency." 35 | } 36 | ``` 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/yaohungt/Multimodal-Transformer/blob/master/src/models.py 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from modules.sp_transformer import SPEncoder 7 | 8 | 9 | class Superformer(nn.Module): 10 | def __init__(self, hyp_params): 11 | """ 12 | Construct a Superformer model. 13 | """ 14 | super(Superformer, self).__init__() 15 | self.d_model = hyp_params.d_model 16 | self.input_dims = dict(t=hyp_params.orig_d_l, 17 | a=hyp_params.orig_d_a, 18 | v=hyp_params.orig_d_v) 19 | self.num_heads = hyp_params.num_heads 20 | self.layers = hyp_params.layers 21 | self.attn_dropout = hyp_params.attn_dropout 22 | self.relu_dropout = hyp_params.relu_dropout 23 | self.res_dropout = hyp_params.res_dropout 24 | self.out_dropout = hyp_params.out_dropout 25 | self.embed_dropout = hyp_params.embed_dropout 26 | self.S, self.r = hyp_params.S, hyp_params.r 27 | self.shift_mode=hyp_params.shift_mode 28 | self.use_fast=hyp_params.use_fast 29 | self.use_dense=hyp_params.use_dense 30 | combined_dim = 3*self.d_model 31 | 32 | self.spe=SPEncoder(embed_dim=self.d_model, 33 | input_dims=self.input_dims, 34 | num_heads=self.num_heads, 35 | layers=self.layers, 36 | attn_dropout=self.attn_dropout, 37 | relu_dropout=self.relu_dropout, 38 | res_dropout=self.res_dropout, 39 | embed_dropout=self.embed_dropout, 40 | S=self.S, r=self.r, 41 | shift_mode=self.shift_mode, 42 | use_fast=self.use_fast, 43 | use_dense=self.use_dense, 44 | device=hyp_params.device) 45 | 46 | # Projection layers 47 | self.proj1 = nn.Linear(combined_dim, combined_dim) 48 | self.proj2 = nn.Linear(combined_dim, combined_dim) 49 | self.out_layer = nn.Linear(combined_dim, hyp_params.output_dim) 50 | 51 | 52 | def forward(self, t, a, v): #[BS,SL,D] 53 | h_a, h_t, h_v=self.spe(a,t,v) 54 | last_hs = torch.cat([h_t[-1], h_a[-1], h_v[-1]], dim=1) 55 | # last_hs = torch.cat([torch.mean(h_t,0), torch.mean(h_a,0), torch.mean(h_v,0)], dim=1) 56 | # last_hs = self.spe(a,t,v)[-1] 57 | 58 | # A residual block 59 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training)) 60 | last_hs_proj += last_hs 61 | 62 | output = self.out_layer(last_hs_proj) 63 | return output, last_hs -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import Dataset 3 | import pickle 4 | import os 5 | from scipy import signal 6 | import torch 7 | 8 | if torch.cuda.is_available(): 9 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 10 | else: 11 | torch.set_default_tensor_type('torch.FloatTensor') 12 | 13 | ############################################################################################ 14 | # This file provides basic processing script for the multimodal datasets we use. For other 15 | # datasets, small modifications may be needed (depending on the type of the data, etc.) 16 | ############################################################################################ 17 | 18 | 19 | class Multimodal_Datasets(Dataset): 20 | def __init__(self, dataset_path, data='mosei_senti', split_type='train', if_align=False): 21 | super(Multimodal_Datasets, self).__init__() 22 | dataset_path = os.path.join(dataset_path, data+'_data.pkl' if if_align else data+'_data_noalign.pkl' ) 23 | dataset = pickle.load(open(dataset_path, 'rb')) 24 | 25 | # These are torch tensors 26 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach() 27 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach() 28 | self.audio = dataset[split_type]['audio'].astype(np.float32) 29 | self.audio[self.audio == -np.inf] = 0 30 | self.audio = torch.tensor(self.audio).cpu().detach() 31 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach() 32 | 33 | # Note: this is STILL an numpy array 34 | self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None 35 | 36 | self.data = data 37 | 38 | self.n_modalities = 3 # vision/ text/ audio 39 | def get_n_modalities(self): 40 | return self.n_modalities 41 | def get_seq_len(self): 42 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1] 43 | def get_dim(self): 44 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] 45 | def get_lbl_info(self): 46 | # return number_of_labels, label_dim 47 | return self.labels.shape[1], self.labels.shape[2] 48 | def __len__(self): 49 | return len(self.labels) 50 | def __getitem__(self, index): 51 | X = (index, self.text[index], self.audio[index], self.vision[index]) 52 | Y = self.labels[index] 53 | META = (0,0,0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2]) 54 | if self.data == 'mosi': 55 | META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'), self.meta[index][2].decode('UTF-8')) 56 | if self.data == 'iemocap': 57 | Y = torch.argmax(Y, dim=-1) 58 | return X, Y, META 59 | 60 | -------------------------------------------------------------------------------- /modules/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Code adapted from the fairseq repo. 7 | 8 | def make_positions(tensor, padding_idx, left_pad): 9 | """Replace non-padding symbols with their position numbers. 10 | Position numbers begin at padding_idx+1. 11 | Padding symbols are ignored, but it is necessary to specify whether padding 12 | is added on the left side (left_pad=True) or right side (left_pad=False). 13 | """ 14 | max_pos = padding_idx + 1 + tensor.size(1) 15 | device = tensor.get_device() 16 | buf_name = f'range_buf_{device}' 17 | if not hasattr(make_positions, buf_name): 18 | setattr(make_positions, buf_name, tensor.new()) 19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor)) 20 | if getattr(make_positions, buf_name).numel() < max_pos: 21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name)) 22 | mask = tensor.ne(padding_idx) 23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor) 24 | if left_pad: 25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) 26 | new_tensor = tensor.clone() 27 | return new_tensor.masked_scatter_(mask, positions[mask]).long() 28 | 29 | 30 | class SinusoidalPositionalEmbedding(nn.Module): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored, but it is necessary to specify whether padding 33 | is added on the left side (left_pad=True) or right side (left_pad=False). 34 | """ 35 | 36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128): 37 | super().__init__() 38 | self.embedding_dim = embedding_dim 39 | self.padding_idx = padding_idx 40 | self.left_pad = left_pad 41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-( 42 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 43 | 44 | @staticmethod 45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | """Build sinusoidal embeddings. 47 | This matches the implementation in tensor2tensor, but differs slightly 48 | from the description in Section 3.5 of "Attention Is All You Need". 49 | """ 50 | half_dim = embedding_dim // 2 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 55 | if embedding_dim % 2 == 1: 56 | # zero pad 57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 58 | if padding_idx is not None: 59 | emb[padding_idx, :] = 0 60 | return emb 61 | 62 | def forward(self, input, dim=None): 63 | """Input is expected to be of size [bsz x seqlen].""" 64 | bsz, seq_len = input.size() 65 | max_pos = self.padding_idx + 1 + seq_len 66 | device = input.get_device() 67 | # if device not in self.weights or max_pos > self.weights[device].size(0): 68 | # # recompute/expand embeddings if needed 69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding( 70 | max_pos, 71 | self.embedding_dim if dim is None else dim, 72 | self.padding_idx, 73 | ) 74 | self.weights[device] = self.weights[device].type_as(self._float_tensor) 75 | positions = make_positions(input, self.padding_idx, self.left_pad) 76 | return self.weights[device].index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 77 | 78 | def max_positions(self): 79 | """Maximum number of supported positions.""" 80 | return int(1e5) # an arbitrary large number -------------------------------------------------------------------------------- /modules/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | # Code adapted from the fairseq repo. 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | See "Attention Is All You Need" for more details. 12 | """ 13 | 14 | def __init__(self, embed_dim, num_heads, attn_dropout=0., input_dim=None): 15 | super().__init__() 16 | self.embed_dim = embed_dim 17 | self.input_dim = embed_dim if input_dim is None else input_dim 18 | self.num_heads = num_heads 19 | self.attn_dropout = attn_dropout 20 | self.head_dim = embed_dim // num_heads 21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 22 | self.scaling = self.head_dim ** -0.5 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(embed_dim+2*self.input_dim, embed_dim)) 25 | self.in_proj_bias = Parameter(torch.Tensor(embed_dim*3)) 26 | self.out_proj = nn.Linear(embed_dim, embed_dim) 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.xavier_uniform_(self.in_proj_weight) 32 | nn.init.xavier_uniform_(self.out_proj.weight) 33 | nn.init.constant_(self.in_proj_bias, 0.) 34 | nn.init.constant_(self.out_proj.bias, 0.) 35 | 36 | def forward(self, query, key, value, attn_mask=None, reverse=False): 37 | """Input shape: Time x Batch x Channel 38 | Self-attention can be implemented by passing in the same arguments for 39 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 40 | `attn_mask` argument. Padding elements can be excluded from 41 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 42 | batch x src_len, where padding elements are indicated by 1s. 43 | """ 44 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 45 | kv_same = key.data_ptr() == value.data_ptr() 46 | 47 | tgt_len, bsz, embed_dim = query.size() 48 | assert embed_dim == self.embed_dim 49 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 50 | assert key.size() == value.size() 51 | 52 | if reverse: 53 | q = self.in_proj_k(query) 54 | k = self.in_proj_q(key) 55 | v = self.in_proj_v(key) 56 | else: 57 | q = self.in_proj_q(query) 58 | k, v = self.in_proj_kv(key) 59 | q = q * self.scaling 60 | 61 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 62 | if k is not None: 63 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 64 | if v is not None: 65 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 66 | 67 | src_len = k.size(1) 68 | 69 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 70 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 71 | 72 | if attn_mask is not None: 73 | try: 74 | attn_weights += attn_mask.unsqueeze(0) 75 | except: 76 | print(attn_weights.shape) 77 | print(attn_mask.unsqueeze(0).shape) 78 | assert False 79 | 80 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 81 | # attn_weights = F.relu(attn_weights) 82 | # attn_weights = attn_weights / torch.max(attn_weights) 83 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 84 | 85 | attn = torch.bmm(attn_weights, v) 86 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 87 | 88 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 89 | attn = self.out_proj(attn) 90 | 91 | # average attention weights over heads 92 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 93 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 94 | return attn, attn_weights 95 | 96 | def in_proj_kv(self, key): 97 | return self._in_proj(key, start=self.embed_dim, startb=self.embed_dim).chunk(2, dim=-1) 98 | 99 | def in_proj_q(self, query, **kwargs): 100 | return self._in_proj(query, end=self.embed_dim, endb=self.embed_dim, **kwargs) 101 | 102 | def in_proj_k(self, key): 103 | return self._in_proj(key, start=self.embed_dim, end=self.embed_dim+self.input_dim, 104 | startb=self.embed_dim, endb=self.embed_dim*2) 105 | 106 | def in_proj_v(self, key): 107 | return self._in_proj(key, start=self.embed_dim+self.input_dim, end=self.embed_dim+self.input_dim*2, 108 | startb=self.embed_dim*2, endb=self.embed_dim*3) 109 | 110 | def _in_proj(self, input, start=0, end=None, startb=0, endb=None, **kwargs): 111 | weight = kwargs.get('weight', self.in_proj_weight) 112 | bias = kwargs.get('bias', self.in_proj_bias) 113 | weight = weight[start:end, :] 114 | bias = bias[startb:endb] 115 | if bias.shape[0]!=weight.shape[0]: weight=weight.reshape(-1,self.embed_dim*2).T # KV 116 | return F.linear(input, weight, bias) 117 | -------------------------------------------------------------------------------- /humor_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import pickle 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import numpy as np 11 | import torch.nn as nn 12 | 13 | def load_pickle(pickle_file): 14 | try: 15 | with open(pickle_file, 'rb') as f: 16 | pickle_data = pickle.load(f) 17 | except UnicodeDecodeError as e: 18 | with open(pickle_file, 'rb') as f: 19 | pickle_data = pickle.load(f, encoding='latin1') 20 | except Exception as e: 21 | print('Unable to load data ', pickle_file, ':', e) 22 | raise 23 | return pickle_data 24 | 25 | 26 | # In[2]: 27 | 28 | 29 | ''' 30 | you can assign the maximum number number of sentences in context and what will be the maximum number of words of any sentence. 31 | 32 | It will do left padding . It will concatenate the word embedding + covarep features + openface features 33 | 34 | example: 35 | 36 | if max_sen_len = 20 then the punchline sentence dimension = 20 * 456. 37 | where 456 = word embedding (300) + covarep (81) + openface(75) 38 | 39 | if max_sen_len = 20 and max_context_len = 5 that means context can have maximum 5 sentences 40 | and each sentence will have maximum 20 words. The context dimension will be 5 * 20 * 456 41 | 42 | We will do left padding with zeros to maintaing the same dimension. 43 | 44 | In our experiments we set max_sen_len = 20 & max_context_len = 5 45 | ''' 46 | 47 | 48 | class HumorDataset(Dataset): 49 | 50 | def __init__(self, id_list,path,max_context_len=5,max_sen_len=20): 51 | self.id_list = id_list 52 | openface_file=path+"openface_features_sdk.pkl" 53 | covarep_file=path+"covarep_features_sdk.pkl" 54 | word_idx_file=path+"word_embedding_indexes_sdk.pkl" 55 | word_embedding_list_file=path+"word_embedding_list.pkl" 56 | humor_label_file=path+"humor_label_sdk.pkl" 57 | 58 | self.word_aligned_openface_sdk=load_pickle(openface_file) 59 | self.word_aligned_covarep_sdk=load_pickle(covarep_file) 60 | self.word_embedding_idx_sdk=load_pickle(word_idx_file) 61 | self.word_embedding_list_sdk=load_pickle(word_embedding_list_file) 62 | self.humor_label_sdk = load_pickle(humor_label_file) 63 | self.of_d=75 64 | self.cvp_d=81 65 | self.max_context_len=max_context_len 66 | self.max_sen_len=max_sen_len 67 | 68 | #left padding with zero vector upto maximum number of words in a sentence * glove embedding dimension 69 | def paded_word_idx(self,seq,max_sen_len=20,left_pad=1): 70 | seq=seq[0:max_sen_len] 71 | pad_w=np.concatenate((np.zeros(max_sen_len-len(seq)),seq),axis=0) 72 | pad_w=np.array([self.word_embedding_list_sdk[int(w_id)] for w_id in pad_w]) 73 | return pad_w 74 | 75 | #left padding with zero vector upto maximum number of words in a sentence * covarep dimension 76 | def padded_covarep_features(self,seq,max_sen_len=20,left_pad=1): 77 | seq=seq[0:max_sen_len] 78 | return np.concatenate((np.zeros((max_sen_len-len(seq),self.cvp_d)),seq),axis=0) 79 | 80 | #left padding with zero vector upto maximum number of words in a sentence * openface dimension 81 | def padded_openface_features(self,seq,max_sen_len=20,left_pad=1): 82 | seq=seq[0:max_sen_len] 83 | return np.concatenate((np.zeros(((max_sen_len-len(seq)),self.of_d)),seq),axis=0) 84 | 85 | #left padding with zero vectors upto maximum number of sentences in context * maximum num of words in a sentence * 456 86 | def padded_context_features(self,context_w,context_of,context_cvp,max_context_len=5,max_sen_len=20): 87 | context_w=context_w[-max_context_len:] 88 | context_of=context_of[-max_context_len:] 89 | context_cvp=context_cvp[-max_context_len:] 90 | 91 | padded_context=[] 92 | for i in range(len(context_w)): 93 | p_seq_w=self.paded_word_idx(context_w[i],max_sen_len) 94 | p_seq_cvp=self.padded_covarep_features(context_cvp[i],max_sen_len) 95 | p_seq_of=self. padded_openface_features(context_of[i],max_sen_len) 96 | padded_context.append(np.concatenate((p_seq_w,p_seq_cvp,p_seq_of),axis=1)) 97 | 98 | pad_c_len=max_context_len-len(padded_context) 99 | padded_context=np.array(padded_context) 100 | 101 | #if there is no context 102 | if not padded_context.any(): 103 | return np.zeros((max_context_len,max_sen_len,456)) 104 | 105 | return np.concatenate((np.zeros((pad_c_len,max_sen_len,456)),padded_context),axis=0) 106 | 107 | def padded_punchline_features(self,punchline_w,punchline_of,punchline_cvp,max_sen_len=20,left_pad=1): 108 | 109 | p_seq_w=self.paded_word_idx(punchline_w,max_sen_len) 110 | p_seq_cvp=self.padded_covarep_features(punchline_cvp,max_sen_len) 111 | p_seq_of=self.padded_openface_features(punchline_of,max_sen_len) 112 | return np.concatenate((p_seq_w,p_seq_cvp,p_seq_of),axis=1) 113 | 114 | 115 | def __len__(self): 116 | return len(self.id_list) 117 | 118 | def __getitem__(self,index): 119 | 120 | hid=self.id_list[index] 121 | punchline_w=np.array(self.word_embedding_idx_sdk[hid]['punchline_embedding_indexes']) 122 | punchline_of=np.array(self.word_aligned_openface_sdk[hid]['punchline_features']) 123 | punchline_cvp=np.array(self.word_aligned_covarep_sdk[hid]['punchline_features']) 124 | 125 | context_w=np.array(self.word_embedding_idx_sdk[hid]['context_embedding_indexes']) 126 | context_of=np.array(self.word_aligned_openface_sdk[hid]['context_features']) 127 | context_cvp=np.array(self.word_aligned_covarep_sdk[hid]['context_features']) 128 | 129 | #punchline feature 130 | x_p=torch.FloatTensor(self.padded_punchline_features(punchline_w,punchline_of,punchline_cvp,self.max_sen_len)) 131 | #context feature 132 | x_c=torch.FloatTensor(self.padded_context_features(context_w,context_of,context_cvp,self.max_context_len,self.max_sen_len)) 133 | 134 | y=torch.FloatTensor([self.humor_label_sdk[hid]]) 135 | return x_c, x_p,y 136 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/yaohungt/Multimodal-Transformer 2 | import torch 3 | import argparse 4 | from src.utils import * 5 | from torch.utils.data import DataLoader 6 | from src import train,train_urfunny 7 | from humor_dataloader import HumorDataset,load_pickle 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Multimodal Sentiment Analysis') 11 | parser.add_argument('-f', default='', type=str) 12 | 13 | # Fixed 14 | parser.add_argument('--model', type=str, default='Superformer', 15 | help='name of the model to use (Transformer, etc.)') 16 | parser.add_argument('--device', type=str, default='cpu', help='device') 17 | 18 | # Tasks 19 | parser.add_argument('--aligned', action='store_true', 20 | help='consider aligned experiment or not (default: False)') 21 | parser.add_argument('--dataset', type=str, default='mosei_senti', 22 | help='dataset to use (default: mosei_senti)') 23 | parser.add_argument('--data_path', type=str, default='data', 24 | help='path for storing the dataset') 25 | 26 | # Dropouts 27 | parser.add_argument('--attn_dropout', type=float, default=0.2, 28 | help='attention dropout') 29 | parser.add_argument('--relu_dropout', type=float, default=0.1, 30 | help='relu dropout') 31 | parser.add_argument('--embed_dropout', type=float, default=0.3, 32 | help='embedding dropout') 33 | parser.add_argument('--res_dropout', type=float, default=0.1, 34 | help='residual block dropout') 35 | parser.add_argument('--out_dropout', type=float, default=0.1, 36 | help='output layer dropout') 37 | 38 | 39 | # Architecture 40 | parser.add_argument('--layers', type=int, default=4, 41 | help='number of layers in the network (default: 5)') 42 | parser.add_argument('--num_heads', type=int, default=8, 43 | help='number of heads for the transformer network (default: 5)') 44 | parser.add_argument('--d_model', type=int, default=32, 45 | help='d_model') 46 | parser.add_argument('--S', type=float, default=5, help='S') 47 | parser.add_argument('--r', type=list, default=[8,4,3], help='r') 48 | parser.add_argument('--shift_mode', type=dict, 49 | default=dict(I=['S,P,R'],X=['S'],S=['S'],C=[1,0.25,0.05]), 50 | help='shift mode') 51 | parser.add_argument('--use_fast', type=bool, default=False, help='use fast attention') 52 | parser.add_argument('--use_dense', type=bool, default=False, help='use dense attention') 53 | 54 | 55 | # Tuning 56 | parser.add_argument('--batch_size', type=int, default=250, metavar='N', 57 | help='batch size (default: 24)') 58 | parser.add_argument('--clip', type=float, default=1.0, 59 | help='gradient clip value (default: 0.8)') 60 | parser.add_argument('--lr', type=float, default=5e-4, 61 | help='initial learning rate (default: 1e-3)') 62 | parser.add_argument('--optim', type=str, default='Adam', 63 | help='optimizer to use (default: Adam)') 64 | parser.add_argument('--num_epochs', type=int, default=50, 65 | help='number of epochs (default: 40)') 66 | parser.add_argument('--when', type=int, default=20, 67 | help='when to decay learning rate (default: 20)') 68 | parser.add_argument('--batch_chunk', type=int, default=1, 69 | help='number of chunks per batch (default: 1)') 70 | 71 | # Logistics 72 | parser.add_argument('--log_interval', type=int, default=30, 73 | help='frequency of result logging (default: 30)') 74 | parser.add_argument('--seed', type=int, default=777, 75 | help='random seed') 76 | parser.add_argument('--no_cuda', action='store_true', 77 | help='do not use cuda') 78 | parser.add_argument('--name', type=str, default='test2', 79 | help='name of the trial (default: "mult")') 80 | args = parser.parse_args() 81 | 82 | 83 | if not os.path.exists('./pre_trained_models'): os.makedirs('./pre_trained_models') 84 | args.data_path='../datasets/Archive' 85 | args.dataset='mosei_senti' 86 | args.aligned=False 87 | 88 | 89 | torch.manual_seed(args.seed) 90 | dataset = str.lower(args.dataset.strip()) 91 | 92 | use_cuda = False 93 | torch.set_default_tensor_type('torch.FloatTensor') 94 | if torch.cuda.is_available(): 95 | if args.no_cuda: 96 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 97 | else: 98 | torch.cuda.manual_seed(args.seed) 99 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 100 | use_cuda = True 101 | 102 | #################### 103 | # # 104 | # Load the dataset # 105 | # # 106 | #################### 107 | 108 | print("Start loading the data....") 109 | 110 | if args.dataset=='urfunny': 111 | path='../datasets/UR-FUNNY/' 112 | data_folds_file= path+"data_folds.pkl" 113 | data_folds= load_pickle(data_folds_file) 114 | 115 | max_context_len=8 116 | max_sen_len=40 117 | train_set = HumorDataset(data_folds['train'],path,max_context_len,max_sen_len) 118 | dev_set = HumorDataset(data_folds['dev'],path,max_context_len,max_sen_len) 119 | test_set = HumorDataset(data_folds['test'],path,max_context_len,max_sen_len) 120 | else: 121 | train_set = get_data(args, dataset, 'train') 122 | dev_set = get_data(args, dataset, 'valid') 123 | test_set = get_data(args, dataset, 'test') 124 | 125 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) 126 | dev_loader = DataLoader(dev_set, batch_size=args.batch_size, shuffle=True) 127 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True) 128 | 129 | print('Finish loading the data....') 130 | 131 | ################### 132 | # # 133 | # Hyperparameters # 134 | # # 135 | ################### 136 | 137 | hyp_params = args 138 | if args.dataset=='urfunny': hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v = 300, 81, 75 139 | else: 140 | hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v = train_set.get_dim() 141 | hyp_params.l_len, hyp_params.a_len, hyp_params.v_len = train_set.get_seq_len() 142 | hyp_params.use_cuda = use_cuda 143 | hyp_params.dataset = dataset 144 | hyp_params.when = args.when 145 | hyp_params.batch_chunk = args.batch_chunk 146 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_set), len(dev_set), len(test_set) 147 | hyp_params.model = str.upper(args.model.strip()) 148 | hyp_params.output_dim = 1 149 | hyp_params.criterion = 'L1Loss' 150 | 151 | if args.dataset=='urfunny': trainer=train_urfunny.initiate 152 | else: trainer=train.initiate 153 | 154 | if __name__ == '__main__': 155 | metric = trainer(hyp_params, train_loader, dev_loader, test_loader, test_only=False) 156 | -------------------------------------------------------------------------------- /modules/fast_attention.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.cuda.amp import autocast 7 | from einops import rearrange, repeat 8 | 9 | from functools import partial 10 | 11 | 12 | # helpers 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def empty(tensor): 18 | return tensor.numel() == 0 19 | 20 | def default(val, d): 21 | return val if exists(val) else d 22 | 23 | 24 | # kernel functions 25 | 26 | # transcribed from jax to pytorch from 27 | # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py 28 | 29 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 30 | b, h, *_ = data.shape 31 | 32 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 33 | 34 | ratio = (projection_matrix.shape[0] ** -0.5) 35 | 36 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 37 | projection = projection.type_as(data) 38 | 39 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 40 | 41 | diag_data = data ** 2 42 | diag_data = torch.sum(diag_data, dim=-1) 43 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 44 | diag_data = diag_data.unsqueeze(dim=-1) 45 | 46 | if is_query: 47 | data_dash = ratio * ( 48 | torch.exp(data_dash - diag_data - 49 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 50 | else: 51 | data_dash = ratio * ( 52 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 53 | 54 | return data_dash.type_as(data) 55 | 56 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 57 | b, h, *_ = data.shape 58 | 59 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 60 | 61 | if projection_matrix is None: 62 | return kernel_fn(data_normalizer * data) + kernel_epsilon 63 | 64 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 65 | projection = projection.type_as(data) 66 | 67 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 68 | 69 | data_prime = kernel_fn(data_dash) + kernel_epsilon 70 | return data_prime.type_as(data) 71 | 72 | def orthogonal_matrix_chunk(cols, device = None): 73 | unstructured_block = torch.randn((cols, cols), device = device) 74 | q, r = torch.qr(unstructured_block.cpu(), some = True) 75 | q, r = map(lambda t: t.to(device), (q, r)) 76 | return q.t() 77 | 78 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): 79 | nb_full_blocks = int(nb_rows / nb_columns) 80 | 81 | block_list = [] 82 | 83 | for _ in range(nb_full_blocks): 84 | q = orthogonal_matrix_chunk(nb_columns, device = device) 85 | block_list.append(q) 86 | 87 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 88 | if remaining_rows > 0: 89 | q = orthogonal_matrix_chunk(nb_columns, device = device) 90 | block_list.append(q[:remaining_rows]) 91 | 92 | final_matrix = torch.cat(block_list) 93 | 94 | if scaling == 0: 95 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 96 | elif scaling == 1: 97 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 98 | else: 99 | raise ValueError(f'Invalid scaling {scaling}') 100 | 101 | if multiplier.is_cuda: final_matrix=final_matrix.cuda() 102 | return torch.diag(multiplier) @ final_matrix 103 | 104 | 105 | def linear_attention(q, k, v): 106 | k_cumsum = k.sum(dim = -2) 107 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) 108 | context = torch.einsum('...nd,...ne->...de', k, v) 109 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 110 | return out 111 | 112 | class FastAttention(nn.Module): 113 | def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, generalized_attention = False, kernel_fn = nn.ReLU()): 114 | super().__init__() 115 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 116 | 117 | self.dim_heads = dim_heads 118 | self.nb_features = nb_features 119 | self.ortho_scaling = ortho_scaling 120 | 121 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) 122 | projection_matrix = self.create_projection() 123 | self.register_buffer('projection_matrix', projection_matrix) 124 | 125 | self.generalized_attention = generalized_attention 126 | self.kernel_fn = kernel_fn 127 | 128 | 129 | def forward(self, q, k, v): 130 | device = q.device 131 | 132 | if self.generalized_attention: 133 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 134 | q, k = map(create_kernel, (q, k)) 135 | 136 | else: 137 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 138 | q = create_kernel(q, is_query = True) 139 | k = create_kernel(k, is_query = False) 140 | 141 | attn_fn = linear_attention 142 | out = attn_fn(q, k, v) 143 | return out 144 | 145 | 146 | class Attention(nn.Module): 147 | def __init__( 148 | self, 149 | dim, 150 | heads = 8, 151 | dim_head = 64, 152 | nb_features = None, 153 | generalized_attention = False, 154 | kernel_fn = nn.ReLU(), 155 | dropout = 0., 156 | qkv_bias = False, 157 | attn_out_bias = True 158 | ): 159 | super().__init__() 160 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 161 | dim_head = default(dim_head, dim // heads) 162 | inner_dim = dim_head * heads 163 | self.fast_attention = FastAttention(dim_head, nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn) 164 | 165 | self.heads = heads 166 | self.global_heads = heads 167 | 168 | self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) 169 | self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) 170 | self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) 171 | self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) 172 | self.dropout = nn.Dropout(dropout) 173 | 174 | def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, reverse=False, **kwargs): 175 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 176 | 177 | cross_attend = exists(context) 178 | 179 | context = default(context, x) 180 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 181 | 182 | if not reverse: 183 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 184 | else: 185 | q, k, v = self.to_k(x), self.to_q(context), self.to_v(context) 186 | 187 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 188 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 189 | 190 | attn_outs = [] 191 | 192 | if not empty(q): 193 | if exists(context_mask): 194 | global_mask = context_mask[:, None, :, None] 195 | v.masked_fill_(~global_mask, 0.) 196 | 197 | out = self.fast_attention(q, k, v) 198 | attn_outs.append(out) 199 | 200 | out = torch.cat(attn_outs, dim = 1) 201 | out = rearrange(out, 'b h n d -> b n (h d)') 202 | out = self.to_out(out) 203 | return self.dropout(out) 204 | 205 | -------------------------------------------------------------------------------- /src/train_urfunny.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import sys 4 | from src import models 5 | from src.utils import * 6 | import torch.optim as optim 7 | import numpy as np 8 | import time 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | import os 11 | import pickle 12 | 13 | from src.eval_metrics import * 14 | 15 | 16 | ####################### 17 | # 18 | # Construct the model 19 | # 20 | ####################### 21 | 22 | text_dim, audio_dim,video_dim = 300, 81, 75 23 | def crop(a): return a[:,:,:, 0:text_dim],a[:, :, :, text_dim:(text_dim + audio_dim)], a[:, :, :, (text_dim+audio_dim)::] 24 | 25 | def initiate(hyp_params, train_loader, valid_loader, test_loader,test_only=False,verbose=True,onlypunch=False): 26 | model = getattr(models, 'SPModel')(hyp_params) 27 | paramnum=sum(p.numel() for p in list(model.parameters()) if p.requires_grad) 28 | if verbose: print('Param num',paramnum) 29 | with open('./pre_trained_models/'+hyp_params.name+'_log.txt','w') as f: 30 | f.write('Param num: '+str(paramnum)+'\n') 31 | 32 | if hyp_params.use_cuda: 33 | model = model.cuda() 34 | 35 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr) 36 | 37 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True) 38 | settings = {'model': model, 39 | 'optimizer': optimizer, 40 | 'scheduler': scheduler} 41 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader,test_only,verbose,onlypunch) 42 | 43 | 44 | #################################################################### 45 | # 46 | # Training and evaluation scripts 47 | # 48 | #################################################################### 49 | 50 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader,test_only=False,verbose=True,onlypunch=False): 51 | model = settings['model'] 52 | optimizer = settings['optimizer'] 53 | criterion = nn.BCEWithLogitsLoss() 54 | 55 | scheduler = settings['scheduler'] 56 | 57 | log='' 58 | 59 | def train(model, optimizer, criterion, verbose=True,onlypunch=False): 60 | epoch_loss = 0 61 | model.train() 62 | num_batches = hyp_params.n_train // hyp_params.batch_size 63 | proc_loss, proc_size = 0, 0 64 | start_time = time.time() 65 | for i_batch, (x_c, x_p, eval_attr) in enumerate(train_loader): 66 | x_p = torch.unsqueeze(x_p, dim=1) 67 | combined = x_p if onlypunch else torch.cat([x_c, x_p], dim=1) 68 | t,a,v = crop(combined) 69 | bs,segs,lens,_=a.shape 70 | audio,text,vision=a.reshape(bs,segs*lens,-1),t.reshape(bs,segs*lens,-1),v.reshape(bs,segs*lens,-1) 71 | 72 | model.zero_grad() 73 | 74 | if hyp_params.use_cuda: 75 | with torch.cuda.device(0): 76 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 77 | 78 | batch_size = text.size(0) 79 | batch_chunk = hyp_params.batch_chunk 80 | 81 | combined_loss = 0 82 | net = nn.DataParallel(model) if batch_size > 10 else model 83 | if batch_chunk > 1: 84 | raw_loss = combined_loss = 0 85 | text_chunks = text.chunk(batch_chunk, dim=0) 86 | audio_chunks = audio.chunk(batch_chunk, dim=0) 87 | vision_chunks = vision.chunk(batch_chunk, dim=0) 88 | eval_attr_chunks = eval_attr.chunk(batch_chunk, dim=0) 89 | 90 | for i in range(len(text_chunks)): 91 | text_i, audio_i, vision_i = text_chunks[i], audio_chunks[i], vision_chunks[i] 92 | eval_attr_i = eval_attr_chunks[i] 93 | preds_i, hiddens_i = net(text_i, audio_i, vision_i) 94 | raw_loss_i = criterion(preds_i, eval_attr_i) / batch_chunk 95 | raw_loss += raw_loss_i 96 | raw_loss_i.backward() 97 | combined_loss = raw_loss 98 | else: 99 | preds, hiddens = net(text, audio, vision) 100 | raw_loss = criterion(preds, eval_attr) 101 | combined_loss = raw_loss 102 | combined_loss.backward() 103 | 104 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip) 105 | optimizer.step() 106 | 107 | proc_loss += raw_loss.item() * batch_size 108 | proc_size += batch_size 109 | epoch_loss += combined_loss.item() * batch_size 110 | if i_batch % hyp_params.log_interval == 0 and i_batch > 0: 111 | avg_loss = proc_loss / proc_size 112 | elapsed_time = time.time() - start_time 113 | loginfo='Epoch {:2d} | Batch {:3d}/{:3d} | Time/Batch(ms) {:5.2f} | Train Loss {:5.4f}'.format( 114 | epoch, i_batch, num_batches, elapsed_time * 1000 / hyp_params.log_interval, avg_loss) 115 | if verbose: print(loginfo) 116 | proc_loss, proc_size = 0, 0 117 | start_time = time.time() 118 | 119 | return epoch_loss / hyp_params.n_train 120 | 121 | def evaluate(model, criterion, test=False,onlypunch=False): 122 | model.eval() 123 | loader = test_loader if test else valid_loader 124 | total_loss = 0.0 125 | 126 | results = [] 127 | truths = [] 128 | 129 | with torch.no_grad(): 130 | for i_batch, (x_c, x_p, eval_attr) in enumerate(loader): 131 | x_p = torch.unsqueeze(x_p, dim=1) 132 | combined = x_p if onlypunch else torch.cat([x_c, x_p], dim=1) 133 | t,a,v = crop(combined) 134 | bs,segs,lens,_=a.shape 135 | audio,text,vision=a.reshape(bs,segs*lens,-1),t.reshape(bs,segs*lens,-1),v.reshape(bs,segs*lens,-1) 136 | 137 | if hyp_params.use_cuda: 138 | with torch.cuda.device(0): 139 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 140 | 141 | batch_size = text.size(0) 142 | 143 | net = nn.DataParallel(model) if batch_size > 10 else model 144 | preds, _ = net(text, audio, vision) 145 | total_loss += criterion(preds, eval_attr).item() * batch_size 146 | 147 | # Collect the results into dictionary 148 | results.append(preds) 149 | truths.append(eval_attr) 150 | 151 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 152 | 153 | results = torch.cat(results) 154 | truths = torch.cat(truths) 155 | return avg_loss, results, truths 156 | 157 | if not test_only: 158 | best_valid = 1e8 159 | for epoch in range(1, hyp_params.num_epochs+1): 160 | start = time.time() 161 | train_loss=train(model, optimizer, criterion, verbose,onlypunch=onlypunch) 162 | val_loss, _, _ = evaluate(model, criterion, test=False,onlypunch=onlypunch) 163 | test_loss, _, _ = evaluate(model, criterion, test=True,onlypunch=onlypunch) 164 | 165 | end = time.time() 166 | duration = end-start 167 | scheduler.step(val_loss) # Decay learning rate by validation loss 168 | 169 | loginfo="-"*50+'\n' 170 | loginfo='Epoch {:2d} | Time {:5.4f} sec | Train Loss {:5.4f} | Valid Loss {:5.4f} | Test Loss {:5.4f}\n'.format( 171 | epoch, duration, train_loss, val_loss, test_loss) 172 | loginfo+="-"*50 173 | if verbose: print(loginfo) 174 | log+=loginfo+'\n' 175 | 176 | if val_loss < best_valid: 177 | loginfo=f"Saved model at pre_trained_models/{hyp_params.name}.pt!" 178 | if verbose: print(loginfo) 179 | log+=loginfo+'\n' 180 | save_model(hyp_params, model, name=hyp_params.name) 181 | best_valid = val_loss 182 | 183 | model = load_model(hyp_params, name=hyp_params.name) 184 | _, results, truths = evaluate(model, criterion, test=True, onlypunch=onlypunch) 185 | metric,loginfo=eval_ur_funny(results, truths) 186 | if verbose: print(loginfo) 187 | log+=loginfo 188 | 189 | with open('./pre_trained_models/'+hyp_params.name+'_log.txt','w') as f: f.write(log) 190 | return metric 191 | 192 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import sys 4 | from src import models 5 | from src.utils import * 6 | import torch.optim as optim 7 | import numpy as np 8 | import time 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | import os 11 | import pickle 12 | 13 | from src.eval_metrics import * 14 | 15 | 16 | ###################### 17 | # 18 | # Construct the model 19 | # 20 | ###################### 21 | 22 | 23 | def initiate(hyp_params, train_loader, valid_loader, test_loader,test_only=False,verbose=True,ratio=1): 24 | model = getattr(models, 'Superformer')(hyp_params) 25 | paramnum=sum(p.numel() for p in list(model.parameters()) if p.requires_grad) 26 | if verbose: print('Param num',paramnum) 27 | with open('./pre_trained_models/'+hyp_params.name+'_log.txt','w') as f: 28 | f.write('Param num: '+str(paramnum)+'\n') 29 | 30 | if hyp_params.use_cuda: 31 | model = model.cuda() 32 | 33 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr) 34 | criterion = getattr(nn, hyp_params.criterion)() 35 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True) 36 | settings = {'model': model, 37 | 'optimizer': optimizer, 38 | 'criterion': criterion, 39 | 'scheduler': scheduler} 40 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader,test_only,verbose,ratio) 41 | 42 | 43 | #################################################################### 44 | # 45 | # Training and evaluation scripts 46 | # 47 | #################################################################### 48 | 49 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader,test_only=False,verbose=True,ratio=1): 50 | model = settings['model'] 51 | optimizer = settings['optimizer'] 52 | criterion = settings['criterion'] 53 | scheduler = settings['scheduler'] 54 | 55 | log='' 56 | 57 | def train(model, optimizer, criterion, verbose=True,ratio=1): 58 | epoch_loss = 0 59 | model.train() 60 | num_batches = hyp_params.n_train // hyp_params.batch_size 61 | proc_loss, proc_size = 0, 0 62 | start_time = time.time() 63 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(train_loader): 64 | if i_batch>=int(np.ceil(len(train_loader)*ratio)): break 65 | sample_ind, text, audio, vision = batch_X 66 | eval_attr = batch_Y.squeeze(-1) # if num of labels is 1 67 | 68 | model.zero_grad() 69 | if hyp_params.use_cuda: 70 | with torch.cuda.device(0): 71 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 72 | if hyp_params.dataset == 'iemocap': 73 | eval_attr = eval_attr.long() 74 | 75 | batch_size = text.size(0) 76 | batch_chunk = hyp_params.batch_chunk 77 | 78 | loss = 0 79 | net = nn.DataParallel(model) if batch_size > 10 else model 80 | if batch_chunk > 1: 81 | loss = 0 82 | text_chunks = text.chunk(batch_chunk, dim=0) 83 | audio_chunks = audio.chunk(batch_chunk, dim=0) 84 | vision_chunks = vision.chunk(batch_chunk, dim=0) 85 | eval_attr_chunks = eval_attr.chunk(batch_chunk, dim=0) 86 | 87 | for i in range(len(text_chunks)): 88 | text_i, audio_i, vision_i = text_chunks[i], audio_chunks[i], vision_chunks[i] 89 | eval_attr_i = eval_attr_chunks[i] 90 | preds_i, hiddens_i = net(text_i, audio_i, vision_i) 91 | 92 | if hyp_params.dataset == 'iemocap': 93 | preds_i = preds_i.view(-1, 2) 94 | eval_attr_i = eval_attr_i.view(-1) 95 | loss_i = criterion(preds_i, eval_attr_i) / batch_chunk 96 | loss += loss_i 97 | loss_i.backward() 98 | else: 99 | preds, hiddens = net(text, audio, vision) 100 | if hyp_params.dataset == 'iemocap': 101 | preds = preds.view(-1, 2) 102 | eval_attr = eval_attr.view(-1) 103 | loss = criterion(preds, eval_attr) 104 | loss.backward() 105 | 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip) 107 | optimizer.step() 108 | 109 | proc_loss += loss.item() * batch_size 110 | proc_size += batch_size 111 | epoch_loss += loss.item() * batch_size 112 | if i_batch % hyp_params.log_interval == 0 and i_batch > 0: 113 | avg_loss = proc_loss / proc_size 114 | elapsed_time = time.time() - start_time 115 | loginfo='Epoch {:2d} | Batch {:3d}/{:3d} | Time/Batch(ms) {:5.2f} | Train Loss {:5.4f}'.format( 116 | epoch, i_batch, num_batches, elapsed_time * 1000 / hyp_params.log_interval, avg_loss) 117 | if verbose: print(loginfo) 118 | proc_loss, proc_size = 0, 0 119 | start_time = time.time() 120 | 121 | return epoch_loss / hyp_params.n_train 122 | 123 | def evaluate(model, criterion, test=False): 124 | model.eval() 125 | loader = test_loader if test else valid_loader 126 | total_loss = 0.0 127 | 128 | results = [] 129 | truths = [] 130 | 131 | with torch.no_grad(): 132 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(loader): 133 | sample_ind, text, audio, vision = batch_X 134 | eval_attr = batch_Y.squeeze(dim=-1) # if num of labels is 1 135 | 136 | if hyp_params.use_cuda: 137 | with torch.cuda.device(0): 138 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 139 | if hyp_params.dataset == 'iemocap': 140 | eval_attr = eval_attr.long() 141 | 142 | batch_size = text.size(0) 143 | 144 | net = nn.DataParallel(model) if batch_size > 10 else model 145 | preds, _ = net(text, audio, vision) 146 | if hyp_params.dataset == 'iemocap': 147 | preds = preds.view(-1, 2) 148 | eval_attr = eval_attr.view(-1) 149 | total_loss += criterion(preds, eval_attr).item() * batch_size 150 | 151 | # Collect the results into dictionary 152 | results.append(preds) 153 | truths.append(eval_attr) 154 | 155 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 156 | 157 | results = torch.cat(results) 158 | truths = torch.cat(truths) 159 | return avg_loss, results, truths 160 | 161 | if not test_only: 162 | best_valid = 1e8 163 | for epoch in range(1, hyp_params.num_epochs+1): 164 | start = time.time() 165 | train_loss=train(model, optimizer, criterion, verbose,ratio) 166 | val_loss, _, _ = evaluate(model, criterion, test=False) 167 | test_loss, _, _ = evaluate(model, criterion, test=True) 168 | 169 | end = time.time() 170 | duration = end-start 171 | scheduler.step(val_loss) # Decay learning rate by validation loss 172 | 173 | loginfo="-"*50+'\n' 174 | loginfo='Epoch {:2d} | Time {:5.4f} sec | Train Loss {:5.4f} | Valid Loss {:5.4f} | Test Loss {:5.4f}\n'.format( 175 | epoch, duration, train_loss, val_loss, test_loss) 176 | loginfo+="-"*50 177 | if verbose: print(loginfo) 178 | log+=loginfo+'\n' 179 | 180 | if val_loss < best_valid: 181 | loginfo=f"Saved model at pre_trained_models/{hyp_params.name}.pt!" 182 | if verbose: print(loginfo) 183 | log+=loginfo+'\n' 184 | save_model(hyp_params, model, name=hyp_params.name) 185 | best_valid = val_loss 186 | 187 | model = load_model(hyp_params, name=hyp_params.name) 188 | _, results, truths = evaluate(model, criterion, test=True) 189 | 190 | if hyp_params.dataset == "mosei_senti": 191 | metric,loginfo=eval_mosei_senti(results, truths, True) 192 | elif hyp_params.dataset == 'mosi': 193 | metric,loginfo=eval_mosi(results, truths, True) 194 | elif hyp_params.dataset == 'iemocap': 195 | metric,loginfo=eval_iemocap(results, truths) 196 | print(loginfo) 197 | log+=loginfo 198 | 199 | with open('./pre_trained_models/'+hyp_params.name+'_log.txt','w') as f: f.write(log) 200 | return metric 201 | 202 | -------------------------------------------------------------------------------- /modules/sp_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.position_embedding import SinusoidalPositionalEmbedding 5 | from modules.fast_attention import Attention 6 | from modules.multihead_attention import MultiheadAttention 7 | import math 8 | 9 | 10 | class SPEncoder(nn.Module): 11 | 12 | def __init__(self, input_dims, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, 13 | res_dropout=0.0, embed_dropout=0.0, S=1, r=[1,1,1], 14 | shift_mode=dict(I=['S','P','R'],X=['S'],S=['S'],C=[1,0.25,0.05]), 15 | use_fast=False,use_dense=False,device='cuda'): 16 | super().__init__() 17 | self.dropout = embed_dropout # Embedding dropout 18 | self.embed_dim = embed_dim 19 | self.embed_scale = math.sqrt(embed_dim) 20 | self.embed_positions = SinusoidalPositionalEmbedding(embed_dim) 21 | self.init_hiddens=nn.Parameter(torch.Tensor(3, embed_dim)) 22 | nn.init.xavier_uniform_(self.init_hiddens) 23 | self.shift_mode=shift_mode 24 | self.use_fast=use_fast 25 | self.use_dense=use_dense 26 | self.device=device 27 | 28 | if self.use_fast: 29 | self.proj_a = nn.Conv1d(input_dims['a'], self.embed_dim, kernel_size=1) 30 | self.proj_t = nn.Conv1d(input_dims['t'], self.embed_dim, kernel_size=1) 31 | self.proj_v = nn.Conv1d(input_dims['v'], self.embed_dim, kernel_size=1) 32 | input_dims=dict(a=self.embed_dim,t=self.embed_dim,v=self.embed_dim) 33 | 34 | self.layers,self.stride=layers,S 35 | self.hiddenlayer = SPEncoderHiddenLayer(embed_dim, 36 | input_dims, 37 | num_heads=num_heads, 38 | attn_dropout=attn_dropout, 39 | relu_dropout=relu_dropout, 40 | res_dropout=res_dropout, 41 | stride=S, 42 | pf_hidden=r[0]*2+1, 43 | use_fast=self.use_fast, 44 | use_dense=self.use_dense, 45 | device=self.device) 46 | self.crosslayer = SPEncoderCrossLayer(embed_dim, 47 | input_dims, 48 | num_heads=num_heads, 49 | attn_dropout=attn_dropout, 50 | relu_dropout=relu_dropout, 51 | res_dropout=res_dropout, 52 | pf_cross=r[1]*2+1, 53 | use_fast=self.use_fast, 54 | use_dense=self.use_dense, 55 | device=self.device) 56 | self.selflayer = SPEncoderSelfLayer(embed_dim, 57 | input_dims, 58 | num_heads=num_heads, 59 | attn_dropout=attn_dropout, 60 | relu_dropout=relu_dropout, 61 | res_dropout=res_dropout, 62 | pf_self=r[2]*2+1, 63 | use_fast=self.use_fast, 64 | use_dense=self.use_dense, 65 | device=self.device) 66 | self.layer_norm = nn.LayerNorm(embed_dim) 67 | 68 | def init_emb(self,x): 69 | x = self.embed_scale * x 70 | if self.embed_positions is not None: 71 | x += self.embed_positions(x.transpose(0, 1)[:, :, 0],x.shape[-1]).transpose(0, 1) # Add positional embedding 72 | x = F.dropout(x, p=self.dropout, training=self.training) 73 | return x 74 | 75 | def forward(self, a, t, v): # [BS, SL, D] 76 | sla,slt,slv,bs=a.shape[1],t.shape[1],v.shape[1],a.shape[0] 77 | hla,hlt,hlv=math.ceil(sla/self.stride),math.ceil(slt/self.stride),math.ceil(slv/self.stride) 78 | h_a=self.init_hiddens[0,:].unsqueeze(0).repeat(hla,bs,1) 79 | h_t=self.init_hiddens[1,:].unsqueeze(0).repeat(hlt,bs,1) 80 | h_v=self.init_hiddens[2,:].unsqueeze(0).repeat(hlv,bs,1) 81 | 82 | # embed tokens and positions 83 | if self.use_fast: 84 | a=self.proj_a(a.transpose(1, 2)).permute(2, 0, 1) 85 | t=self.proj_t(t.transpose(1, 2)).permute(2, 0, 1) 86 | v=self.proj_v(v.transpose(1, 2)).permute(2, 0, 1) 87 | else: a,t,v=a.permute(1,0,2),t.permute(1,0,2),v.permute(1,0,2) 88 | a,t,v=self.init_emb(a),self.init_emb(t),self.init_emb(v) 89 | h_a,h_t,h_v=self.init_emb(h_a),self.init_emb(h_t),self.init_emb(h_v) 90 | 91 | # encoder layers 92 | shift_mode=self.shift_mode 93 | for i in range(self.layers): 94 | shift_a,shift_t,shift_v=0,0,0 95 | if 'S' in shift_mode['I']: 96 | A=shift_mode['C'][0] 97 | shift_a+=i*A; shift_t+=i*A; shift_v+=i*A 98 | if 'P' in shift_mode['I']: 99 | B=shift_mode['C'][1] 100 | shift_a+=(sla*torch.sin(torch.arange(hla)*B)).long().to(self.device) 101 | shift_t+=(slt*torch.sin(torch.arange(hlt)*B)).long().to(self.device) 102 | shift_v+=(slv*torch.sin(torch.arange(hlv)*B)).long().to(self.device) 103 | if 'R' in shift_mode['I']: 104 | G=shift_mode['C'][2] 105 | shift_a+=torch.randint(0,math.ceil(G*sla/hla),[hla],device=self.device) 106 | shift_t+=torch.randint(0,math.ceil(G*slt/hlt),[hlt],device=self.device) 107 | shift_v+=torch.randint(0,math.ceil(G*slv/hlt),[hlv],device=self.device) 108 | h_a, h_t, h_v=self.hiddenlayer(a,t,v,h_a,h_t,h_v,shift_a,shift_t,shift_v) 109 | 110 | shift_a,shift_t,shift_v=0,0,0 111 | if 'S' in shift_mode['X']: 112 | A=shift_mode['C'][0] 113 | shift_a+=i*A; shift_t+=i*A; shift_v+=i*A 114 | if 'P' in shift_mode['X']: 115 | B=shift_mode['C'][1] 116 | shift_a+=(sla*torch.sin(torch.arange(hla)*B)).long().to(self.device) 117 | shift_t+=(slt*torch.sin(torch.arange(hlt)*B)).long().to(self.device) 118 | shift_v+=(slv*torch.sin(torch.arange(hlv)*B)).long().to(self.device) 119 | if 'R' in shift_mode['X']: 120 | G=shift_mode['C'][2] 121 | shift_a+=torch.randint(0,math.ceil(G*sla/hla),[hla],device=self.device) 122 | shift_t+=torch.randint(0,math.ceil(G*slt/hlt),[hlt],device=self.device) 123 | shift_v+=torch.randint(0,math.ceil(G*slv/hlt),[hlv],device=self.device) 124 | h_a, h_t, h_v=self.crosslayer(h_a,h_t,h_v,shift_a,shift_t,shift_v) 125 | 126 | shift_a,shift_t,shift_v=0,0,0 127 | if 'S' in shift_mode['S']: 128 | A=shift_mode['C'][0] 129 | shift_a+=i*A; shift_t+=i*A; shift_v+=i*A 130 | if 'P' in shift_mode['S']: 131 | B=shift_mode['C'][1] 132 | shift_a+=(sla*torch.sin(torch.arange(hla)*B)).long().to(self.device) 133 | shift_t+=(slt*torch.sin(torch.arange(hlt)*B)).long().to(self.device) 134 | shift_v+=(slv*torch.sin(torch.arange(hlv)*B)).long().to(self.device) 135 | if 'R' in shift_mode['S']: 136 | G=shift_mode['C'][2] 137 | shift_a+=torch.randint(0,math.ceil(G*sla/hla),[hla],device=self.device) 138 | shift_t+=torch.randint(0,math.ceil(G*slt/hlt),[hlt],device=self.device) 139 | shift_v+=torch.randint(0,math.ceil(G*slv/hlt),[hlv],device=self.device) 140 | h_a, h_t, h_v=self.selflayer(h_a,h_t,h_v,shift_a,shift_t,shift_v) 141 | 142 | h_a = self.layer_norm(h_a) 143 | h_t = self.layer_norm(h_t) 144 | h_v = self.layer_norm(h_v) 145 | return h_a,h_t,h_v 146 | 147 | def max_positions(self): 148 | """Maximum input length supported by the encoder.""" 149 | if self.embed_positions is None: 150 | return self.max_source_positions 151 | return min(self.max_source_positions, self.embed_positions.max_positions()) 152 | 153 | 154 | class FFN(nn.Module): 155 | def __init__(self, embed_dim,relu_dropout,res_dropout): 156 | super().__init__() 157 | self.embed_dim=embed_dim 158 | self.relu_dropout=relu_dropout 159 | self.res_dropout=res_dropout 160 | self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper 161 | self.fc2 = Linear(4*self.embed_dim, self.embed_dim) 162 | self.layer_norm = nn.LayerNorm(self.embed_dim) 163 | 164 | def forward(self,x): 165 | residual = x 166 | x = self.layer_norm(x) 167 | x = F.relu(self.fc1(x)) 168 | x = F.dropout(x, p=self.relu_dropout, training=self.training) 169 | x = self.fc2(x) 170 | x = F.dropout(x, p=self.res_dropout, training=self.training) 171 | x = residual + x 172 | return x 173 | 174 | 175 | class SPF(nn.Module): # Sparse Phased Fast attention 176 | def __init__(self,embed_dim,num_heads,attn_dropout,res_dropout,input_dim=None,stride=1,pf=None, 177 | generalized_attention=False,dim_head_down=1,use_fast=True,use_dense=False,device='cuda',): 178 | super().__init__() 179 | self.embed_dim=embed_dim 180 | self.num_heads=num_heads 181 | self.res_dropout=res_dropout 182 | self.stride, self.pf = stride, pf 183 | self.generalized_attention=generalized_attention 184 | self.dim_head=int(embed_dim/num_heads/dim_head_down) 185 | self.use_fast=use_fast 186 | self.use_dense=use_dense 187 | 188 | if not use_dense and use_fast: 189 | self.attn = Attention( 190 | self.embed_dim, 191 | heads = self.num_heads, 192 | dim_head = self.dim_head, 193 | generalized_attention = self.generalized_attention 194 | ) 195 | else: 196 | self.attn = MultiheadAttention(self.embed_dim, self.num_heads, input_dim=input_dim) 197 | self.layer_norm = nn.LayerNorm(self.embed_dim) 198 | input_dim=self.embed_dim if input_dim is None else input_dim 199 | self.layer_norm_kv = nn.LayerNorm(input_dim) 200 | self.device=device 201 | 202 | def forward(self,x,x_k=None,x_v=None,shift=0,reverse=False): 203 | sl,bs,_=x.shape 204 | residual = x 205 | x = self.layer_norm(x) 206 | context=x if x_k is None else x_k 207 | if self.use_dense: 208 | mask=sparse_mask(x, context, self.stride, self.pf); c=context 209 | else: 210 | fetch=sparsify(x, context, self.stride, self.pf, shift, self.device) 211 | x=x.unsqueeze(2).reshape(-1,1,self.embed_dim) 212 | c=fetch.permute(1,2,0,3).reshape(-1,self.pf,fetch.shape[-1]) 213 | if not self.use_fast: x=x.permute(1,0,2); c=c.permute(1,0,2) 214 | if x_k is not None: c = self.layer_norm_kv(c) 215 | if self.use_dense: 216 | x,_ = self.attn(x, c, c, reverse=reverse, attn_mask=mask) 217 | else: 218 | if self.use_fast: x = self.attn(x, context=c, reverse=reverse) 219 | else: x,_ = self.attn(x, c, c, reverse=reverse) 220 | x = F.dropout(x, p=self.res_dropout, training=self.training) 221 | x = x.squeeze(1).reshape(sl,bs,-1) 222 | x = residual + x 223 | return x 224 | 225 | 226 | class SPEncoderHiddenLayer(nn.Module): 227 | def __init__(self, embed_dim, input_dims, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, 228 | res_dropout=0.1, stride=None, pf_hidden=None, use_fast=False, use_dense=False, device='cuda'): 229 | super().__init__() 230 | self.mha_a=SPF(embed_dim,num_heads,attn_dropout,res_dropout,input_dims['a'],stride,pf_hidden,use_fast=use_fast,use_dense=use_dense,device=device) 231 | self.ffn_a=FFN(embed_dim,relu_dropout,res_dropout) 232 | self.mha_t=SPF(embed_dim,num_heads,attn_dropout,res_dropout,input_dims['t'],stride,pf_hidden,use_fast=use_fast,use_dense=use_dense,device=device) 233 | self.ffn_t=FFN(embed_dim,relu_dropout,res_dropout) 234 | self.mha_v=SPF(embed_dim,num_heads,attn_dropout,res_dropout,input_dims['v'],stride,pf_hidden,use_fast=use_fast,use_dense=use_dense,device=device) 235 | self.ffn_v=FFN(embed_dim,relu_dropout,res_dropout) 236 | 237 | def forward(self, a, t, v, h_a, h_t, h_v,shift_a=0,shift_t=0,shift_v=0): 238 | h_a=self.ffn_a(self.mha_a(h_a,a,a,shift_a)) 239 | h_t=self.ffn_t(self.mha_t(h_t,t,t,shift_t)) 240 | h_v=self.ffn_v(self.mha_v(h_v,v,v,shift_v)) 241 | return h_a, h_t, h_v 242 | 243 | def sum_fuse(x,y): return x+y 244 | 245 | class SPEncoderCrossLayer(nn.Module): 246 | def __init__(self, embed_dim, input_dims, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, 247 | res_dropout=0.1, pf_cross=None, use_fast=False, use_dense=False, device='cuda'): 248 | super().__init__() 249 | self.mha_at=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_cross,use_fast=use_fast,use_dense=use_dense,device=device) 250 | self.mha_tv=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_cross,use_fast=use_fast,use_dense=use_dense,device=device) 251 | self.mha_va=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_cross,use_fast=use_fast,use_dense=use_dense,device=device) 252 | self.ffn_at=FFN(embed_dim,relu_dropout,res_dropout) 253 | self.ffn_tv=FFN(embed_dim,relu_dropout,res_dropout) 254 | self.ffn_va=FFN(embed_dim,relu_dropout,res_dropout) 255 | self.fuse_a=self.fuse_t=self.fuse_v=sum_fuse 256 | 257 | def forward(self, h_a, h_t, h_v,shift_a=0,shift_t=0,shift_v=0): 258 | h_at=self.ffn_at(self.mha_at(h_a,h_t,h_t,shift_a)) 259 | h_tv=self.ffn_tv(self.mha_tv(h_t,h_v,h_v,shift_t)) 260 | h_va=self.ffn_va(self.mha_va(h_v,h_a,h_a,shift_v)) 261 | h_ta=self.ffn_at(self.mha_at(h_t,h_a,h_a,shift_t,True)) 262 | h_vt=self.ffn_tv(self.mha_tv(h_v,h_t,h_t,shift_v,True)) 263 | h_av=self.ffn_va(self.mha_va(h_a,h_v,h_v,shift_a,True)) 264 | return self.fuse_a(h_at,h_av), self.fuse_t(h_ta,h_tv), self.fuse_v(h_va,h_vt) 265 | 266 | 267 | class SPEncoderSelfLayer(nn.Module): 268 | def __init__(self, embed_dim, input_dims, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, 269 | res_dropout=0.1, pf_self=None, use_fast=False, use_dense=False, device='cuda'): 270 | super().__init__() 271 | self.embed_dim=embed_dim 272 | self.mha_a=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_self,use_fast=use_fast,use_dense=use_dense,device=device) 273 | self.ffn_a=FFN(embed_dim,relu_dropout,res_dropout) 274 | self.mha_t=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_self,use_fast=use_fast,use_dense=use_dense,device=device) 275 | self.ffn_t=FFN(embed_dim,relu_dropout,res_dropout) 276 | self.mha_v=SPF(embed_dim,num_heads,attn_dropout,res_dropout,pf=pf_self,use_fast=use_fast,use_dense=use_dense,device=device) 277 | self.ffn_v=FFN(embed_dim,relu_dropout,res_dropout) 278 | 279 | def forward(self, h_a, h_t, h_v,shift_a=0,shift_t=0,shift_v=0): 280 | h_a=self.ffn_a(self.mha_a(h_a,shift=shift_a)) 281 | h_t=self.ffn_t(self.mha_t(h_t,shift=shift_t)) 282 | h_v=self.ffn_v(self.mha_v(h_v,shift=shift_v)) 283 | return h_a, h_t, h_v 284 | 285 | 286 | def sparsify(hidden,context,stride,pf,shift=0,device='cuda'): 287 | h,bs,_ = hidden.shape 288 | c,_,dc = context.shape 289 | r=(torch.arange(h).to(device)+1)*stride-pf+shift 290 | r=r.unsqueeze(0).repeat(pf,1)+torch.arange(pf).unsqueeze(1).to(device) 291 | r=r.reshape([pf*h]) 292 | return context[r%c].reshape(pf,h,bs,dc) 293 | 294 | def Linear(in_features, out_features, bias=True): 295 | m = nn.Linear(in_features, out_features, bias) 296 | nn.init.xavier_uniform_(m.weight) 297 | if bias: nn.init.constant_(m.bias, 0.) 298 | return m 299 | 300 | def sparse_mask(hidden=None,context=None,stride=None, 301 | pf=None,h=None,c=None,cuda=None,shift=0): #generate 302 | h=hidden.size(0) if h is None else h 303 | c=context.size(0) if c is None else c 304 | mask=torch.ones(h,c)*torch.tensor(float('-inf')) 305 | for i in range(pf): 306 | k=(torch.arange(h)+1)*stride-pf+i+shift 307 | mask[torch.arange(h),k%c]=0 308 | if cuda or (context is not None and context.is_cuda): mask = mask.cuda() 309 | return mask --------------------------------------------------------------------------------