├── LICENSE
├── README.md
├── imgs
├── architecture.png
└── cm.png
├── main.py
├── modules
├── multihead_attention.py
├── position_embedding.py
└── transformer.py
└── src
├── README.md
├── ctc.py
├── dataset.py
├── eval_metrics.py
├── models.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yao-Hung Hubert Tsai and Shaojie Bai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Multimodal Transformer for Unaligned Multimodal Language Sequences
4 |
5 | > Pytorch implementation for learning Multimodal Transformer for unaligned multimodal language sequences.
6 |
7 | Correspondence to:
8 | - Yao-Hung Hubert Tsai (yaohungt@cs.cmu.edu)
9 | - Shaojie Bai (shaojieb@andrew.cmu.edu)
10 |
11 | ## Paper
12 | [**Multimodal Transformer for Unaligned Multimodal Language Sequences**](https://arxiv.org/pdf/1906.00295.pdf)
13 | [Yao-Hung Hubert Tsai](https://yaohungt.github.io) *, [Shaojie Bai](https://jerrybai1995.github.io) *, [Paul Pu Liang](http://www.cs.cmu.edu/~pliang/), [J. Zico Kolter](http://zicokolter.com), [Louis-Philippe Morency](https://www.cs.cmu.edu/~morency/), and [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)
14 | Association for Computational Linguistics (ACL), 2019. (*equal contribution)
15 |
16 | Please cite our paper if you find our work useful for your research:
17 |
18 | ```tex
19 | @inproceedings{tsai2019MULT,
20 | title={Multimodal Transformer for Unaligned Multimodal Language Sequences},
21 | author={Tsai, Yao-Hung Hubert and Bai, Shaojie and Liang, Paul Pu and Kolter, J. Zico and Morency, Louis-Philippe and Salakhutdinov, Ruslan},
22 | booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
23 | month = {7},
24 | year={2019},
25 | address = {Florence, Italy},
26 | publisher = {Association for Computational Linguistics},
27 | }
28 | ```
29 |
30 | ## Overview
31 |
32 | ### Overall Architecture for Multimodal Transformer
33 |
34 |
35 |
36 | Multimodal Transformer (MulT) merges multimodal time-series via a feed-forward fusion process from multiple directional pairwise crossmodal transformers. Specifically, each crossmodal transformer serves to repeatedly reinforce a *target modality* with the low-level features from another *source modality* by learning the attention across the two modalities' features. A MulT architecture hence models all pairs of modalities with such crossmodal transformers, followed by sequence models (e.g., self-attention transformer) that predicts using the fused features.
37 |
38 |
39 | ### Crossmodal Attention for Two Sequences from Distinct Modalities
40 |
41 |
42 |
43 | The core of our proposed model are crossmodal transformer and crossmodal attention module.
44 |
45 | ## Usage
46 |
47 | ### Prerequisites
48 | - Python 3.6/3.7
49 | - [Pytorch (>=1.0.0) and torchvision](https://pytorch.org/)
50 | - CUDA 10.0 or above
51 |
52 | ### Datasets
53 |
54 | Data files (containing processed MOSI, MOSEI and IEMOCAP datasets) can be downloaded from [here](https://www.dropbox.com/sh/hyzpgx1hp9nj37s/AAB7FhBqJOFDw2hEyvv2ZXHxa?dl=0).
55 |
56 | I personally used command line to download everything:
57 | ~~~~
58 | wget https://www.dropbox.com/sh/hyzpgx1hp9nj37s/AADfY2s7gD_MkR76m03KS0K1a/Archive.zip?dl=1
59 | mv 'Archive.zip?dl=1' Archive.zip
60 | unzip Archive.zip
61 | ~~~~
62 |
63 | To retrieve the meta information and the raw data, please refer to the [SDK for these datasets](https://github.com/A2Zadeh/CMU-MultimodalSDK).
64 |
65 | ### Run the Code
66 |
67 | 1. Create (empty) folders for data and pre-trained models:
68 | ~~~~
69 | mkdir data pre_trained_models
70 | ~~~~
71 |
72 | and put the downloaded data in 'data/'.
73 |
74 | 2. Command as follows
75 | ~~~~
76 | python main.py [--FLAGS]
77 | ~~~~
78 |
79 | Note that the defualt arguments are for unaligned version of MOSEI. For other datasets, please refer to Supplmentary.
80 |
81 | ### If Using CTC
82 |
83 | Transformer requires no CTC module. However, as we describe in the paper, CTC module offers an alternative to applying other kinds of sequence models (e.g., recurrent architectures) to unaligned multimodal streams.
84 |
85 | If you want to use the CTC module, plesase install warp-ctc from [here](https://github.com/baidu-research/warp-ctc).
86 |
87 | The quick version:
88 | ~~~~
89 | git clone https://github.com/SeanNaren/warp-ctc.git
90 | cd warp-ctc
91 | mkdir build; cd build
92 | cmake ..
93 | make
94 | cd ../pytorch_binding
95 | python setup.py install
96 | export WARP_CTC_PATH=/home/xxx/warp-ctc/build
97 | ~~~~
98 |
99 | ### Acknowledgement
100 | Some portion of the code were adapted from the [fairseq](https://github.com/pytorch/fairseq) repo.
101 |
102 |
103 |
--------------------------------------------------------------------------------
/imgs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yaohungt/Multimodal-Transformer/a670936824ee722c8494fd98d204977a1d663c7a/imgs/architecture.png
--------------------------------------------------------------------------------
/imgs/cm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yaohungt/Multimodal-Transformer/a670936824ee722c8494fd98d204977a1d663c7a/imgs/cm.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from src.utils import *
4 | from torch.utils.data import DataLoader
5 | from src import train
6 |
7 |
8 | parser = argparse.ArgumentParser(description='MOSEI Sentiment Analysis')
9 | parser.add_argument('-f', default='', type=str)
10 |
11 | # Fixed
12 | parser.add_argument('--model', type=str, default='MulT',
13 | help='name of the model to use (Transformer, etc.)')
14 |
15 | # Tasks
16 | parser.add_argument('--vonly', action='store_true',
17 | help='use the crossmodal fusion into v (default: False)')
18 | parser.add_argument('--aonly', action='store_true',
19 | help='use the crossmodal fusion into a (default: False)')
20 | parser.add_argument('--lonly', action='store_true',
21 | help='use the crossmodal fusion into l (default: False)')
22 | parser.add_argument('--aligned', action='store_true',
23 | help='consider aligned experiment or not (default: False)')
24 | parser.add_argument('--dataset', type=str, default='mosei_senti',
25 | help='dataset to use (default: mosei_senti)')
26 | parser.add_argument('--data_path', type=str, default='data',
27 | help='path for storing the dataset')
28 |
29 | # Dropouts
30 | parser.add_argument('--attn_dropout', type=float, default=0.1,
31 | help='attention dropout')
32 | parser.add_argument('--attn_dropout_a', type=float, default=0.0,
33 | help='attention dropout (for audio)')
34 | parser.add_argument('--attn_dropout_v', type=float, default=0.0,
35 | help='attention dropout (for visual)')
36 | parser.add_argument('--relu_dropout', type=float, default=0.1,
37 | help='relu dropout')
38 | parser.add_argument('--embed_dropout', type=float, default=0.25,
39 | help='embedding dropout')
40 | parser.add_argument('--res_dropout', type=float, default=0.1,
41 | help='residual block dropout')
42 | parser.add_argument('--out_dropout', type=float, default=0.0,
43 | help='output layer dropout')
44 |
45 | # Architecture
46 | parser.add_argument('--nlevels', type=int, default=5,
47 | help='number of layers in the network (default: 5)')
48 | parser.add_argument('--num_heads', type=int, default=5,
49 | help='number of heads for the transformer network (default: 5)')
50 | parser.add_argument('--attn_mask', action='store_false',
51 | help='use attention mask for Transformer (default: true)')
52 |
53 | # Tuning
54 | parser.add_argument('--batch_size', type=int, default=24, metavar='N',
55 | help='batch size (default: 24)')
56 | parser.add_argument('--clip', type=float, default=0.8,
57 | help='gradient clip value (default: 0.8)')
58 | parser.add_argument('--lr', type=float, default=1e-3,
59 | help='initial learning rate (default: 1e-3)')
60 | parser.add_argument('--optim', type=str, default='Adam',
61 | help='optimizer to use (default: Adam)')
62 | parser.add_argument('--num_epochs', type=int, default=40,
63 | help='number of epochs (default: 40)')
64 | parser.add_argument('--when', type=int, default=20,
65 | help='when to decay learning rate (default: 20)')
66 | parser.add_argument('--batch_chunk', type=int, default=1,
67 | help='number of chunks per batch (default: 1)')
68 |
69 | # Logistics
70 | parser.add_argument('--log_interval', type=int, default=30,
71 | help='frequency of result logging (default: 30)')
72 | parser.add_argument('--seed', type=int, default=1111,
73 | help='random seed')
74 | parser.add_argument('--no_cuda', action='store_true',
75 | help='do not use cuda')
76 | parser.add_argument('--name', type=str, default='mult',
77 | help='name of the trial (default: "mult")')
78 | args = parser.parse_args()
79 |
80 | torch.manual_seed(args.seed)
81 | dataset = str.lower(args.dataset.strip())
82 | valid_partial_mode = args.lonly + args.vonly + args.aonly
83 |
84 | if valid_partial_mode == 0:
85 | args.lonly = args.vonly = args.aonly = True
86 | elif valid_partial_mode != 1:
87 | raise ValueError("You can only choose one of {l/v/a}only.")
88 |
89 | use_cuda = False
90 |
91 | output_dim_dict = {
92 | 'mosi': 1,
93 | 'mosei_senti': 1,
94 | 'iemocap': 8
95 | }
96 |
97 | criterion_dict = {
98 | 'iemocap': 'CrossEntropyLoss'
99 | }
100 |
101 | torch.set_default_tensor_type('torch.FloatTensor')
102 | if torch.cuda.is_available():
103 | if args.no_cuda:
104 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda")
105 | else:
106 | torch.cuda.manual_seed(args.seed)
107 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
108 | use_cuda = True
109 |
110 | ####################################################################
111 | #
112 | # Load the dataset (aligned or non-aligned)
113 | #
114 | ####################################################################
115 |
116 | print("Start loading the data....")
117 |
118 | train_data = get_data(args, dataset, 'train')
119 | valid_data = get_data(args, dataset, 'valid')
120 | test_data = get_data(args, dataset, 'test')
121 |
122 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
123 | valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=True)
124 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
125 |
126 | print('Finish loading the data....')
127 | if not args.aligned:
128 | print("### Note: You are running in unaligned mode.")
129 |
130 | ####################################################################
131 | #
132 | # Hyperparameters
133 | #
134 | ####################################################################
135 |
136 | hyp_params = args
137 | hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v = train_data.get_dim()
138 | hyp_params.l_len, hyp_params.a_len, hyp_params.v_len = train_data.get_seq_len()
139 | hyp_params.layers = args.nlevels
140 | hyp_params.use_cuda = use_cuda
141 | hyp_params.dataset = dataset
142 | hyp_params.when = args.when
143 | hyp_params.batch_chunk = args.batch_chunk
144 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_data), len(valid_data), len(test_data)
145 | hyp_params.model = str.upper(args.model.strip())
146 | hyp_params.output_dim = output_dim_dict.get(dataset, 1)
147 | hyp_params.criterion = criterion_dict.get(dataset, 'L1Loss')
148 |
149 |
150 | if __name__ == '__main__':
151 | test_loss = train.initiate(hyp_params, train_loader, valid_loader, test_loader)
152 |
153 |
--------------------------------------------------------------------------------
/modules/multihead_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 | import torch.nn.functional as F
5 | import sys
6 |
7 | # Code adapted from the fairseq repo.
8 |
9 | class MultiheadAttention(nn.Module):
10 | """Multi-headed attention.
11 | See "Attention Is All You Need" for more details.
12 | """
13 |
14 | def __init__(self, embed_dim, num_heads, attn_dropout=0.,
15 | bias=True, add_bias_kv=False, add_zero_attn=False):
16 | super().__init__()
17 | self.embed_dim = embed_dim
18 | self.num_heads = num_heads
19 | self.attn_dropout = attn_dropout
20 | self.head_dim = embed_dim // num_heads
21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
22 | self.scaling = self.head_dim ** -0.5
23 |
24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
25 | self.register_parameter('in_proj_bias', None)
26 | if bias:
27 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
28 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
29 |
30 | if add_bias_kv:
31 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
32 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
33 | else:
34 | self.bias_k = self.bias_v = None
35 |
36 | self.add_zero_attn = add_zero_attn
37 |
38 | self.reset_parameters()
39 |
40 | def reset_parameters(self):
41 | nn.init.xavier_uniform_(self.in_proj_weight)
42 | nn.init.xavier_uniform_(self.out_proj.weight)
43 | if self.in_proj_bias is not None:
44 | nn.init.constant_(self.in_proj_bias, 0.)
45 | nn.init.constant_(self.out_proj.bias, 0.)
46 | if self.bias_k is not None:
47 | nn.init.xavier_normal_(self.bias_k)
48 | if self.bias_v is not None:
49 | nn.init.xavier_normal_(self.bias_v)
50 |
51 | def forward(self, query, key, value, attn_mask=None):
52 | """Input shape: Time x Batch x Channel
53 | Self-attention can be implemented by passing in the same arguments for
54 | query, key and value. Timesteps can be masked by supplying a T x T mask in the
55 | `attn_mask` argument. Padding elements can be excluded from
56 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
57 | batch x src_len, where padding elements are indicated by 1s.
58 | """
59 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
60 | kv_same = key.data_ptr() == value.data_ptr()
61 |
62 | tgt_len, bsz, embed_dim = query.size()
63 | assert embed_dim == self.embed_dim
64 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
65 | assert key.size() == value.size()
66 |
67 | aved_state = None
68 |
69 | if qkv_same:
70 | # self-attention
71 | q, k, v = self.in_proj_qkv(query)
72 | elif kv_same:
73 | # encoder-decoder attention
74 | q = self.in_proj_q(query)
75 |
76 | if key is None:
77 | assert value is None
78 | k = v = None
79 | else:
80 | k, v = self.in_proj_kv(key)
81 | else:
82 | q = self.in_proj_q(query)
83 | k = self.in_proj_k(key)
84 | v = self.in_proj_v(value)
85 | q = q * self.scaling
86 |
87 | if self.bias_k is not None:
88 | assert self.bias_v is not None
89 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
90 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
91 | if attn_mask is not None:
92 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
93 |
94 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
95 | if k is not None:
96 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
97 | if v is not None:
98 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
99 |
100 | src_len = k.size(1)
101 |
102 | if self.add_zero_attn:
103 | src_len += 1
104 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
105 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
106 | if attn_mask is not None:
107 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
108 |
109 | attn_weights = torch.bmm(q, k.transpose(1, 2))
110 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
111 |
112 | if attn_mask is not None:
113 | try:
114 | attn_weights += attn_mask.unsqueeze(0)
115 | except:
116 | print(attn_weights.shape)
117 | print(attn_mask.unsqueeze(0).shape)
118 | assert False
119 |
120 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
121 | # attn_weights = F.relu(attn_weights)
122 | # attn_weights = attn_weights / torch.max(attn_weights)
123 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)
124 |
125 | attn = torch.bmm(attn_weights, v)
126 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
127 |
128 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
129 | attn = self.out_proj(attn)
130 |
131 | # average attention weights over heads
132 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
133 | attn_weights = attn_weights.sum(dim=1) / self.num_heads
134 | return attn, attn_weights
135 |
136 | def in_proj_qkv(self, query):
137 | return self._in_proj(query).chunk(3, dim=-1)
138 |
139 | def in_proj_kv(self, key):
140 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
141 |
142 | def in_proj_q(self, query, **kwargs):
143 | return self._in_proj(query, end=self.embed_dim, **kwargs)
144 |
145 | def in_proj_k(self, key):
146 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
147 |
148 | def in_proj_v(self, value):
149 | return self._in_proj(value, start=2 * self.embed_dim)
150 |
151 | def _in_proj(self, input, start=0, end=None, **kwargs):
152 | weight = kwargs.get('weight', self.in_proj_weight)
153 | bias = kwargs.get('bias', self.in_proj_bias)
154 | weight = weight[start:end, :]
155 | if bias is not None:
156 | bias = bias[start:end]
157 | return F.linear(input, weight, bias)
158 |
--------------------------------------------------------------------------------
/modules/position_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | # Code adapted from the fairseq repo.
7 |
8 | def make_positions(tensor, padding_idx, left_pad):
9 | """Replace non-padding symbols with their position numbers.
10 | Position numbers begin at padding_idx+1.
11 | Padding symbols are ignored, but it is necessary to specify whether padding
12 | is added on the left side (left_pad=True) or right side (left_pad=False).
13 | """
14 | max_pos = padding_idx + 1 + tensor.size(1)
15 | device = tensor.get_device()
16 | buf_name = f'range_buf_{device}'
17 | if not hasattr(make_positions, buf_name):
18 | setattr(make_positions, buf_name, tensor.new())
19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
20 | if getattr(make_positions, buf_name).numel() < max_pos:
21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
22 | mask = tensor.ne(padding_idx)
23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
24 | if left_pad:
25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
26 | new_tensor = tensor.clone()
27 | return new_tensor.masked_scatter_(mask, positions[mask]).long()
28 |
29 |
30 | class SinusoidalPositionalEmbedding(nn.Module):
31 | """This module produces sinusoidal positional embeddings of any length.
32 | Padding symbols are ignored, but it is necessary to specify whether padding
33 | is added on the left side (left_pad=True) or right side (left_pad=False).
34 | """
35 |
36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
37 | super().__init__()
38 | self.embedding_dim = embedding_dim
39 | self.padding_idx = padding_idx
40 | self.left_pad = left_pad
41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-(
42 | self.register_buffer('_float_tensor', torch.FloatTensor(1))
43 |
44 | @staticmethod
45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
46 | """Build sinusoidal embeddings.
47 | This matches the implementation in tensor2tensor, but differs slightly
48 | from the description in Section 3.5 of "Attention Is All You Need".
49 | """
50 | half_dim = embedding_dim // 2
51 | emb = math.log(10000) / (half_dim - 1)
52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
55 | if embedding_dim % 2 == 1:
56 | # zero pad
57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
58 | if padding_idx is not None:
59 | emb[padding_idx, :] = 0
60 | return emb
61 |
62 | def forward(self, input):
63 | """Input is expected to be of size [bsz x seqlen]."""
64 | bsz, seq_len = input.size()
65 | max_pos = self.padding_idx + 1 + seq_len
66 | device = input.get_device()
67 | if device not in self.weights or max_pos > self.weights[device].size(0):
68 | # recompute/expand embeddings if needed
69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
70 | max_pos,
71 | self.embedding_dim,
72 | self.padding_idx,
73 | )
74 | self.weights[device] = self.weights[device].type_as(self._float_tensor)
75 | positions = make_positions(input, self.padding_idx, self.left_pad)
76 | return self.weights[device].index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
77 |
78 | def max_positions(self):
79 | """Maximum number of supported positions."""
80 | return int(1e5) # an arbitrary large number
--------------------------------------------------------------------------------
/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from modules.position_embedding import SinusoidalPositionalEmbedding
5 | from modules.multihead_attention import MultiheadAttention
6 | import math
7 |
8 |
9 | class TransformerEncoder(nn.Module):
10 | """
11 | Transformer encoder consisting of *args.encoder_layers* layers. Each layer
12 | is a :class:`TransformerEncoderLayer`.
13 | Args:
14 | embed_tokens (torch.nn.Embedding): input embedding
15 | num_heads (int): number of heads
16 | layers (int): number of layers
17 | attn_dropout (float): dropout applied on the attention weights
18 | relu_dropout (float): dropout applied on the first layer of the residual block
19 | res_dropout (float): dropout applied on the residual block
20 | attn_mask (bool): whether to apply mask on the attention weights
21 | """
22 |
23 | def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0,
24 | embed_dropout=0.0, attn_mask=False):
25 | super().__init__()
26 | self.dropout = embed_dropout # Embedding dropout
27 | self.attn_dropout = attn_dropout
28 | self.embed_dim = embed_dim
29 | self.embed_scale = math.sqrt(embed_dim)
30 | self.embed_positions = SinusoidalPositionalEmbedding(embed_dim)
31 |
32 | self.attn_mask = attn_mask
33 |
34 | self.layers = nn.ModuleList([])
35 | for layer in range(layers):
36 | new_layer = TransformerEncoderLayer(embed_dim,
37 | num_heads=num_heads,
38 | attn_dropout=attn_dropout,
39 | relu_dropout=relu_dropout,
40 | res_dropout=res_dropout,
41 | attn_mask=attn_mask)
42 | self.layers.append(new_layer)
43 |
44 | self.register_buffer('version', torch.Tensor([2]))
45 | self.normalize = True
46 | if self.normalize:
47 | self.layer_norm = LayerNorm(embed_dim)
48 |
49 | def forward(self, x_in, x_in_k = None, x_in_v = None):
50 | """
51 | Args:
52 | x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
53 | x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
54 | x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
55 | Returns:
56 | dict:
57 | - **encoder_out** (Tensor): the last encoder layer's output of
58 | shape `(src_len, batch, embed_dim)`
59 | - **encoder_padding_mask** (ByteTensor): the positions of
60 | padding elements of shape `(batch, src_len)`
61 | """
62 | # embed tokens and positions
63 | x = self.embed_scale * x_in
64 | if self.embed_positions is not None:
65 | x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
66 | x = F.dropout(x, p=self.dropout, training=self.training)
67 |
68 | if x_in_k is not None and x_in_v is not None:
69 | # embed tokens and positions
70 | x_k = self.embed_scale * x_in_k
71 | x_v = self.embed_scale * x_in_v
72 | if self.embed_positions is not None:
73 | x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
74 | x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
75 | x_k = F.dropout(x_k, p=self.dropout, training=self.training)
76 | x_v = F.dropout(x_v, p=self.dropout, training=self.training)
77 |
78 | # encoder layers
79 | intermediates = [x]
80 | for layer in self.layers:
81 | if x_in_k is not None and x_in_v is not None:
82 | x = layer(x, x_k, x_v)
83 | else:
84 | x = layer(x)
85 | intermediates.append(x)
86 |
87 | if self.normalize:
88 | x = self.layer_norm(x)
89 |
90 | return x
91 |
92 | def max_positions(self):
93 | """Maximum input length supported by the encoder."""
94 | if self.embed_positions is None:
95 | return self.max_source_positions
96 | return min(self.max_source_positions, self.embed_positions.max_positions())
97 |
98 |
99 | class TransformerEncoderLayer(nn.Module):
100 | """Encoder layer block.
101 | In the original paper each operation (multi-head attention or FFN) is
102 | postprocessed with: `dropout -> add residual -> layernorm`. In the
103 | tensor2tensor code they suggest that learning is more robust when
104 | preprocessing each layer with layernorm and postprocessing with:
105 | `dropout -> add residual`. We default to the approach in the paper, but the
106 | tensor2tensor approach can be enabled by setting
107 | *args.encoder_normalize_before* to ``True``.
108 | Args:
109 | embed_dim: Embedding dimension
110 | """
111 |
112 | def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
113 | attn_mask=False):
114 | super().__init__()
115 | self.embed_dim = embed_dim
116 | self.num_heads = num_heads
117 |
118 | self.self_attn = MultiheadAttention(
119 | embed_dim=self.embed_dim,
120 | num_heads=self.num_heads,
121 | attn_dropout=attn_dropout
122 | )
123 | self.attn_mask = attn_mask
124 |
125 | self.relu_dropout = relu_dropout
126 | self.res_dropout = res_dropout
127 | self.normalize_before = True
128 |
129 | self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper
130 | self.fc2 = Linear(4*self.embed_dim, self.embed_dim)
131 | self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)])
132 |
133 | def forward(self, x, x_k=None, x_v=None):
134 | """
135 | Args:
136 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
137 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape
138 | `(batch, src_len)` where padding elements are indicated by ``1``.
139 | x_k (Tensor): same as x
140 | x_v (Tensor): same as x
141 | Returns:
142 | encoded output of shape `(batch, src_len, embed_dim)`
143 | """
144 | residual = x
145 | x = self.maybe_layer_norm(0, x, before=True)
146 | mask = buffered_future_mask(x, x_k) if self.attn_mask else None
147 | if x_k is None and x_v is None:
148 | x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
149 | else:
150 | x_k = self.maybe_layer_norm(0, x_k, before=True)
151 | x_v = self.maybe_layer_norm(0, x_v, before=True)
152 | x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
153 | x = F.dropout(x, p=self.res_dropout, training=self.training)
154 | x = residual + x
155 | x = self.maybe_layer_norm(0, x, after=True)
156 |
157 | residual = x
158 | x = self.maybe_layer_norm(1, x, before=True)
159 | x = F.relu(self.fc1(x))
160 | x = F.dropout(x, p=self.relu_dropout, training=self.training)
161 | x = self.fc2(x)
162 | x = F.dropout(x, p=self.res_dropout, training=self.training)
163 | x = residual + x
164 | x = self.maybe_layer_norm(1, x, after=True)
165 | return x
166 |
167 | def maybe_layer_norm(self, i, x, before=False, after=False):
168 | assert before ^ after
169 | if after ^ self.normalize_before:
170 | return self.layer_norms[i](x)
171 | else:
172 | return x
173 |
174 | def fill_with_neg_inf(t):
175 | """FP16-compatible function that fills a tensor with -inf."""
176 | return t.float().fill_(float('-inf')).type_as(t)
177 |
178 |
179 | def buffered_future_mask(tensor, tensor2=None):
180 | dim1 = dim2 = tensor.size(0)
181 | if tensor2 is not None:
182 | dim2 = tensor2.size(0)
183 | future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1))
184 | if tensor.is_cuda:
185 | future_mask = future_mask.cuda()
186 | return future_mask[:dim1, :dim2]
187 |
188 |
189 | def Linear(in_features, out_features, bias=True):
190 | m = nn.Linear(in_features, out_features, bias)
191 | nn.init.xavier_uniform_(m.weight)
192 | if bias:
193 | nn.init.constant_(m.bias, 0.)
194 | return m
195 |
196 |
197 | def LayerNorm(embedding_dim):
198 | m = nn.LayerNorm(embedding_dim)
199 | return m
200 |
201 |
202 | if __name__ == '__main__':
203 | encoder = TransformerEncoder(300, 4, 2)
204 | x = torch.tensor(torch.rand(20, 2, 300))
205 | print(encoder(x).shape)
206 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | # MulT model and Multimodal Sentiment Analysis Benchmarks
2 |
3 | This directory contains the model architecture for Multimodal Transformer (MulT) as well as the three major multimodal sentiment analysis benchmarks we used in the paper. All datasets should be put in the `../data` folder (if it does not exist, use `mkdir` to create one). Depending on the dataset, we may have different sets of evaluation metrics (see `eval_metrics.py`).
4 |
--------------------------------------------------------------------------------
/src/ctc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class CTCModule(nn.Module):
6 | def __init__(self, in_dim, out_seq_len):
7 | '''
8 | This module is performing alignment from A (e.g., audio) to B (e.g., text).
9 | :param in_dim: Dimension for input modality A
10 | :param out_seq_len: Sequence length for output modality B
11 | '''
12 | super(CTCModule, self).__init__()
13 | # Use LSTM for predicting the position from A to B
14 | self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank
15 |
16 | self.out_seq_len = out_seq_len
17 |
18 | self.softmax = nn.Softmax(dim=2)
19 | def forward(self, x):
20 | '''
21 | :input x: Input with shape [batch_size x in_seq_len x in_dim]
22 | '''
23 | # NOTE that the index 0 refers to blank.
24 | pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x)
25 |
26 | prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1
27 | prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len
28 | prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len
29 | pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim
30 |
31 | # pseudo_aligned_out is regarded as the aligned A (w.r.t B)
32 | return pseudo_aligned_out, (pred_output_position_inclu_blank)
33 |
34 |
35 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data.dataset import Dataset
3 | import pickle
4 | import os
5 | from scipy import signal
6 | import torch
7 |
8 | if torch.cuda.is_available():
9 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
10 | else:
11 | torch.set_default_tensor_type('torch.FloatTensor')
12 |
13 | ############################################################################################
14 | # This file provides basic processing script for the multimodal datasets we use. For other
15 | # datasets, small modifications may be needed (depending on the type of the data, etc.)
16 | ############################################################################################
17 |
18 |
19 | class Multimodal_Datasets(Dataset):
20 | def __init__(self, dataset_path, data='mosei_senti', split_type='train', if_align=False):
21 | super(Multimodal_Datasets, self).__init__()
22 | dataset_path = os.path.join(dataset_path, data+'_data.pkl' if if_align else data+'_data_noalign.pkl' )
23 | dataset = pickle.load(open(dataset_path, 'rb'))
24 |
25 | # These are torch tensors
26 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach()
27 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach()
28 | self.audio = dataset[split_type]['audio'].astype(np.float32)
29 | self.audio[self.audio == -np.inf] = 0
30 | self.audio = torch.tensor(self.audio).cpu().detach()
31 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach()
32 |
33 | # Note: this is STILL an numpy array
34 | self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None
35 |
36 | self.data = data
37 |
38 | self.n_modalities = 3 # vision/ text/ audio
39 | def get_n_modalities(self):
40 | return self.n_modalities
41 | def get_seq_len(self):
42 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1]
43 | def get_dim(self):
44 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
45 | def get_lbl_info(self):
46 | # return number_of_labels, label_dim
47 | return self.labels.shape[1], self.labels.shape[2]
48 | def __len__(self):
49 | return len(self.labels)
50 | def __getitem__(self, index):
51 | X = (index, self.text[index], self.audio[index], self.vision[index])
52 | Y = self.labels[index]
53 | META = (0,0,0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2])
54 | if self.data == 'mosi':
55 | META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'), self.meta[index][2].decode('UTF-8'))
56 | if self.data == 'iemocap':
57 | Y = torch.argmax(Y, dim=-1)
58 | return X, Y, META
59 |
60 |
--------------------------------------------------------------------------------
/src/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import classification_report
4 | from sklearn.metrics import confusion_matrix
5 | from sklearn.metrics import precision_recall_fscore_support
6 | from sklearn.metrics import accuracy_score, f1_score
7 |
8 |
9 | def multiclass_acc(preds, truths):
10 | """
11 | Compute the multiclass accuracy w.r.t. groundtruth
12 |
13 | :param preds: Float array representing the predictions, dimension (N,)
14 | :param truths: Float/int array representing the groundtruth classes, dimension (N,)
15 | :return: Classification accuracy
16 | """
17 | return np.sum(np.round(preds) == np.round(truths)) / float(len(truths))
18 |
19 |
20 | def weighted_accuracy(test_preds_emo, test_truth_emo):
21 | true_label = (test_truth_emo > 0)
22 | predicted_label = (test_preds_emo > 0)
23 | tp = float(np.sum((true_label==1) & (predicted_label==1)))
24 | tn = float(np.sum((true_label==0) & (predicted_label==0)))
25 | p = float(np.sum(true_label==1))
26 | n = float(np.sum(true_label==0))
27 |
28 | return (tp * (n/p) +tn) / (2*n)
29 |
30 |
31 | def eval_mosei_senti(results, truths, exclude_zero=False):
32 | test_preds = results.view(-1).cpu().detach().numpy()
33 | test_truth = truths.view(-1).cpu().detach().numpy()
34 |
35 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)])
36 |
37 | test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
38 | test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
39 | test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
40 | test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)
41 |
42 | mae = np.mean(np.absolute(test_preds - test_truth)) # Average L1 distance between preds and truths
43 | corr = np.corrcoef(test_preds, test_truth)[0][1]
44 | mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7)
45 | mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5)
46 | f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted')
47 | binary_truth = (test_truth[non_zeros] > 0)
48 | binary_preds = (test_preds[non_zeros] > 0)
49 |
50 | print("MAE: ", mae)
51 | print("Correlation Coefficient: ", corr)
52 | print("mult_acc_7: ", mult_a7)
53 | print("mult_acc_5: ", mult_a5)
54 | print("F1 score: ", f_score)
55 | print("Accuracy: ", accuracy_score(binary_truth, binary_preds))
56 |
57 | print("-" * 50)
58 |
59 |
60 | def eval_mosi(results, truths, exclude_zero=False):
61 | return eval_mosei_senti(results, truths, exclude_zero)
62 |
63 |
64 | def eval_iemocap(results, truths, single=-1):
65 | emos = ["Neutral", "Happy", "Sad", "Angry"]
66 | if single < 0:
67 | test_preds = results.view(-1, 4, 2).cpu().detach().numpy()
68 | test_truth = truths.view(-1, 4).cpu().detach().numpy()
69 |
70 | for emo_ind in range(4):
71 | print(f"{emos[emo_ind]}: ")
72 | test_preds_i = np.argmax(test_preds[:,emo_ind],axis=1)
73 | test_truth_i = test_truth[:,emo_ind]
74 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
75 | acc = accuracy_score(test_truth_i, test_preds_i)
76 | print(" - F1 Score: ", f1)
77 | print(" - Accuracy: ", acc)
78 | else:
79 | test_preds = results.view(-1, 2).cpu().detach().numpy()
80 | test_truth = truths.view(-1).cpu().detach().numpy()
81 |
82 | print(f"{emos[single]}: ")
83 | test_preds_i = np.argmax(test_preds,axis=1)
84 | test_truth_i = test_truth
85 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
86 | acc = accuracy_score(test_truth_i, test_preds_i)
87 | print(" - F1 Score: ", f1)
88 | print(" - Accuracy: ", acc)
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from modules.transformer import TransformerEncoder
6 |
7 |
8 | class MULTModel(nn.Module):
9 | def __init__(self, hyp_params):
10 | """
11 | Construct a MulT model.
12 | """
13 | super(MULTModel, self).__init__()
14 | self.orig_d_l, self.orig_d_a, self.orig_d_v = hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v
15 | self.d_l, self.d_a, self.d_v = 30, 30, 30
16 | self.vonly = hyp_params.vonly
17 | self.aonly = hyp_params.aonly
18 | self.lonly = hyp_params.lonly
19 | self.num_heads = hyp_params.num_heads
20 | self.layers = hyp_params.layers
21 | self.attn_dropout = hyp_params.attn_dropout
22 | self.attn_dropout_a = hyp_params.attn_dropout_a
23 | self.attn_dropout_v = hyp_params.attn_dropout_v
24 | self.relu_dropout = hyp_params.relu_dropout
25 | self.res_dropout = hyp_params.res_dropout
26 | self.out_dropout = hyp_params.out_dropout
27 | self.embed_dropout = hyp_params.embed_dropout
28 | self.attn_mask = hyp_params.attn_mask
29 |
30 | combined_dim = self.d_l + self.d_a + self.d_v
31 |
32 | self.partial_mode = self.lonly + self.aonly + self.vonly
33 | if self.partial_mode == 1:
34 | combined_dim = 2 * self.d_l # assuming d_l == d_a == d_v
35 | else:
36 | combined_dim = 2 * (self.d_l + self.d_a + self.d_v)
37 |
38 | output_dim = hyp_params.output_dim # This is actually not a hyperparameter :-)
39 |
40 | # 1. Temporal convolutional layers
41 | self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=1, padding=0, bias=False)
42 | self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=1, padding=0, bias=False)
43 | self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=1, padding=0, bias=False)
44 |
45 | # 2. Crossmodal Attentions
46 | if self.lonly:
47 | self.trans_l_with_a = self.get_network(self_type='la')
48 | self.trans_l_with_v = self.get_network(self_type='lv')
49 | if self.aonly:
50 | self.trans_a_with_l = self.get_network(self_type='al')
51 | self.trans_a_with_v = self.get_network(self_type='av')
52 | if self.vonly:
53 | self.trans_v_with_l = self.get_network(self_type='vl')
54 | self.trans_v_with_a = self.get_network(self_type='va')
55 |
56 | # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.)
57 | # [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1)
58 | self.trans_l_mem = self.get_network(self_type='l_mem', layers=3)
59 | self.trans_a_mem = self.get_network(self_type='a_mem', layers=3)
60 | self.trans_v_mem = self.get_network(self_type='v_mem', layers=3)
61 |
62 | # Projection layers
63 | self.proj1 = nn.Linear(combined_dim, combined_dim)
64 | self.proj2 = nn.Linear(combined_dim, combined_dim)
65 | self.out_layer = nn.Linear(combined_dim, output_dim)
66 |
67 | def get_network(self, self_type='l', layers=-1):
68 | if self_type in ['l', 'al', 'vl']:
69 | embed_dim, attn_dropout = self.d_l, self.attn_dropout
70 | elif self_type in ['a', 'la', 'va']:
71 | embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
72 | elif self_type in ['v', 'lv', 'av']:
73 | embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
74 | elif self_type == 'l_mem':
75 | embed_dim, attn_dropout = 2*self.d_l, self.attn_dropout
76 | elif self_type == 'a_mem':
77 | embed_dim, attn_dropout = 2*self.d_a, self.attn_dropout
78 | elif self_type == 'v_mem':
79 | embed_dim, attn_dropout = 2*self.d_v, self.attn_dropout
80 | else:
81 | raise ValueError("Unknown network type")
82 |
83 | return TransformerEncoder(embed_dim=embed_dim,
84 | num_heads=self.num_heads,
85 | layers=max(self.layers, layers),
86 | attn_dropout=attn_dropout,
87 | relu_dropout=self.relu_dropout,
88 | res_dropout=self.res_dropout,
89 | embed_dropout=self.embed_dropout,
90 | attn_mask=self.attn_mask)
91 |
92 | def forward(self, x_l, x_a, x_v):
93 | """
94 | text, audio, and vision should have dimension [batch_size, seq_len, n_features]
95 | """
96 | x_l = F.dropout(x_l.transpose(1, 2), p=self.embed_dropout, training=self.training)
97 | x_a = x_a.transpose(1, 2)
98 | x_v = x_v.transpose(1, 2)
99 |
100 | # Project the textual/visual/audio features
101 | proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
102 | proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
103 | proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
104 | proj_x_a = proj_x_a.permute(2, 0, 1)
105 | proj_x_v = proj_x_v.permute(2, 0, 1)
106 | proj_x_l = proj_x_l.permute(2, 0, 1)
107 |
108 | if self.lonly:
109 | # (V,A) --> L
110 | h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a) # Dimension (L, N, d_l)
111 | h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v) # Dimension (L, N, d_l)
112 | h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2)
113 | h_ls = self.trans_l_mem(h_ls)
114 | if type(h_ls) == tuple:
115 | h_ls = h_ls[0]
116 | last_h_l = last_hs = h_ls[-1] # Take the last output for prediction
117 |
118 | if self.aonly:
119 | # (L,V) --> A
120 | h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l)
121 | h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v)
122 | h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2)
123 | h_as = self.trans_a_mem(h_as)
124 | if type(h_as) == tuple:
125 | h_as = h_as[0]
126 | last_h_a = last_hs = h_as[-1]
127 |
128 | if self.vonly:
129 | # (L,A) --> V
130 | h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l)
131 | h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a)
132 | h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2)
133 | h_vs = self.trans_v_mem(h_vs)
134 | if type(h_vs) == tuple:
135 | h_vs = h_vs[0]
136 | last_h_v = last_hs = h_vs[-1]
137 |
138 | if self.partial_mode == 3:
139 | last_hs = torch.cat([last_h_l, last_h_a, last_h_v], dim=1)
140 |
141 | # A residual block
142 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training))
143 | last_hs_proj += last_hs
144 |
145 | output = self.out_layer(last_hs_proj)
146 | return output, last_hs
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import sys
4 | from src import models
5 | from src import ctc
6 | from src.utils import *
7 | import torch.optim as optim
8 | import numpy as np
9 | import time
10 | from torch.optim.lr_scheduler import ReduceLROnPlateau
11 | import os
12 | import pickle
13 |
14 | from sklearn.metrics import classification_report
15 | from sklearn.metrics import confusion_matrix
16 | from sklearn.metrics import precision_recall_fscore_support
17 | from sklearn.metrics import accuracy_score, f1_score
18 | from src.eval_metrics import *
19 |
20 |
21 | ####################################################################
22 | #
23 | # Construct the model and the CTC module (which may not be needed)
24 | #
25 | ####################################################################
26 |
27 | def get_CTC_module(hyp_params):
28 | a2l_module = getattr(ctc, 'CTCModule')(in_dim=hyp_params.orig_d_a, out_seq_len=hyp_params.l_len)
29 | v2l_module = getattr(ctc, 'CTCModule')(in_dim=hyp_params.orig_d_v, out_seq_len=hyp_params.l_len)
30 | return a2l_module, v2l_module
31 |
32 | def initiate(hyp_params, train_loader, valid_loader, test_loader):
33 | model = getattr(models, hyp_params.model+'Model')(hyp_params)
34 |
35 | if hyp_params.use_cuda:
36 | model = model.cuda()
37 |
38 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr)
39 | criterion = getattr(nn, hyp_params.criterion)()
40 | if hyp_params.aligned or hyp_params.model=='MULT':
41 | ctc_criterion = None
42 | ctc_a2l_module, ctc_v2l_module = None, None
43 | ctc_a2l_optimizer, ctc_v2l_optimizer = None, None
44 | else:
45 | from warpctc_pytorch import CTCLoss
46 | ctc_criterion = CTCLoss()
47 | ctc_a2l_module, ctc_v2l_module = get_CTC_module(hyp_params)
48 | if hyp_params.use_cuda:
49 | ctc_a2l_module, ctc_v2l_module = ctc_a2l_module.cuda(), ctc_v2l_module.cuda()
50 | ctc_a2l_optimizer = getattr(optim, hyp_params.optim)(ctc_a2l_module.parameters(), lr=hyp_params.lr)
51 | ctc_v2l_optimizer = getattr(optim, hyp_params.optim)(ctc_v2l_module.parameters(), lr=hyp_params.lr)
52 |
53 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True)
54 | settings = {'model': model,
55 | 'optimizer': optimizer,
56 | 'criterion': criterion,
57 | 'ctc_a2l_module': ctc_a2l_module,
58 | 'ctc_v2l_module': ctc_v2l_module,
59 | 'ctc_a2l_optimizer': ctc_a2l_optimizer,
60 | 'ctc_v2l_optimizer': ctc_v2l_optimizer,
61 | 'ctc_criterion': ctc_criterion,
62 | 'scheduler': scheduler}
63 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader)
64 |
65 |
66 | ####################################################################
67 | #
68 | # Training and evaluation scripts
69 | #
70 | ####################################################################
71 |
72 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader):
73 | model = settings['model']
74 | optimizer = settings['optimizer']
75 | criterion = settings['criterion']
76 |
77 | ctc_a2l_module = settings['ctc_a2l_module']
78 | ctc_v2l_module = settings['ctc_v2l_module']
79 | ctc_a2l_optimizer = settings['ctc_a2l_optimizer']
80 | ctc_v2l_optimizer = settings['ctc_v2l_optimizer']
81 | ctc_criterion = settings['ctc_criterion']
82 |
83 | scheduler = settings['scheduler']
84 |
85 |
86 | def train(model, optimizer, criterion, ctc_a2l_module, ctc_v2l_module, ctc_a2l_optimizer, ctc_v2l_optimizer, ctc_criterion):
87 | epoch_loss = 0
88 | model.train()
89 | num_batches = hyp_params.n_train // hyp_params.batch_size
90 | proc_loss, proc_size = 0, 0
91 | start_time = time.time()
92 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(train_loader):
93 | sample_ind, text, audio, vision = batch_X
94 | eval_attr = batch_Y.squeeze(-1) # if num of labels is 1
95 |
96 | model.zero_grad()
97 | if ctc_criterion is not None:
98 | ctc_a2l_module.zero_grad()
99 | ctc_v2l_module.zero_grad()
100 |
101 | if hyp_params.use_cuda:
102 | with torch.cuda.device(0):
103 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda()
104 | if hyp_params.dataset == 'iemocap':
105 | eval_attr = eval_attr.long()
106 |
107 | batch_size = text.size(0)
108 | batch_chunk = hyp_params.batch_chunk
109 |
110 | ######## CTC STARTS ######## Do not worry about this if not working on CTC
111 | if ctc_criterion is not None:
112 | ctc_a2l_net = nn.DataParallel(ctc_a2l_module) if batch_size > 10 else ctc_a2l_module
113 | ctc_v2l_net = nn.DataParallel(ctc_v2l_module) if batch_size > 10 else ctc_v2l_module
114 |
115 | audio, a2l_position = ctc_a2l_net(audio) # audio now is the aligned to text
116 | vision, v2l_position = ctc_v2l_net(vision)
117 |
118 | ## Compute the ctc loss
119 | l_len, a_len, v_len = hyp_params.l_len, hyp_params.a_len, hyp_params.v_len
120 | # Output Labels
121 | l_position = torch.tensor([i+1 for i in range(l_len)]*batch_size).int().cpu()
122 | # Specifying each output length
123 | l_length = torch.tensor([l_len]*batch_size).int().cpu()
124 | # Specifying each input length
125 | a_length = torch.tensor([a_len]*batch_size).int().cpu()
126 | v_length = torch.tensor([v_len]*batch_size).int().cpu()
127 |
128 | ctc_a2l_loss = ctc_criterion(a2l_position.transpose(0,1).cpu(), l_position, a_length, l_length)
129 | ctc_v2l_loss = ctc_criterion(v2l_position.transpose(0,1).cpu(), l_position, v_length, l_length)
130 | ctc_loss = ctc_a2l_loss + ctc_v2l_loss
131 | ctc_loss = ctc_loss.cuda() if hyp_params.use_cuda else ctc_loss
132 | else:
133 | ctc_loss = 0
134 | ######## CTC ENDS ########
135 |
136 | combined_loss = 0
137 | net = nn.DataParallel(model) if batch_size > 10 else model
138 | if batch_chunk > 1:
139 | raw_loss = combined_loss = 0
140 | text_chunks = text.chunk(batch_chunk, dim=0)
141 | audio_chunks = audio.chunk(batch_chunk, dim=0)
142 | vision_chunks = vision.chunk(batch_chunk, dim=0)
143 | eval_attr_chunks = eval_attr.chunk(batch_chunk, dim=0)
144 |
145 | for i in range(batch_chunk):
146 | text_i, audio_i, vision_i = text_chunks[i], audio_chunks[i], vision_chunks[i]
147 | eval_attr_i = eval_attr_chunks[i]
148 | preds_i, hiddens_i = net(text_i, audio_i, vision_i)
149 |
150 | if hyp_params.dataset == 'iemocap':
151 | preds_i = preds_i.view(-1, 2)
152 | eval_attr_i = eval_attr_i.view(-1)
153 | raw_loss_i = criterion(preds_i, eval_attr_i) / batch_chunk
154 | raw_loss += raw_loss_i
155 | raw_loss_i.backward()
156 | ctc_loss.backward()
157 | combined_loss = raw_loss + ctc_loss
158 | else:
159 | preds, hiddens = net(text, audio, vision)
160 | if hyp_params.dataset == 'iemocap':
161 | preds = preds.view(-1, 2)
162 | eval_attr = eval_attr.view(-1)
163 | raw_loss = criterion(preds, eval_attr)
164 | combined_loss = raw_loss + ctc_loss
165 | combined_loss.backward()
166 |
167 | if ctc_criterion is not None:
168 | torch.nn.utils.clip_grad_norm_(ctc_a2l_module.parameters(), hyp_params.clip)
169 | torch.nn.utils.clip_grad_norm_(ctc_v2l_module.parameters(), hyp_params.clip)
170 | ctc_a2l_optimizer.step()
171 | ctc_v2l_optimizer.step()
172 |
173 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip)
174 | optimizer.step()
175 |
176 | proc_loss += raw_loss.item() * batch_size
177 | proc_size += batch_size
178 | epoch_loss += combined_loss.item() * batch_size
179 | if i_batch % hyp_params.log_interval == 0 and i_batch > 0:
180 | avg_loss = proc_loss / proc_size
181 | elapsed_time = time.time() - start_time
182 | print('Epoch {:2d} | Batch {:3d}/{:3d} | Time/Batch(ms) {:5.2f} | Train Loss {:5.4f}'.
183 | format(epoch, i_batch, num_batches, elapsed_time * 1000 / hyp_params.log_interval, avg_loss))
184 | proc_loss, proc_size = 0, 0
185 | start_time = time.time()
186 |
187 | return epoch_loss / hyp_params.n_train
188 |
189 | def evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=False):
190 | model.eval()
191 | loader = test_loader if test else valid_loader
192 | total_loss = 0.0
193 |
194 | results = []
195 | truths = []
196 |
197 | with torch.no_grad():
198 | for i_batch, (batch_X, batch_Y, batch_META) in enumerate(loader):
199 | sample_ind, text, audio, vision = batch_X
200 | eval_attr = batch_Y.squeeze(dim=-1) # if num of labels is 1
201 |
202 | if hyp_params.use_cuda:
203 | with torch.cuda.device(0):
204 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda()
205 | if hyp_params.dataset == 'iemocap':
206 | eval_attr = eval_attr.long()
207 |
208 | batch_size = text.size(0)
209 |
210 | if (ctc_a2l_module is not None) and (ctc_v2l_module is not None):
211 | ctc_a2l_net = nn.DataParallel(ctc_a2l_module) if batch_size > 10 else ctc_a2l_module
212 | ctc_v2l_net = nn.DataParallel(ctc_v2l_module) if batch_size > 10 else ctc_v2l_module
213 | audio, _ = ctc_a2l_net(audio) # audio aligned to text
214 | vision, _ = ctc_v2l_net(vision) # vision aligned to text
215 |
216 | net = nn.DataParallel(model) if batch_size > 10 else model
217 | preds, _ = net(text, audio, vision)
218 | if hyp_params.dataset == 'iemocap':
219 | preds = preds.view(-1, 2)
220 | eval_attr = eval_attr.view(-1)
221 | total_loss += criterion(preds, eval_attr).item() * batch_size
222 |
223 | # Collect the results into dictionary
224 | results.append(preds)
225 | truths.append(eval_attr)
226 |
227 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid)
228 |
229 | results = torch.cat(results)
230 | truths = torch.cat(truths)
231 | return avg_loss, results, truths
232 |
233 | best_valid = 1e8
234 | for epoch in range(1, hyp_params.num_epochs+1):
235 | start = time.time()
236 | train(model, optimizer, criterion, ctc_a2l_module, ctc_v2l_module, ctc_a2l_optimizer, ctc_v2l_optimizer, ctc_criterion)
237 | val_loss, _, _ = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=False)
238 | test_loss, _, _ = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=True)
239 |
240 | end = time.time()
241 | duration = end-start
242 | scheduler.step(val_loss) # Decay learning rate by validation loss
243 |
244 | print("-"*50)
245 | print('Epoch {:2d} | Time {:5.4f} sec | Valid Loss {:5.4f} | Test Loss {:5.4f}'.format(epoch, duration, val_loss, test_loss))
246 | print("-"*50)
247 |
248 | if val_loss < best_valid:
249 | print(f"Saved model at pre_trained_models/{hyp_params.name}.pt!")
250 | save_model(hyp_params, model, name=hyp_params.name)
251 | best_valid = val_loss
252 |
253 | model = load_model(hyp_params, name=hyp_params.name)
254 | _, results, truths = evaluate(model, ctc_a2l_module, ctc_v2l_module, criterion, test=True)
255 |
256 | if hyp_params.dataset == "mosei_senti":
257 | eval_mosei_senti(results, truths, True)
258 | elif hyp_params.dataset == 'mosi':
259 | eval_mosi(results, truths, True)
260 | elif hyp_params.dataset == 'iemocap':
261 | eval_iemocap(results, truths)
262 |
263 | sys.stdout.flush()
264 | input('[Press Any Key to start another run]')
265 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from src.dataset import Multimodal_Datasets
4 |
5 |
6 | def get_data(args, dataset, split='train'):
7 | alignment = 'a' if args.aligned else 'na'
8 | data_path = os.path.join(args.data_path, dataset) + f'_{split}_{alignment}.dt'
9 | if not os.path.exists(data_path):
10 | print(f" - Creating new {split} data")
11 | data = Multimodal_Datasets(args.data_path, dataset, split, args.aligned)
12 | torch.save(data, data_path)
13 | else:
14 | print(f" - Found cached {split} data")
15 | data = torch.load(data_path)
16 | return data
17 |
18 |
19 | def save_load_name(args, name=''):
20 | if args.aligned:
21 | name = name if len(name) > 0 else 'aligned_model'
22 | elif not args.aligned:
23 | name = name if len(name) > 0 else 'nonaligned_model'
24 |
25 | return name + '_' + args.model
26 |
27 |
28 | def save_model(args, model, name=''):
29 | name = save_load_name(args, name)
30 | torch.save(model, f'pre_trained_models/{name}.pt')
31 |
32 |
33 | def load_model(args, name=''):
34 | name = save_load_name(args, name)
35 | model = torch.load(f'pre_trained_models/{name}.pt')
36 | return model
37 |
--------------------------------------------------------------------------------