├── LICENSE ├── README.md ├── imgs ├── architecture.png └── cm.png ├── main.py ├── modules ├── multihead_attention.py ├── position_embedding.py └── transformer.py └── src ├── README.md ├── ctc.py ├── dataset.py ├── eval_metrics.py ├── models.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yao-Hung Hubert Tsai and Shaojie Bai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | 3 | # Multimodal Transformer for Unaligned Multimodal Language Sequences 4 | 5 | > Pytorch implementation for learning Multimodal Transformer for unaligned multimodal language sequences. 6 | 7 | Correspondence to: 8 | - Yao-Hung Hubert Tsai (yaohungt@cs.cmu.edu) 9 | - Shaojie Bai (shaojieb@andrew.cmu.edu) 10 | 11 | ## Paper 12 | [**Multimodal Transformer for Unaligned Multimodal Language Sequences**](https://arxiv.org/pdf/1906.00295.pdf)
13 | [Yao-Hung Hubert Tsai](https://yaohungt.github.io) *, [Shaojie Bai](https://jerrybai1995.github.io) *, [Paul Pu Liang](http://www.cs.cmu.edu/~pliang/), [J. Zico Kolter](http://zicokolter.com), [Louis-Philippe Morency](https://www.cs.cmu.edu/~morency/), and [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)
14 | Association for Computational Linguistics (ACL), 2019. (*equal contribution) 15 | 16 | Please cite our paper if you find our work useful for your research: 17 | 18 | ```tex 19 | @inproceedings{tsai2019MULT, 20 | title={Multimodal Transformer for Unaligned Multimodal Language Sequences}, 21 | author={Tsai, Yao-Hung Hubert and Bai, Shaojie and Liang, Paul Pu and Kolter, J. Zico and Morency, Louis-Philippe and Salakhutdinov, Ruslan}, 22 | booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 23 | month = {7}, 24 | year={2019}, 25 | address = {Florence, Italy}, 26 | publisher = {Association for Computational Linguistics}, 27 | } 28 | ``` 29 | 30 | ## Overview 31 | 32 | ### Overall Architecture for Multimodal Transformer 33 |

34 | 35 | 36 | Multimodal Transformer (MulT) merges multimodal time-series via a feed-forward fusion process from multiple directional pairwise crossmodal transformers. Specifically, each crossmodal transformer serves to repeatedly reinforce a *target modality* with the low-level features from another *source modality* by learning the attention across the two modalities' features. A MulT architecture hence models all pairs of modalities with such crossmodal transformers, followed by sequence models (e.g., self-attention transformer) that predicts using the fused features. 37 | 38 | 39 | ### Crossmodal Attention for Two Sequences from Distinct Modalities 40 |

41 | 42 | 43 | The core of our proposed model are crossmodal transformer and crossmodal attention module. 44 | 45 | ## Usage 46 | 47 | ### Prerequisites 48 | - Python 3.6/3.7 49 | - [Pytorch (>=1.0.0) and torchvision](https://pytorch.org/) 50 | - CUDA 10.0 or above 51 | 52 | ### Datasets 53 | 54 | Data files (containing processed MOSI, MOSEI and IEMOCAP datasets) can be downloaded from [here](https://www.dropbox.com/sh/hyzpgx1hp9nj37s/AAB7FhBqJOFDw2hEyvv2ZXHxa?dl=0). 55 | 56 | I personally used command line to download everything: 57 | ~~~~ 58 | wget https://www.dropbox.com/sh/hyzpgx1hp9nj37s/AADfY2s7gD_MkR76m03KS0K1a/Archive.zip?dl=1 59 | mv 'Archive.zip?dl=1' Archive.zip 60 | unzip Archive.zip 61 | ~~~~ 62 | 63 | To retrieve the meta information and the raw data, please refer to the [SDK for these datasets](https://github.com/A2Zadeh/CMU-MultimodalSDK). 64 | 65 | ### Run the Code 66 | 67 | 1. Create (empty) folders for data and pre-trained models: 68 | ~~~~ 69 | mkdir data pre_trained_models 70 | ~~~~ 71 | 72 | and put the downloaded data in 'data/'. 73 | 74 | 2. Command as follows 75 | ~~~~ 76 | python main.py [--FLAGS] 77 | ~~~~ 78 | 79 | Note that the defualt arguments are for unaligned version of MOSEI. For other datasets, please refer to Supplmentary. 80 | 81 | ### If Using CTC 82 | 83 | Transformer requires no CTC module. However, as we describe in the paper, CTC module offers an alternative to applying other kinds of sequence models (e.g., recurrent architectures) to unaligned multimodal streams. 84 | 85 | If you want to use the CTC module, plesase install warp-ctc from [here](https://github.com/baidu-research/warp-ctc). 86 | 87 | The quick version: 88 | ~~~~ 89 | git clone https://github.com/SeanNaren/warp-ctc.git 90 | cd warp-ctc 91 | mkdir build; cd build 92 | cmake .. 93 | make 94 | cd ../pytorch_binding 95 | python setup.py install 96 | export WARP_CTC_PATH=/home/xxx/warp-ctc/build 97 | ~~~~ 98 | 99 | ### Acknowledgement 100 | Some portion of the code were adapted from the [fairseq](https://github.com/pytorch/fairseq) repo. 101 | 102 | 103 | -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaohungt/Multimodal-Transformer/a670936824ee722c8494fd98d204977a1d663c7a/imgs/architecture.png -------------------------------------------------------------------------------- /imgs/cm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaohungt/Multimodal-Transformer/a670936824ee722c8494fd98d204977a1d663c7a/imgs/cm.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from src.utils import * 4 | from torch.utils.data import DataLoader 5 | from src import train 6 | 7 | 8 | parser = argparse.ArgumentParser(description='MOSEI Sentiment Analysis') 9 | parser.add_argument('-f', default='', type=str) 10 | 11 | # Fixed 12 | parser.add_argument('--model', type=str, default='MulT', 13 | help='name of the model to use (Transformer, etc.)') 14 | 15 | # Tasks 16 | parser.add_argument('--vonly', action='store_true', 17 | help='use the crossmodal fusion into v (default: False)') 18 | parser.add_argument('--aonly', action='store_true', 19 | help='use the crossmodal fusion into a (default: False)') 20 | parser.add_argument('--lonly', action='store_true', 21 | help='use the crossmodal fusion into l (default: False)') 22 | parser.add_argument('--aligned', action='store_true', 23 | help='consider aligned experiment or not (default: False)') 24 | parser.add_argument('--dataset', type=str, default='mosei_senti', 25 | help='dataset to use (default: mosei_senti)') 26 | parser.add_argument('--data_path', type=str, default='data', 27 | help='path for storing the dataset') 28 | 29 | # Dropouts 30 | parser.add_argument('--attn_dropout', type=float, default=0.1, 31 | help='attention dropout') 32 | parser.add_argument('--attn_dropout_a', type=float, default=0.0, 33 | help='attention dropout (for audio)') 34 | parser.add_argument('--attn_dropout_v', type=float, default=0.0, 35 | help='attention dropout (for visual)') 36 | parser.add_argument('--relu_dropout', type=float, default=0.1, 37 | help='relu dropout') 38 | parser.add_argument('--embed_dropout', type=float, default=0.25, 39 | help='embedding dropout') 40 | parser.add_argument('--res_dropout', type=float, default=0.1, 41 | help='residual block dropout') 42 | parser.add_argument('--out_dropout', type=float, default=0.0, 43 | help='output layer dropout') 44 | 45 | # Architecture 46 | parser.add_argument('--nlevels', type=int, default=5, 47 | help='number of layers in the network (default: 5)') 48 | parser.add_argument('--num_heads', type=int, default=5, 49 | help='number of heads for the transformer network (default: 5)') 50 | parser.add_argument('--attn_mask', action='store_false', 51 | help='use attention mask for Transformer (default: true)') 52 | 53 | # Tuning 54 | parser.add_argument('--batch_size', type=int, default=24, metavar='N', 55 | help='batch size (default: 24)') 56 | parser.add_argument('--clip', type=float, default=0.8, 57 | help='gradient clip value (default: 0.8)') 58 | parser.add_argument('--lr', type=float, default=1e-3, 59 | help='initial learning rate (default: 1e-3)') 60 | parser.add_argument('--optim', type=str, default='Adam', 61 | help='optimizer to use (default: Adam)') 62 | parser.add_argument('--num_epochs', type=int, default=40, 63 | help='number of epochs (default: 40)') 64 | parser.add_argument('--when', type=int, default=20, 65 | help='when to decay learning rate (default: 20)') 66 | parser.add_argument('--batch_chunk', type=int, default=1, 67 | help='number of chunks per batch (default: 1)') 68 | 69 | # Logistics 70 | parser.add_argument('--log_interval', type=int, default=30, 71 | help='frequency of result logging (default: 30)') 72 | parser.add_argument('--seed', type=int, default=1111, 73 | help='random seed') 74 | parser.add_argument('--no_cuda', action='store_true', 75 | help='do not use cuda') 76 | parser.add_argument('--name', type=str, default='mult', 77 | help='name of the trial (default: "mult")') 78 | args = parser.parse_args() 79 | 80 | torch.manual_seed(args.seed) 81 | dataset = str.lower(args.dataset.strip()) 82 | valid_partial_mode = args.lonly + args.vonly + args.aonly 83 | 84 | if valid_partial_mode == 0: 85 | args.lonly = args.vonly = args.aonly = True 86 | elif valid_partial_mode != 1: 87 | raise ValueError("You can only choose one of {l/v/a}only.") 88 | 89 | use_cuda = False 90 | 91 | output_dim_dict = { 92 | 'mosi': 1, 93 | 'mosei_senti': 1, 94 | 'iemocap': 8 95 | } 96 | 97 | criterion_dict = { 98 | 'iemocap': 'CrossEntropyLoss' 99 | } 100 | 101 | torch.set_default_tensor_type('torch.FloatTensor') 102 | if torch.cuda.is_available(): 103 | if args.no_cuda: 104 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 105 | else: 106 | torch.cuda.manual_seed(args.seed) 107 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 108 | use_cuda = True 109 | 110 | #################################################################### 111 | # 112 | # Load the dataset (aligned or non-aligned) 113 | # 114 | #################################################################### 115 | 116 | print("Start loading the data....") 117 | 118 | train_data = get_data(args, dataset, 'train') 119 | valid_data = get_data(args, dataset, 'valid') 120 | test_data = get_data(args, dataset, 'test') 121 | 122 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) 123 | valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=True) 124 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True) 125 | 126 | print('Finish loading the data....') 127 | if not args.aligned: 128 | print("### Note: You are running in unaligned mode.") 129 | 130 | #################################################################### 131 | # 132 | # Hyperparameters 133 | # 134 | #################################################################### 135 | 136 | hyp_params = args 137 | hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v = train_data.get_dim() 138 | hyp_params.l_len, hyp_params.a_len, hyp_params.v_len = train_data.get_seq_len() 139 | hyp_params.layers = args.nlevels 140 | hyp_params.use_cuda = use_cuda 141 | hyp_params.dataset = dataset 142 | hyp_params.when = args.when 143 | hyp_params.batch_chunk = args.batch_chunk 144 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_data), len(valid_data), len(test_data) 145 | hyp_params.model = str.upper(args.model.strip()) 146 | hyp_params.output_dim = output_dim_dict.get(dataset, 1) 147 | hyp_params.criterion = criterion_dict.get(dataset, 'L1Loss') 148 | 149 | 150 | if __name__ == '__main__': 151 | test_loss = train.initiate(hyp_params, train_loader, valid_loader, test_loader) 152 | 153 | -------------------------------------------------------------------------------- /modules/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | # Code adapted from the fairseq repo. 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | See "Attention Is All You Need" for more details. 12 | """ 13 | 14 | def __init__(self, embed_dim, num_heads, attn_dropout=0., 15 | bias=True, add_bias_kv=False, add_zero_attn=False): 16 | super().__init__() 17 | self.embed_dim = embed_dim 18 | self.num_heads = num_heads 19 | self.attn_dropout = attn_dropout 20 | self.head_dim = embed_dim // num_heads 21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 22 | self.scaling = self.head_dim ** -0.5 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 25 | self.register_parameter('in_proj_bias', None) 26 | if bias: 27 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 28 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 29 | 30 | if add_bias_kv: 31 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 32 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 33 | else: 34 | self.bias_k = self.bias_v = None 35 | 36 | self.add_zero_attn = add_zero_attn 37 | 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | nn.init.xavier_uniform_(self.in_proj_weight) 42 | nn.init.xavier_uniform_(self.out_proj.weight) 43 | if self.in_proj_bias is not None: 44 | nn.init.constant_(self.in_proj_bias, 0.) 45 | nn.init.constant_(self.out_proj.bias, 0.) 46 | if self.bias_k is not None: 47 | nn.init.xavier_normal_(self.bias_k) 48 | if self.bias_v is not None: 49 | nn.init.xavier_normal_(self.bias_v) 50 | 51 | def forward(self, query, key, value, attn_mask=None): 52 | """Input shape: Time x Batch x Channel 53 | Self-attention can be implemented by passing in the same arguments for 54 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 55 | `attn_mask` argument. Padding elements can be excluded from 56 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 57 | batch x src_len, where padding elements are indicated by 1s. 58 | """ 59 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 60 | kv_same = key.data_ptr() == value.data_ptr() 61 | 62 | tgt_len, bsz, embed_dim = query.size() 63 | assert embed_dim == self.embed_dim 64 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 65 | assert key.size() == value.size() 66 | 67 | aved_state = None 68 | 69 | if qkv_same: 70 | # self-attention 71 | q, k, v = self.in_proj_qkv(query) 72 | elif kv_same: 73 | # encoder-decoder attention 74 | q = self.in_proj_q(query) 75 | 76 | if key is None: 77 | assert value is None 78 | k = v = None 79 | else: 80 | k, v = self.in_proj_kv(key) 81 | else: 82 | q = self.in_proj_q(query) 83 | k = self.in_proj_k(key) 84 | v = self.in_proj_v(value) 85 | q = q * self.scaling 86 | 87 | if self.bias_k is not None: 88 | assert self.bias_v is not None 89 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 90 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 91 | if attn_mask is not None: 92 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 93 | 94 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 95 | if k is not None: 96 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 97 | if v is not None: 98 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 99 | 100 | src_len = k.size(1) 101 | 102 | if self.add_zero_attn: 103 | src_len += 1 104 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 105 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 106 | if attn_mask is not None: 107 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 108 | 109 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 110 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 111 | 112 | if attn_mask is not None: 113 | try: 114 | attn_weights += attn_mask.unsqueeze(0) 115 | except: 116 | print(attn_weights.shape) 117 | print(attn_mask.unsqueeze(0).shape) 118 | assert False 119 | 120 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 121 | # attn_weights = F.relu(attn_weights) 122 | # attn_weights = attn_weights / torch.max(attn_weights) 123 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 124 | 125 | attn = torch.bmm(attn_weights, v) 126 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 127 | 128 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 129 | attn = self.out_proj(attn) 130 | 131 | # average attention weights over heads 132 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 133 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 134 | return attn, attn_weights 135 | 136 | def in_proj_qkv(self, query): 137 | return self._in_proj(query).chunk(3, dim=-1) 138 | 139 | def in_proj_kv(self, key): 140 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 141 | 142 | def in_proj_q(self, query, **kwargs): 143 | return self._in_proj(query, end=self.embed_dim, **kwargs) 144 | 145 | def in_proj_k(self, key): 146 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 147 | 148 | def in_proj_v(self, value): 149 | return self._in_proj(value, start=2 * self.embed_dim) 150 | 151 | def _in_proj(self, input, start=0, end=None, **kwargs): 152 | weight = kwargs.get('weight', self.in_proj_weight) 153 | bias = kwargs.get('bias', self.in_proj_bias) 154 | weight = weight[start:end, :] 155 | if bias is not None: 156 | bias = bias[start:end] 157 | return F.linear(input, weight, bias) 158 | -------------------------------------------------------------------------------- /modules/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Code adapted from the fairseq repo. 7 | 8 | def make_positions(tensor, padding_idx, left_pad): 9 | """Replace non-padding symbols with their position numbers. 10 | Position numbers begin at padding_idx+1. 11 | Padding symbols are ignored, but it is necessary to specify whether padding 12 | is added on the left side (left_pad=True) or right side (left_pad=False). 13 | """ 14 | max_pos = padding_idx + 1 + tensor.size(1) 15 | device = tensor.get_device() 16 | buf_name = f'range_buf_{device}' 17 | if not hasattr(make_positions, buf_name): 18 | setattr(make_positions, buf_name, tensor.new()) 19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor)) 20 | if getattr(make_positions, buf_name).numel() < max_pos: 21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name)) 22 | mask = tensor.ne(padding_idx) 23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor) 24 | if left_pad: 25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) 26 | new_tensor = tensor.clone() 27 | return new_tensor.masked_scatter_(mask, positions[mask]).long() 28 | 29 | 30 | class SinusoidalPositionalEmbedding(nn.Module): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored, but it is necessary to specify whether padding 33 | is added on the left side (left_pad=True) or right side (left_pad=False). 34 | """ 35 | 36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128): 37 | super().__init__() 38 | self.embedding_dim = embedding_dim 39 | self.padding_idx = padding_idx 40 | self.left_pad = left_pad 41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-( 42 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 43 | 44 | @staticmethod 45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | """Build sinusoidal embeddings. 47 | This matches the implementation in tensor2tensor, but differs slightly 48 | from the description in Section 3.5 of "Attention Is All You Need". 49 | """ 50 | half_dim = embedding_dim // 2 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 55 | if embedding_dim % 2 == 1: 56 | # zero pad 57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 58 | if padding_idx is not None: 59 | emb[padding_idx, :] = 0 60 | return emb 61 | 62 | def forward(self, input): 63 | """Input is expected to be of size [bsz x seqlen].""" 64 | bsz, seq_len = input.size() 65 | max_pos = self.padding_idx + 1 + seq_len 66 | device = input.get_device() 67 | if device not in self.weights or max_pos > self.weights[device].size(0): 68 | # recompute/expand embeddings if needed 69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding( 70 | max_pos, 71 | self.embedding_dim, 72 | self.padding_idx, 73 | ) 74 | self.weights[device] = self.weights[device].type_as(self._float_tensor) 75 | positions = make_positions(input, self.padding_idx, self.left_pad) 76 | return self.weights[device].index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 77 | 78 | def max_positions(self): 79 | """Maximum number of supported positions.""" 80 | return int(1e5) # an arbitrary large number -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.position_embedding import SinusoidalPositionalEmbedding 5 | from modules.multihead_attention import MultiheadAttention 6 | import math 7 | 8 | 9 | class TransformerEncoder(nn.Module): 10 | """ 11 | Transformer encoder consisting of *args.encoder_layers* layers. Each layer 12 | is a :class:`TransformerEncoderLayer`. 13 | Args: 14 | embed_tokens (torch.nn.Embedding): input embedding 15 | num_heads (int): number of heads 16 | layers (int): number of layers 17 | attn_dropout (float): dropout applied on the attention weights 18 | relu_dropout (float): dropout applied on the first layer of the residual block 19 | res_dropout (float): dropout applied on the residual block 20 | attn_mask (bool): whether to apply mask on the attention weights 21 | """ 22 | 23 | def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0, 24 | embed_dropout=0.0, attn_mask=False): 25 | super().__init__() 26 | self.dropout = embed_dropout # Embedding dropout 27 | self.attn_dropout = attn_dropout 28 | self.embed_dim = embed_dim 29 | self.embed_scale = math.sqrt(embed_dim) 30 | self.embed_positions = SinusoidalPositionalEmbedding(embed_dim) 31 | 32 | self.attn_mask = attn_mask 33 | 34 | self.layers = nn.ModuleList([]) 35 | for layer in range(layers): 36 | new_layer = TransformerEncoderLayer(embed_dim, 37 | num_heads=num_heads, 38 | attn_dropout=attn_dropout, 39 | relu_dropout=relu_dropout, 40 | res_dropout=res_dropout, 41 | attn_mask=attn_mask) 42 | self.layers.append(new_layer) 43 | 44 | self.register_buffer('version', torch.Tensor([2])) 45 | self.normalize = True 46 | if self.normalize: 47 | self.layer_norm = LayerNorm(embed_dim) 48 | 49 | def forward(self, x_in, x_in_k = None, x_in_v = None): 50 | """ 51 | Args: 52 | x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 53 | x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 54 | x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)` 55 | Returns: 56 | dict: 57 | - **encoder_out** (Tensor): the last encoder layer's output of 58 | shape `(src_len, batch, embed_dim)` 59 | - **encoder_padding_mask** (ByteTensor): the positions of 60 | padding elements of shape `(batch, src_len)` 61 | """ 62 | # embed tokens and positions 63 | x = self.embed_scale * x_in 64 | if self.embed_positions is not None: 65 | x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 66 | x = F.dropout(x, p=self.dropout, training=self.training) 67 | 68 | if x_in_k is not None and x_in_v is not None: 69 | # embed tokens and positions 70 | x_k = self.embed_scale * x_in_k 71 | x_v = self.embed_scale * x_in_v 72 | if self.embed_positions is not None: 73 | x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 74 | x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 75 | x_k = F.dropout(x_k, p=self.dropout, training=self.training) 76 | x_v = F.dropout(x_v, p=self.dropout, training=self.training) 77 | 78 | # encoder layers 79 | intermediates = [x] 80 | for layer in self.layers: 81 | if x_in_k is not None and x_in_v is not None: 82 | x = layer(x, x_k, x_v) 83 | else: 84 | x = layer(x) 85 | intermediates.append(x) 86 | 87 | if self.normalize: 88 | x = self.layer_norm(x) 89 | 90 | return x 91 | 92 | def max_positions(self): 93 | """Maximum input length supported by the encoder.""" 94 | if self.embed_positions is None: 95 | return self.max_source_positions 96 | return min(self.max_source_positions, self.embed_positions.max_positions()) 97 | 98 | 99 | class TransformerEncoderLayer(nn.Module): 100 | """Encoder layer block. 101 | In the original paper each operation (multi-head attention or FFN) is 102 | postprocessed with: `dropout -> add residual -> layernorm`. In the 103 | tensor2tensor code they suggest that learning is more robust when 104 | preprocessing each layer with layernorm and postprocessing with: 105 | `dropout -> add residual`. We default to the approach in the paper, but the 106 | tensor2tensor approach can be enabled by setting 107 | *args.encoder_normalize_before* to ``True``. 108 | Args: 109 | embed_dim: Embedding dimension 110 | """ 111 | 112 | def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1, 113 | attn_mask=False): 114 | super().__init__() 115 | self.embed_dim = embed_dim 116 | self.num_heads = num_heads 117 | 118 | self.self_attn = MultiheadAttention( 119 | embed_dim=self.embed_dim, 120 | num_heads=self.num_heads, 121 | attn_dropout=attn_dropout 122 | ) 123 | self.attn_mask = attn_mask 124 | 125 | self.relu_dropout = relu_dropout 126 | self.res_dropout = res_dropout 127 | self.normalize_before = True 128 | 129 | self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper 130 | self.fc2 = Linear(4*self.embed_dim, self.embed_dim) 131 | self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) 132 | 133 | def forward(self, x, x_k=None, x_v=None): 134 | """ 135 | Args: 136 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 137 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape 138 | `(batch, src_len)` where padding elements are indicated by ``1``. 139 | x_k (Tensor): same as x 140 | x_v (Tensor): same as x 141 | Returns: 142 | encoded output of shape `(batch, src_len, embed_dim)` 143 | """ 144 | residual = x 145 | x = self.maybe_layer_norm(0, x, before=True) 146 | mask = buffered_future_mask(x, x_k) if self.attn_mask else None 147 | if x_k is None and x_v is None: 148 | x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask) 149 | else: 150 | x_k = self.maybe_layer_norm(0, x_k, before=True) 151 | x_v = self.maybe_layer_norm(0, x_v, before=True) 152 | x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask) 153 | x = F.dropout(x, p=self.res_dropout, training=self.training) 154 | x = residual + x 155 | x = self.maybe_layer_norm(0, x, after=True) 156 | 157 | residual = x 158 | x = self.maybe_layer_norm(1, x, before=True) 159 | x = F.relu(self.fc1(x)) 160 | x = F.dropout(x, p=self.relu_dropout, training=self.training) 161 | x = self.fc2(x) 162 | x = F.dropout(x, p=self.res_dropout, training=self.training) 163 | x = residual + x 164 | x = self.maybe_layer_norm(1, x, after=True) 165 | return x 166 | 167 | def maybe_layer_norm(self, i, x, before=False, after=False): 168 | assert before ^ after 169 | if after ^ self.normalize_before: 170 | return self.layer_norms[i](x) 171 | else: 172 | return x 173 | 174 | def fill_with_neg_inf(t): 175 | """FP16-compatible function that fills a tensor with -inf.""" 176 | return t.float().fill_(float('-inf')).type_as(t) 177 | 178 | 179 | def buffered_future_mask(tensor, tensor2=None): 180 | dim1 = dim2 = tensor.size(0) 181 | if tensor2 is not None: 182 | dim2 = tensor2.size(0) 183 | future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1)) 184 | if tensor.is_cuda: 185 | future_mask = future_mask.cuda() 186 | return future_mask[:dim1, :dim2] 187 | 188 | 189 | def Linear(in_features, out_features, bias=True): 190 | m = nn.Linear(in_features, out_features, bias) 191 | nn.init.xavier_uniform_(m.weight) 192 | if bias: 193 | nn.init.constant_(m.bias, 0.) 194 | return m 195 | 196 | 197 | def LayerNorm(embedding_dim): 198 | m = nn.LayerNorm(embedding_dim) 199 | return m 200 | 201 | 202 | if __name__ == '__main__': 203 | encoder = TransformerEncoder(300, 4, 2) 204 | x = torch.tensor(torch.rand(20, 2, 300)) 205 | print(encoder(x).shape) 206 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # MulT model and Multimodal Sentiment Analysis Benchmarks 2 | 3 | This directory contains the model architecture for Multimodal Transformer (MulT) as well as the three major multimodal sentiment analysis benchmarks we used in the paper. All datasets should be put in the `../data` folder (if it does not exist, use `mkdir` to create one). Depending on the dataset, we may have different sets of evaluation metrics (see `eval_metrics.py`). 4 | -------------------------------------------------------------------------------- /src/ctc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CTCModule(nn.Module): 6 | def __init__(self, in_dim, out_seq_len): 7 | ''' 8 | This module is performing alignment from A (e.g., audio) to B (e.g., text). 9 | :param in_dim: Dimension for input modality A 10 | :param out_seq_len: Sequence length for output modality B 11 | ''' 12 | super(CTCModule, self).__init__() 13 | # Use LSTM for predicting the position from A to B 14 | self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank 15 | 16 | self.out_seq_len = out_seq_len 17 | 18 | self.softmax = nn.Softmax(dim=2) 19 | def forward(self, x): 20 | ''' 21 | :input x: Input with shape [batch_size x in_seq_len x in_dim] 22 | ''' 23 | # NOTE that the index 0 refers to blank. 24 | pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x) 25 | 26 | prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1 27 | prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len 28 | prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len 29 | pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim 30 | 31 | # pseudo_aligned_out is regarded as the aligned A (w.r.t B) 32 | return pseudo_aligned_out, (pred_output_position_inclu_blank) 33 | 34 | 35 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import Dataset 3 | import pickle 4 | import os 5 | from scipy import signal 6 | import torch 7 | 8 | if torch.cuda.is_available(): 9 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 10 | else: 11 | torch.set_default_tensor_type('torch.FloatTensor') 12 | 13 | ############################################################################################ 14 | # This file provides basic processing script for the multimodal datasets we use. For other 15 | # datasets, small modifications may be needed (depending on the type of the data, etc.) 16 | ############################################################################################ 17 | 18 | 19 | class Multimodal_Datasets(Dataset): 20 | def __init__(self, dataset_path, data='mosei_senti', split_type='train', if_align=False): 21 | super(Multimodal_Datasets, self).__init__() 22 | dataset_path = os.path.join(dataset_path, data+'_data.pkl' if if_align else data+'_data_noalign.pkl' ) 23 | dataset = pickle.load(open(dataset_path, 'rb')) 24 | 25 | # These are torch tensors 26 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach() 27 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach() 28 | self.audio = dataset[split_type]['audio'].astype(np.float32) 29 | self.audio[self.audio == -np.inf] = 0 30 | self.audio = torch.tensor(self.audio).cpu().detach() 31 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach() 32 | 33 | # Note: this is STILL an numpy array 34 | self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None 35 | 36 | self.data = data 37 | 38 | self.n_modalities = 3 # vision/ text/ audio 39 | def get_n_modalities(self): 40 | return self.n_modalities 41 | def get_seq_len(self): 42 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1] 43 | def get_dim(self): 44 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] 45 | def get_lbl_info(self): 46 | # return number_of_labels, label_dim 47 | return self.labels.shape[1], self.labels.shape[2] 48 | def __len__(self): 49 | return len(self.labels) 50 | def __getitem__(self, index): 51 | X = (index, self.text[index], self.audio[index], self.vision[index]) 52 | Y = self.labels[index] 53 | META = (0,0,0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2]) 54 | if self.data == 'mosi': 55 | META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'), self.meta[index][2].decode('UTF-8')) 56 | if self.data == 'iemocap': 57 | Y = torch.argmax(Y, dim=-1) 58 | return X, Y, META 59 | 60 | -------------------------------------------------------------------------------- /src/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import classification_report 4 | from sklearn.metrics import confusion_matrix 5 | from sklearn.metrics import precision_recall_fscore_support 6 | from sklearn.metrics import accuracy_score, f1_score 7 | 8 | 9 | def multiclass_acc(preds, truths): 10 | """ 11 | Compute the multiclass accuracy w.r.t. groundtruth 12 | 13 | :param preds: Float array representing the predictions, dimension (N,) 14 | :param truths: Float/int array representing the groundtruth classes, dimension (N,) 15 | :return: Classification accuracy 16 | """ 17 | return np.sum(np.round(preds) == np.round(truths)) / float(len(truths)) 18 | 19 | 20 | def weighted_accuracy(test_preds_emo, test_truth_emo): 21 | true_label = (test_truth_emo > 0) 22 | predicted_label = (test_preds_emo > 0) 23 | tp = float(np.sum((true_label==1) & (predicted_label==1))) 24 | tn = float(np.sum((true_label==0) & (predicted_label==0))) 25 | p = float(np.sum(true_label==1)) 26 | n = float(np.sum(true_label==0)) 27 | 28 | return (tp * (n/p) +tn) / (2*n) 29 | 30 | 31 | def eval_mosei_senti(results, truths, exclude_zero=False): 32 | test_preds = results.view(-1).cpu().detach().numpy() 33 | test_truth = truths.view(-1).cpu().detach().numpy() 34 | 35 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)]) 36 | 37 | test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.) 38 | test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.) 39 | test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.) 40 | test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.) 41 | 42 | mae = np.mean(np.absolute(test_preds - test_truth)) # Average L1 distance between preds and truths 43 | corr = np.corrcoef(test_preds, test_truth)[0][1] 44 | mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7) 45 | mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5) 46 | f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted') 47 | binary_truth = (test_truth[non_zeros] > 0) 48 | binary_preds = (test_preds[non_zeros] > 0) 49 | 50 | print("MAE: ", mae) 51 | print("Correlation Coefficient: ", corr) 52 | print("mult_acc_7: ", mult_a7) 53 | print("mult_acc_5: ", mult_a5) 54 | print("F1 score: ", f_score) 55 | print("Accuracy: ", accuracy_score(binary_truth, binary_preds)) 56 | 57 | print("-" * 50) 58 | 59 | 60 | def eval_mosi(results, truths, exclude_zero=False): 61 | return eval_mosei_senti(results, truths, exclude_zero) 62 | 63 | 64 | def eval_iemocap(results, truths, single=-1): 65 | emos = ["Neutral", "Happy", "Sad", "Angry"] 66 | if single < 0: 67 | test_preds = results.view(-1, 4, 2).cpu().detach().numpy() 68 | test_truth = truths.view(-1, 4).cpu().detach().numpy() 69 | 70 | for emo_ind in range(4): 71 | print(f"{emos[emo_ind]}: ") 72 | test_preds_i = np.argmax(test_preds[:,emo_ind],axis=1) 73 | test_truth_i = test_truth[:,emo_ind] 74 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted') 75 | acc = accuracy_score(test_truth_i, test_preds_i) 76 | print(" - F1 Score: ", f1) 77 | print(" - Accuracy: ", acc) 78 | else: 79 | test_preds = results.view(-1, 2).cpu().detach().numpy() 80 | test_truth = truths.view(-1).cpu().detach().numpy() 81 | 82 | print(f"{emos[single]}: ") 83 | test_preds_i = np.argmax(test_preds,axis=1) 84 | test_truth_i = test_truth 85 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted') 86 | acc = accuracy_score(test_truth_i, test_preds_i) 87 | print(" - F1 Score: ", f1) 88 | print(" - Accuracy: ", acc) 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from modules.transformer import TransformerEncoder 6 | 7 | 8 | class MULTModel(nn.Module): 9 | def __init__(self, hyp_params): 10 | """ 11 | Construct a MulT model. 12 | """ 13 | super(MULTModel, self).__init__() 14 | self.orig_d_l, self.orig_d_a, self.orig_d_v = hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v 15 | self.d_l, self.d_a, self.d_v = 30, 30, 30 16 | self.vonly = hyp_params.vonly 17 | self.aonly = hyp_params.aonly 18 | self.lonly = hyp_params.lonly 19 | self.num_heads = hyp_params.num_heads 20 | self.layers = hyp_params.layers 21 | self.attn_dropout = hyp_params.attn_dropout 22 | self.attn_dropout_a = hyp_params.attn_dropout_a 23 | self.attn_dropout_v = hyp_params.attn_dropout_v 24 | self.relu_dropout = hyp_params.relu_dropout 25 | self.res_dropout = hyp_params.res_dropout 26 | self.out_dropout = hyp_params.out_dropout 27 | self.embed_dropout = hyp_params.embed_dropout 28 | self.attn_mask = hyp_params.attn_mask 29 | 30 | combined_dim = self.d_l + self.d_a + self.d_v 31 | 32 | self.partial_mode = self.lonly + self.aonly + self.vonly 33 | if self.partial_mode == 1: 34 | combined_dim = 2 * self.d_l # assuming d_l == d_a == d_v 35 | else: 36 | combined_dim = 2 * (self.d_l + self.d_a + self.d_v) 37 | 38 | output_dim = hyp_params.output_dim # This is actually not a hyperparameter :-) 39 | 40 | # 1. Temporal convolutional layers 41 | self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=1, padding=0, bias=False) 42 | self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=1, padding=0, bias=False) 43 | self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=1, padding=0, bias=False) 44 | 45 | # 2. Crossmodal Attentions 46 | if self.lonly: 47 | self.trans_l_with_a = self.get_network(self_type='la') 48 | self.trans_l_with_v = self.get_network(self_type='lv') 49 | if self.aonly: 50 | self.trans_a_with_l = self.get_network(self_type='al') 51 | self.trans_a_with_v = self.get_network(self_type='av') 52 | if self.vonly: 53 | self.trans_v_with_l = self.get_network(self_type='vl') 54 | self.trans_v_with_a = self.get_network(self_type='va') 55 | 56 | # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.) 57 | # [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1) 58 | self.trans_l_mem = self.get_network(self_type='l_mem', layers=3) 59 | self.trans_a_mem = self.get_network(self_type='a_mem', layers=3) 60 | self.trans_v_mem = self.get_network(self_type='v_mem', layers=3) 61 | 62 | # Projection layers 63 | self.proj1 = nn.Linear(combined_dim, combined_dim) 64 | self.proj2 = nn.Linear(combined_dim, combined_dim) 65 | self.out_layer = nn.Linear(combined_dim, output_dim) 66 | 67 | def get_network(self, self_type='l', layers=-1): 68 | if self_type in ['l', 'al', 'vl']: 69 | embed_dim, attn_dropout = self.d_l, self.attn_dropout 70 | elif self_type in ['a', 'la', 'va']: 71 | embed_dim, attn_dropout = self.d_a, self.attn_dropout_a 72 | elif self_type in ['v', 'lv', 'av']: 73 | embed_dim, attn_dropout = self.d_v, self.attn_dropout_v 74 | elif self_type == 'l_mem': 75 | embed_dim, attn_dropout = 2*self.d_l, self.attn_dropout 76 | elif self_type == 'a_mem': 77 | embed_dim, attn_dropout = 2*self.d_a, self.attn_dropout 78 | elif self_type == 'v_mem': 79 | embed_dim, attn_dropout = 2*self.d_v, self.attn_dropout 80 | else: 81 | raise ValueError("Unknown network type") 82 | 83 | return TransformerEncoder(embed_dim=embed_dim, 84 | num_heads=self.num_heads, 85 | layers=max(self.layers, layers), 86 | attn_dropout=attn_dropout, 87 | relu_dropout=self.relu_dropout, 88 | res_dropout=self.res_dropout, 89 | embed_dropout=self.embed_dropout, 90 | attn_mask=self.attn_mask) 91 | 92 | def forward(self, x_l, x_a, x_v): 93 | """ 94 | text, audio, and vision should have dimension [batch_size, seq_len, n_features] 95 | """ 96 | x_l = F.dropout(x_l.transpose(1, 2), p=self.embed_dropout, training=self.training) 97 | x_a = x_a.transpose(1, 2) 98 | x_v = x_v.transpose(1, 2) 99 | 100 | # Project the textual/visual/audio features 101 | proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l) 102 | proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a) 103 | proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v) 104 | proj_x_a = proj_x_a.permute(2, 0, 1) 105 | proj_x_v = proj_x_v.permute(2, 0, 1) 106 | proj_x_l = proj_x_l.permute(2, 0, 1) 107 | 108 | if self.lonly: 109 | # (V,A) --> L 110 | h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a) # Dimension (L, N, d_l) 111 | h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v) # Dimension (L, N, d_l) 112 | h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2) 113 | h_ls = self.trans_l_mem(h_ls) 114 | if type(h_ls) == tuple: 115 | h_ls = h_ls[0] 116 | last_h_l = last_hs = h_ls[-1] # Take the last output for prediction 117 | 118 | if self.aonly: 119 | # (L,V) --> A 120 | h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l) 121 | h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v) 122 | h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2) 123 | h_as = self.trans_a_mem(h_as) 124 | if type(h_as) == tuple: 125 | h_as = h_as[0] 126 | last_h_a = last_hs = h_as[-1] 127 | 128 | if self.vonly: 129 | # (L,A) --> V 130 | h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l) 131 | h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a) 132 | h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2) 133 | h_vs = self.trans_v_mem(h_vs) 134 | if type(h_vs) == tuple: 135 | h_vs = h_vs[0] 136 | last_h_v = last_hs = h_vs[-1] 137 | 138 | if self.partial_mode == 3: 139 | last_hs = torch.cat([last_h_l, last_h_a, last_h_v], dim=1) 140 | 141 | # A residual block 142 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training)) 143 | last_hs_proj += last_hs 144 | 145 | output = self.out_layer(last_hs_proj) 146 | return output, last_hs -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import sys 4 | from src import models 5 | from src import ctc 6 | from src.utils import * 7 | import torch.optim as optim 8 | import numpy as np 9 | import time 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | import os 12 | import pickle 13 | 14 | from sklearn.metrics import classification_report 15 | from sklearn.metrics import confusion_matrix 16 | from sklearn.metrics import precision_recall_fscore_support 17 | from sklearn.metrics import accuracy_score, f1_score 18 | from src.eval_metrics import * 19 | 20 | 21 | #################################################################### 22 | # 23 | # Construct the model and the CTC module (which may not be needed) 24 | # 25 | #################################################################### 26 | 27 | def get_CTC_module(hyp_params): 28 | a2l_module = getattr(ctc, 'CTCModule')(in_dim=hyp_params.orig_d_a, out_seq_len=hyp_params.l_len) 29 | v2l_module = getattr(ctc, 'CTCModule')(in_dim=hyp_params.orig_d_v, out_seq_len=hyp_params.l_len) 30 | return a2l_module, v2l_module 31 | 32 | def initiate(hyp_params, train_loader, valid_loader, test_loader): 33 | model = getattr(models, hyp_params.model+'Model')(hyp_params) 34 | 35 | if hyp_params.use_cuda: 36 | model = model.cuda() 37 | 38 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr) 39 | criterion = getattr(nn, hyp_params.criterion)() 40 | if hyp_params.aligned or hyp_params.model=='MULT': 41 | ctc_criterion = None 42 | ctc_a2l_module, ctc_v2l_module = None, None 43 | ctc_a2l_optimizer, ctc_v2l_optimizer = None, None 44 | else: 45 | from warpctc_pytorch import CTCLoss 46 | ctc_criterion = CTCLoss() 47 | ctc_a2l_module, ctc_v2l_module = get_CTC_module(hyp_params) 48 | if hyp_params.use_cuda: 49 | ctc_a2l_module, ctc_v2l_module = ctc_a2l_module.cuda(), ctc_v2l_module.cuda() 50 | ctc_a2l_optimizer = getattr(optim, hyp_params.optim)(ctc_a2l_module.parameters(), lr=hyp_params.lr) 51 | ctc_v2l_optimizer = getattr(optim, hyp_params.optim)(ctc_v2l_module.parameters(), lr=hyp_params.lr) 52 | 53 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True) 54 | settings = {'model': model, 55 | 'optimizer': optimizer, 56 | 'criterion': criterion, 57 | 'ctc_a2l_module': ctc_a2l_module, 58 | 'ctc_v2l_module': ctc_v2l_module, 59 | 'ctc_a2l_optimizer': ctc_a2l_optimizer, 60 | 'ctc_v2l_optimizer': ctc_v2l_optimizer, 61 | 'ctc_criterion': ctc_criterion, 62 | 'scheduler': scheduler} 63 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader) 64 | 65 | 66 | #################################################################### 67 | # 68 | # Training and evaluation scripts 69 | # 70 | #################################################################### 71 | 72 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader): 73 | model = settings['model'] 74 | optimizer = settings['optimizer'] 75 | criterion = settings['criterion'] 76 | 77 | ctc_a2l_module = settings['ctc_a2l_module'] 78 | ctc_v2l_module = settings['ctc_v2l_module'] 79 | ctc_a2l_optimizer = settings['ctc_a2l_optimizer'] 80 | ctc_v2l_optimizer = settings['ctc_v2l_optimizer'] 81 | ctc_criterion = settings['ctc_criterion'] 82 | 83 | scheduler = settings['scheduler'] 84 | 85 | 86 | def train(model, optimizer, criterion, ctc_a2l_module, ctc_v2l_module, ctc_a2l_optimizer, ctc_v2l_optimizer, ctc_criterion): 87 | epoch_loss = 0 88 | model.train() 89 | num_batches = hyp_params.n_train // hyp_params.batch_size 90 | proc_loss, proc_size = 0, 0 91 | start_time = time.time() 92 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(train_loader): 93 | sample_ind, text, audio, vision = batch_X 94 | eval_attr = batch_Y.squeeze(-1) # if num of labels is 1 95 | 96 | model.zero_grad() 97 | if ctc_criterion is not None: 98 | ctc_a2l_module.zero_grad() 99 | ctc_v2l_module.zero_grad() 100 | 101 | if hyp_params.use_cuda: 102 | with torch.cuda.device(0): 103 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 104 | if hyp_params.dataset == 'iemocap': 105 | eval_attr = eval_attr.long() 106 | 107 | batch_size = text.size(0) 108 | batch_chunk = hyp_params.batch_chunk 109 | 110 | ######## CTC STARTS ######## Do not worry about this if not working on CTC 111 | if ctc_criterion is not None: 112 | ctc_a2l_net = nn.DataParallel(ctc_a2l_module) if batch_size > 10 else ctc_a2l_module 113 | ctc_v2l_net = nn.DataParallel(ctc_v2l_module) if batch_size > 10 else ctc_v2l_module 114 | 115 | audio, a2l_position = ctc_a2l_net(audio) # audio now is the aligned to text 116 | vision, v2l_position = ctc_v2l_net(vision) 117 | 118 | ## Compute the ctc loss 119 | l_len, a_len, v_len = hyp_params.l_len, hyp_params.a_len, hyp_params.v_len 120 | # Output Labels 121 | l_position = torch.tensor([i+1 for i in range(l_len)]*batch_size).int().cpu() 122 | # Specifying each output length 123 | l_length = torch.tensor([l_len]*batch_size).int().cpu() 124 | # Specifying each input length 125 | a_length = torch.tensor([a_len]*batch_size).int().cpu() 126 | v_length = torch.tensor([v_len]*batch_size).int().cpu() 127 | 128 | ctc_a2l_loss = ctc_criterion(a2l_position.transpose(0,1).cpu(), l_position, a_length, l_length) 129 | ctc_v2l_loss = ctc_criterion(v2l_position.transpose(0,1).cpu(), l_position, v_length, l_length) 130 | ctc_loss = ctc_a2l_loss + ctc_v2l_loss 131 | ctc_loss = ctc_loss.cuda() if hyp_params.use_cuda else ctc_loss 132 | else: 133 | ctc_loss = 0 134 | ######## CTC ENDS ######## 135 | 136 | combined_loss = 0 137 | net = nn.DataParallel(model) if batch_size > 10 else model 138 | if batch_chunk > 1: 139 | raw_loss = combined_loss = 0 140 | text_chunks = text.chunk(batch_chunk, dim=0) 141 | audio_chunks = audio.chunk(batch_chunk, dim=0) 142 | vision_chunks = vision.chunk(batch_chunk, dim=0) 143 | eval_attr_chunks = eval_attr.chunk(batch_chunk, dim=0) 144 | 145 | for i in range(batch_chunk): 146 | text_i, audio_i, vision_i = text_chunks[i], audio_chunks[i], vision_chunks[i] 147 | eval_attr_i = eval_attr_chunks[i] 148 | preds_i, hiddens_i = net(text_i, audio_i, vision_i) 149 | 150 | if hyp_params.dataset == 'iemocap': 151 | preds_i = preds_i.view(-1, 2) 152 | eval_attr_i = eval_attr_i.view(-1) 153 | raw_loss_i = criterion(preds_i, eval_attr_i) / batch_chunk 154 | raw_loss += raw_loss_i 155 | raw_loss_i.backward() 156 | ctc_loss.backward() 157 | combined_loss = raw_loss + ctc_loss 158 | else: 159 | preds, hiddens = net(text, audio, vision) 160 | if hyp_params.dataset == 'iemocap': 161 | preds = preds.view(-1, 2) 162 | eval_attr = eval_attr.view(-1) 163 | raw_loss = criterion(preds, eval_attr) 164 | combined_loss = raw_loss + ctc_loss 165 | combined_loss.backward() 166 | 167 | if ctc_criterion is not None: 168 | torch.nn.utils.clip_grad_norm_(ctc_a2l_module.parameters(), hyp_params.clip) 169 | torch.nn.utils.clip_grad_norm_(ctc_v2l_module.parameters(), hyp_params.clip) 170 | ctc_a2l_optimizer.step() 171 | ctc_v2l_optimizer.step() 172 | 173 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip) 174 | optimizer.step() 175 | 176 | proc_loss += raw_loss.item() * batch_size 177 | proc_size += batch_size 178 | epoch_loss += combined_loss.item() * batch_size 179 | if i_batch % hyp_params.log_interval == 0 and i_batch > 0: 180 | avg_loss = proc_loss / proc_size 181 | elapsed_time = time.time() - start_time 182 | print('Epoch {:2d} | Batch {:3d}/{:3d} | Time/Batch(ms) {:5.2f} | Train Loss {:5.4f}'. 183 | format(epoch, i_batch, num_batches, elapsed_time * 1000 / hyp_params.log_interval, avg_loss)) 184 | proc_loss, proc_size = 0, 0 185 | start_time = time.time() 186 | 187 | return epoch_loss / hyp_params.n_train 188 | 189 | def evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=False): 190 | model.eval() 191 | loader = test_loader if test else valid_loader 192 | total_loss = 0.0 193 | 194 | results = [] 195 | truths = [] 196 | 197 | with torch.no_grad(): 198 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(loader): 199 | sample_ind, text, audio, vision = batch_X 200 | eval_attr = batch_Y.squeeze(dim=-1) # if num of labels is 1 201 | 202 | if hyp_params.use_cuda: 203 | with torch.cuda.device(0): 204 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 205 | if hyp_params.dataset == 'iemocap': 206 | eval_attr = eval_attr.long() 207 | 208 | batch_size = text.size(0) 209 | 210 | if (ctc_a2l_module is not None) and (ctc_v2l_module is not None): 211 | ctc_a2l_net = nn.DataParallel(ctc_a2l_module) if batch_size > 10 else ctc_a2l_module 212 | ctc_v2l_net = nn.DataParallel(ctc_v2l_module) if batch_size > 10 else ctc_v2l_module 213 | audio, _ = ctc_a2l_net(audio) # audio aligned to text 214 | vision, _ = ctc_v2l_net(vision) # vision aligned to text 215 | 216 | net = nn.DataParallel(model) if batch_size > 10 else model 217 | preds, _ = net(text, audio, vision) 218 | if hyp_params.dataset == 'iemocap': 219 | preds = preds.view(-1, 2) 220 | eval_attr = eval_attr.view(-1) 221 | total_loss += criterion(preds, eval_attr).item() * batch_size 222 | 223 | # Collect the results into dictionary 224 | results.append(preds) 225 | truths.append(eval_attr) 226 | 227 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 228 | 229 | results = torch.cat(results) 230 | truths = torch.cat(truths) 231 | return avg_loss, results, truths 232 | 233 | best_valid = 1e8 234 | for epoch in range(1, hyp_params.num_epochs+1): 235 | start = time.time() 236 | train(model, optimizer, criterion, ctc_a2l_module, ctc_v2l_module, ctc_a2l_optimizer, ctc_v2l_optimizer, ctc_criterion) 237 | val_loss, _, _ = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=False) 238 | test_loss, _, _ = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=True) 239 | 240 | end = time.time() 241 | duration = end-start 242 | scheduler.step(val_loss) # Decay learning rate by validation loss 243 | 244 | print("-"*50) 245 | print('Epoch {:2d} | Time {:5.4f} sec | Valid Loss {:5.4f} | Test Loss {:5.4f}'.format(epoch, duration, val_loss, test_loss)) 246 | print("-"*50) 247 | 248 | if val_loss < best_valid: 249 | print(f"Saved model at pre_trained_models/{hyp_params.name}.pt!") 250 | save_model(hyp_params, model, name=hyp_params.name) 251 | best_valid = val_loss 252 | 253 | model = load_model(hyp_params, name=hyp_params.name) 254 | _, results, truths = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=True) 255 | 256 | if hyp_params.dataset == "mosei_senti": 257 | eval_mosei_senti(results, truths, True) 258 | elif hyp_params.dataset == 'mosi': 259 | eval_mosi(results, truths, True) 260 | elif hyp_params.dataset == 'iemocap': 261 | eval_iemocap(results, truths) 262 | 263 | sys.stdout.flush() 264 | input('[Press Any Key to start another run]') 265 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from src.dataset import Multimodal_Datasets 4 | 5 | 6 | def get_data(args, dataset, split='train'): 7 | alignment = 'a' if args.aligned else 'na' 8 | data_path = os.path.join(args.data_path, dataset) + f'_{split}_{alignment}.dt' 9 | if not os.path.exists(data_path): 10 | print(f" - Creating new {split} data") 11 | data = Multimodal_Datasets(args.data_path, dataset, split, args.aligned) 12 | torch.save(data, data_path) 13 | else: 14 | print(f" - Found cached {split} data") 15 | data = torch.load(data_path) 16 | return data 17 | 18 | 19 | def save_load_name(args, name=''): 20 | if args.aligned: 21 | name = name if len(name) > 0 else 'aligned_model' 22 | elif not args.aligned: 23 | name = name if len(name) > 0 else 'nonaligned_model' 24 | 25 | return name + '_' + args.model 26 | 27 | 28 | def save_model(args, model, name=''): 29 | name = save_load_name(args, name) 30 | torch.save(model, f'pre_trained_models/{name}.pt') 31 | 32 | 33 | def load_model(args, name=''): 34 | name = save_load_name(args, name) 35 | model = torch.load(f'pre_trained_models/{name}.pt') 36 | return model 37 | --------------------------------------------------------------------------------