├── README.md
├── dataloaders
├── __init__.py
├── iemocap_dataset.py
├── meld_dataset.py
├── mosei_dataset.py
├── mosi_dataset.py
└── vgaf_dataset.py
├── ensembling.py
├── layers
├── fc.py
└── layer_norm.py
├── main.py
├── models
├── __init__.py
├── model_MAT.py
└── model_MNT.py
├── train.py
└── utils
├── compute_args.py
├── optim.py
├── plot.py
├── pred_func.py
└── tokenize.py
/README.md:
--------------------------------------------------------------------------------
1 |

2 | EMNLP2020
3 |
4 | Pytorch implementation of the paper "Modulated Fusion using Transformer for Linguistic-Acoustic Emotion Recognition"
5 | ```
6 | @inproceedings{delbrouck-etal-2020-modulated,
7 | title = "Modulated Fusion using Transformer for Linguistic-Acoustic Emotion Recognition",
8 | author = "Delbrouck, Jean-Benoit and
9 | Tits, No{\'e} and
10 | Dupont, St{\'e}phane",
11 | booktitle = "Proceedings of the First International Workshop on Natural Language Processing Beyond Text",
12 | month = nov,
13 | year = "2020",
14 | address = "Online",
15 | publisher = "Association for Computational Linguistics",
16 | url = "https://www.aclweb.org/anthology/2020.nlpbt-1.1",
17 | doi = "10.18653/v1/2020.nlpbt-1.1",
18 | pages = "1--10",
19 | abstract = "This paper aims to bring a new lightweight yet powerful solution for the task of Emotion Recognition and Sentiment Analysis. Our motivation is to propose two architectures based on Transformers and modulation that combine the linguistic and acoustic inputs from a wide range of datasets to challenge, and sometimes surpass, the state-of-the-art in the field. To demonstrate the efficiency of our models, we carefully evaluate their performances on the IEMOCAP, MOSI, MOSEI and MELD dataset. The experiments can be directly replicated and the code is fully open for future researches.",
20 | }
21 | ```
22 |
23 | #### Environement
24 |
25 | Create a 3.6 python environement with:
26 | ```
27 | torch 1.2.0
28 | torchvision 0.4.0
29 | numpy 1.18.1
30 | ```
31 |
32 | We use GloVe vectors from space. This can be installed to your environement using the following commands :
33 | ```
34 | wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz
35 | pip install en_vectors_web_lg-2.1.0.tar.gz
36 | ```
37 | #### Data
38 |
39 | Create a data folder and get the data:
40 | ```
41 | mkdir -p data
42 | cd data
43 | wget -O data.zip https://www.dropbox.com/s/tz25q3xxfraw2r3/data.zip?dl=1
44 | unzip data.zip
45 | ```
46 |
47 | #### Training
48 |
49 | Here is an example to train a MAT model on IEMOCAP:
50 |
51 | ```
52 | mkdir -p ckpt
53 | for i in {1..10}
54 | do
55 | python main.py --dataset IEMOCAP \
56 | --model Model_MAT \
57 | --multi_head 4 \
58 | --ff_size 1024 \
59 | --hidden_size 512 \
60 | --layer 2 \
61 | --batch_size 32 \
62 | --lr_base 0.0001 \
63 | --dropout_r 0.1 \
64 | --dropout_o 0.5 \
65 | --name mymodel
66 | done
67 |
68 | ```
69 | Checkpoints will be stored in folder `ckpt/mymodel`
70 |
71 | #### Evaluation
72 |
73 | You can evaluate a model by typing :
74 | ```
75 | python ensembling.py --name mymodel --sets test
76 | ```
77 | The task settings are defined in the checkpoint state dict, so the evaluation will be carried on the dataset you trained your model on.
78 |
79 | By default, the script globs all the training checkpoints inside the folder and ensembling will be performed
80 | To show further details of the evaluation from a specific ensembling, you can use the `--index` argument:
81 | ```
82 | python ensembling.py --name mymodel --sets test --index 5
83 | ```
84 |
85 | #### Pre-trained model
86 | We release pre-trained models to replicate the results as shown in the paper. Models should be placed in the `ckpt` folder.
87 | ```
88 | mkdir -p ckpt
89 | ```
90 |
91 | [IEMOCAP 4-class emotions](https://www.dropbox.com/s/wzoiwrtc9m3nb78/IEMOCAP_pretrained.zip?dl=1)
92 | ```
93 | python ensembling.py --name IEMOCAP_pretrained --index 5 --sets test
94 |
95 | precision recall f1-score support
96 |
97 | 0 0.70 0.66 0.68 384
98 | 1 0.68 0.75 0.71 278
99 | 2 0.79 0.71 0.75 194
100 | 3 0.78 0.81 0.79 229
101 |
102 | accuracy 0.73 1085
103 | macro avg 0.74 0.73 0.73 1085
104 | weighted avg 0.73 0.73 0.73 1085
105 |
106 | Max ensemble w-accuracies for test : 72.53456221198157
107 | ```
108 |
109 | [MOSEI 2-class sentiment](https://www.dropbox.com/s/t2p8soswt9t1ii4/MOSEI_pretrained.zip?dl=1)
110 | ```
111 | python ensembling.py --name MOSEI_pretrained --index 9 --sets test
112 |
113 | precision recall f1-score support
114 |
115 | 0 0.75 0.57 0.65 1350
116 | 1 0.84 0.92 0.88 3312
117 |
118 | accuracy 0.82 4662
119 | macro avg 0.80 0.75 0.77 4662
120 | weighted avg 0.82 0.82 0.81 4662
121 |
122 | Max ensemble w-accuracies for test : 82.15358215358215
123 | ```
124 |
125 | [MOSI 2-class sentiment](https://www.dropbox.com/s/zw4a9ukk1npzt9r/MOSI_pretrained.zip?dl=1)
126 | ```
127 | python ensembling.py --name MOSI_pretrained --index 2 --sets test
128 |
129 |
130 | precision recall f1-score support
131 |
132 | 0 0.77 0.91 0.84 379
133 | 1 0.84 0.63 0.72 277
134 |
135 | accuracy 0.79 656
136 | macro avg 0.81 0.77 0.78 656
137 | weighted avg 0.80 0.79 0.79 656
138 |
139 | Max ensemble w-accuracies for test : 79.26829268292683
140 | ```
141 | [MELD 7-class emotions](https://www.dropbox.com/s/458h1ze6cic3h1l/MELD_pretrained.zip?dl=1)
142 | ```
143 | python ensembling.py --name MELD_pretrained --index 9 --sets test
144 |
145 |
146 | precision recall f1-score support
147 |
148 | 0 0.64 0.52 0.58 1256
149 | 1 0.36 0.58 0.45 281
150 | 2 0.08 0.18 0.11 50
151 | 3 0.23 0.25 0.24 208
152 | 4 0.44 0.47 0.46 402
153 | 5 0.23 0.24 0.23 68
154 | 6 0.31 0.27 0.29 345
155 |
156 | accuracy 0.45 2610
157 | macro avg 0.33 0.36 0.34 2610
158 | weighted avg 0.48 0.45 0.46 2610
159 | ```
160 |
161 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .meld_dataset import Meld_Dataset
2 | from .mosei_dataset import Mosei_Dataset
3 | from .mosi_dataset import Mosi_Dataset
4 | from .iemocap_dataset import Iemocap_Dataset
5 | from .vgaf_dataset import Vgaf_Dataset
--------------------------------------------------------------------------------
/dataloaders/iemocap_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import pickle
4 | import numpy as np
5 | import torch
6 | from utils.plot import plot
7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature
8 | from torch.utils.data import Dataset
9 |
10 | class Iemocap_Dataset(Dataset):
11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'):
12 | super(Iemocap_Dataset, self).__init__()
13 | assert name in ['train', 'valid', 'test']
14 | self.name = name
15 | self.args = args
16 | self.dataroot = os.path.join(dataroot, 'IEMOCAP')
17 | self.private_set = name == 'private'
18 |
19 | if name == 'train':
20 | name = 'traindev'
21 | if name == 'valid':
22 | name = 'test'
23 |
24 | word_file = os.path.join(self.dataroot, name + "_sentences.p")
25 | audio_file = os.path.join(self.dataroot, name + "_mels.p")
26 | y_s_file = os.path.join(self.dataroot, name + "_emotions.p")
27 |
28 | self.key_to_word = pickle.load(open(word_file, "rb"))
29 | self.key_to_audio = pickle.load(open(audio_file, "rb"))
30 | self.key_to_label = pickle.load(open(y_s_file, "rb"))
31 | self.set = list(self.key_to_label.keys())
32 |
33 | for key in self.set:
34 | if not (key in self.key_to_word and
35 | key in self.key_to_audio and
36 | key in self.key_to_label):
37 | print("Not present everywhere, removing key ", key)
38 | self.set.remove(key)
39 |
40 | # Plot temporal dimension of feature
41 | # t = []
42 | # for key in self.key_to_word.keys():
43 | # x = np.array(self.key_to_word[key]).shape[0]
44 | # t.append(x)
45 | # plot(t)
46 | # sys.exit()
47 |
48 | # Creating embeddings and word indexes
49 | self.key_to_sentence = tokenize(self.key_to_word)
50 | if token_to_ix is not None:
51 | self.token_to_ix = token_to_ix
52 | else: # Train
53 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot)
54 | self.vocab_size = len(self.token_to_ix)
55 |
56 | self.l_max_len = 15
57 | self.a_max_len = 40
58 |
59 | def __getitem__(self, idx):
60 | key = self.set[idx]
61 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len)
62 | A = pad_feature(self.key_to_audio[key], self.a_max_len)
63 | V = np.zeros(1) # not using video, insert dummy
64 |
65 | y = self.key_to_label[key]
66 | y = np.array(y)
67 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y)
68 |
69 | def __len__(self):
70 | return len(self.set)
--------------------------------------------------------------------------------
/dataloaders/mosi_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import pickle
4 | import numpy as np
5 | import torch
6 | from utils.plot import plot
7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature
8 | from torch.utils.data import Dataset
9 |
10 | class Mosi_Dataset(Dataset):
11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'):
12 | super(Mosi_Dataset, self).__init__()
13 | assert name in ['train', 'valid', 'test']
14 | self.name = name
15 | self.args = args
16 | self.dataroot = os.path.join(dataroot, 'MOSI')
17 | self.private_set = name == 'private'
18 | if name == 'train':
19 | name = 'trainval'
20 | if name == 'valid':
21 | name = 'test'
22 | word_file = os.path.join(self.dataroot, name + "_sentences.p")
23 | audio_file = os.path.join(self.dataroot, name + "_mels.p")
24 | y_s_file = os.path.join(self.dataroot, name + "_sentiment.p")
25 |
26 | self.key_to_word = pickle.load(open(word_file, "rb"))
27 | self.key_to_audio = pickle.load(open(audio_file, "rb"))
28 | self.key_to_label = pickle.load(open(y_s_file, "rb"))
29 | self.set = list(self.key_to_label.keys())
30 |
31 | # filter y = 0 for binary task (https://github.com/A2Zadeh/CMU-MultimodalSDK/tree/master/mmsdk/mmdatasdk/dataset/standard_datasets/CMU_MOSI)
32 | if self.args.task_binary:
33 | for key in self.key_to_label.keys():
34 | if self.key_to_label[key] == 0.0:
35 | print("2-class Sentiment, removing key ", key)
36 | self.set.remove(key)
37 |
38 | for key in self.set:
39 | if not (key in self.key_to_word and
40 | key in self.key_to_audio and
41 | key in self.key_to_label):
42 | print("Not present everywhere, removing key ", key)
43 | self.set.remove(key)
44 |
45 | # Plot temporal dimension of feature
46 | # t = []
47 | # for key in self.key_to_word.keys():
48 | # x = np.array(self.key_to_word[key]).shape[0]
49 | # t.append(x)
50 | # plot(t)
51 | # sys.exit()
52 |
53 | # Creating embeddings and word indexes
54 | self.key_to_sentence = tokenize(self.key_to_word)
55 | if token_to_ix is not None:
56 | self.token_to_ix = token_to_ix
57 | else: # Train
58 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot)
59 | self.vocab_size = len(self.token_to_ix)
60 |
61 | self.l_max_len = 30
62 | self.a_max_len = 60
63 |
64 | def __getitem__(self, idx):
65 | key = self.set[idx]
66 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len)
67 | A = pad_feature(self.key_to_audio[key], self.a_max_len)
68 | V = np.zeros(1) # not using video, insert dummy
69 |
70 | y = self.key_to_label[key]
71 | if self.args.task_binary:
72 | c = 0 if y < 0.0 else 1
73 | else:
74 | c = int(round(y)) + 3 # from -3;3 to 0;6
75 | y = np.array(c)
76 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y)
77 |
78 | def __len__(self):
79 | return len(self.set)
--------------------------------------------------------------------------------
/dataloaders/vgaf_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import pickle
4 | import numpy as np
5 | import torch
6 | from utils.plot import plot
7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature
8 | from torch.utils.data import Dataset
9 |
10 | class Vgaf_Dataset(Dataset):
11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'):
12 | super(Vgaf_Dataset, self).__init__()
13 | assert name in ['train', 'valid', 'test']
14 | self.name = name
15 | self.args = args
16 | self.dataroot = os.path.join(dataroot, 'VGAF')
17 | self.private_set = name == 'private'
18 |
19 | if name == 'test':
20 | name = 'valid'
21 |
22 | word_file = os.path.join(self.dataroot, name + "_sentences.p")
23 | audio_file = os.path.join(self.dataroot, name + "_mels.p")
24 | y_file = os.path.join(self.dataroot, name + "_emotions.p")
25 |
26 | self.key_to_word = pickle.load(open(word_file, "rb"))
27 | self.key_to_audio = pickle.load(open(audio_file, "rb"))
28 | self.key_to_label = pickle.load(open(y_file, "rb"))
29 | self.set = list(self.key_to_label.keys())
30 |
31 | for key in self.set:
32 | if not (key in self.key_to_word and
33 | key in self.key_to_audio and
34 | key in self.key_to_label):
35 | print("Not present everywhere, removing key ", key)
36 | self.set.remove(key)
37 |
38 | # Plot temporal dimension of feature
39 | # t = []
40 | # for key in self.key_to_word.keys():
41 | # x = np.array(self.key_to_word[key]).shape[0]
42 | # t.append(x)
43 | # print(max(t))
44 | # plot(t)
45 | # sys.exit()
46 |
47 | # Creating embeddings and word indexes
48 | self.key_to_sentence = tokenize(self.key_to_word)
49 | if token_to_ix is not None:
50 | self.token_to_ix = token_to_ix
51 | else: # Train
52 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot)
53 | self.vocab_size = len(self.token_to_ix)
54 |
55 | self.l_max_len = 30
56 | self.a_max_len = 26
57 |
58 | def __getitem__(self, idx):
59 | key = self.set[idx]
60 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len)
61 | A = pad_feature(self.key_to_audio[key], self.a_max_len)
62 | V = np.zeros(1) # not using video, insert dummy
63 |
64 | y = self.key_to_label[key]
65 | y = np.array(int(y)-1) #from 1,3 to 0,2
66 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y)
67 |
68 | def __len__(self):
69 | return len(self.set)
--------------------------------------------------------------------------------
/ensembling.py:
--------------------------------------------------------------------------------
1 | import argparse, os, glob, warnings, torch
2 | import numpy as np
3 | from utils.pred_func import *
4 | from sklearn.metrics import classification_report
5 | from torch.utils.data import DataLoader
6 | from dataloaders import *
7 | from models import *
8 | from utils.compute_args import compute_args
9 | from train import evaluate
10 |
11 | warnings.filterwarnings("ignore")
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--output', type=str, default='ckpt/')
17 | parser.add_argument('--name', type=str, default='exp0/')
18 | parser.add_argument('--sets', nargs='+', default=["valid", "test"])
19 |
20 | parser.add_argument('--index', type=int, default=99)
21 | parser.add_argument('--private_set', type=str, default=None)
22 |
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | if __name__ == '__main__':
28 | args = parse_args()
29 |
30 | # Save vars
31 | private_set = args.private_set
32 | index = args.index
33 | sets = args.sets
34 |
35 | # Listing sorted checkpoints
36 | ckpts = sorted(glob.glob(os.path.join(args.output, args.name,'best*')), reverse=True)
37 |
38 | # Load original args
39 | args = torch.load(ckpts[0])['args']
40 | args = compute_args(args)
41 |
42 |
43 | # Define the splits to be evaluated
44 | evaluation_sets = list(sets) + ([private_set] if private_set is not None else [])
45 | print("Evaluated sets: ", str(evaluation_sets))
46 | # Creating dataloader
47 | train_dset = eval(args.dataloader)('train', args)
48 | loaders = {set: DataLoader(eval(args.dataloader)(set, args, train_dset.token_to_ix),
49 | args.batch_size,
50 | num_workers=8,
51 | pin_memory=True) for set in evaluation_sets}
52 |
53 | # Creating net
54 | net = eval(args.model)(args, train_dset.vocab_size, train_dset.pretrained_emb).cuda()
55 |
56 | # Ensembling sets
57 | ensemble_preds = {set: {} for set in evaluation_sets}
58 | ensemble_accuracies = {set: [] for set in evaluation_sets}
59 |
60 | # Iterating over checkpoints
61 | for i, ckpt in enumerate(ckpts):
62 |
63 | if i >= index:
64 | break
65 |
66 | print("###### Ensembling " + str(i+1))
67 | state_dict = torch.load(ckpt)['state_dict']
68 | net.load_state_dict(state_dict)
69 |
70 | # Evaluation per checkpoint predictions
71 | for set in evaluation_sets:
72 | accuracy, preds = evaluate(net, loaders[set], args)
73 | print('Accuracy for ' + set + ' for model ' + ckpt + ":", accuracy)
74 | for id, pred in preds.items():
75 | if id not in ensemble_preds[set]:
76 | ensemble_preds[set][id] = []
77 | ensemble_preds[set][id].append(pred)
78 |
79 | # Compute set ensembling accuracy
80 | # Get all ids and answers
81 | ids = [id for ids, _, _, _, _ in loaders[set] for id in ids]
82 | ans = [np.array(a) for _, _, _, _, ans in loaders[set] for a in ans]
83 |
84 | # for all id, get averaged probabilities
85 | avg_preds = np.array([np.mean(np.array(ensemble_preds[set][id]), axis=0) for id in ids])
86 | # Compute accuracies
87 | if set != private_set:
88 | accuracy = np.mean(eval(args.pred_func)(avg_preds) == ans) * 100
89 | print("New " + set + " ens. Accuracy :", accuracy)
90 | ensemble_accuracies[set].append(accuracy)
91 |
92 | if i + 1 == index:
93 | print(classification_report(ans, eval(args.pred_func)(avg_preds)))
94 |
95 | # Printing overall results
96 | for set in sets:
97 | print("Max ensemble w-accuracies for " + set + " : " + str(max(ensemble_accuracies[set])))
--------------------------------------------------------------------------------
/layers/fc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class FC(nn.Module):
4 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
5 | super(FC, self).__init__()
6 | self.dropout_r = dropout_r
7 | self.use_relu = use_relu
8 |
9 | self.linear = nn.Linear(in_size, out_size)
10 |
11 | if use_relu:
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | if dropout_r > 0:
15 | self.dropout = nn.Dropout(dropout_r)
16 |
17 | def forward(self, x):
18 | x = self.linear(x)
19 |
20 | if self.use_relu:
21 | x = self.relu(x)
22 |
23 | if self.dropout_r > 0:
24 | x = self.dropout(x)
25 |
26 | return x
27 |
28 |
29 | class MLP(nn.Module):
30 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
31 | super(MLP, self).__init__()
32 |
33 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
34 | self.linear = nn.Linear(mid_size, out_size)
35 |
36 | def forward(self, x):
37 | return self.linear(self.fc(x))
38 |
--------------------------------------------------------------------------------
/layers/layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class LayerNorm(nn.Module):
5 | def __init__(self, size, eps=1e-6):
6 | super(LayerNorm, self).__init__()
7 | self.eps = eps
8 |
9 | self.a_2 = nn.Parameter(torch.ones(size))
10 | self.b_2 = nn.Parameter(torch.zeros(size))
11 |
12 | def forward(self, x):
13 | mean = x.mean(-1, keepdim=True)
14 | std = x.std(-1, keepdim=True)
15 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
16 |
17 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse, os, random
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from dataloaders import *
6 | from models import *
7 | from train import train
8 | from utils.compute_args import compute_args
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | # Model
13 | parser.add_argument('--model', type=str, default="Model_MCAN")
14 | parser.add_argument('--layer', type=int, default=4)
15 | parser.add_argument('--hidden_size', type=int, default=512)
16 | parser.add_argument('--dropout_i', type=float, default=0.0)
17 | parser.add_argument('--dropout_r', type=float, default=0.1)
18 | parser.add_argument('--dropout_o', type=float, default=0.5)
19 | parser.add_argument('--multi_head', type=int, default=8)
20 | parser.add_argument('--ff_size', type=int, default=2048)
21 | parser.add_argument('--word_embed_size', type=int, default=300)
22 | parser.add_argument('--bidirectional', type=bool, default=False)
23 |
24 | # Data
25 | parser.add_argument('--lang_seq_len', type=int, default=50)
26 | parser.add_argument('--audio_seq_len', type=int, default=50)
27 | parser.add_argument('--video_seq_len', type=int, default=60)
28 | parser.add_argument('--audio_feat_size', type=int, default=80)
29 | parser.add_argument('--video_feat_size', type=int, default=512)
30 |
31 | # Training
32 | parser.add_argument('--output', type=str, default='ckpt/')
33 | parser.add_argument('--name', type=str, default='exp0/')
34 | parser.add_argument('--batch_size', type=int, default=64)
35 | parser.add_argument('--max_epoch', type=int, default=99)
36 | parser.add_argument('--lr_base', type=float, default=0.00005)
37 | parser.add_argument('--lr_decay', type=float, default=0.2)
38 | parser.add_argument('--lr_decay_times', type=int, default=2)
39 | parser.add_argument('--grad_norm_clip', type=float, default=-1)
40 | parser.add_argument('--eval_start', type=int, default=0)
41 | parser.add_argument('--early_stop', type=int, default=3)
42 | parser.add_argument('--seed', type=int, default=random.randint(0, 9999999))
43 |
44 | # Dataset and task
45 | parser.add_argument('--dataset', type=str, choices=['MELD', 'MOSEI', 'MOSI', 'IEMOCAP', 'VGAF'], default='MOSEI')
46 | parser.add_argument('--task', type=str, choices=['sentiment', 'emotion'], default='sentiment')
47 | parser.add_argument('--task_binary', type=bool, default=False)
48 |
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | if __name__ == '__main__':
54 | # Base on args given, compute new args
55 | args = compute_args(parse_args())
56 |
57 | # Seed
58 | torch.manual_seed(args.seed)
59 | np.random.seed(args.seed)
60 | torch.backends.cudnn.deterministic = True
61 | torch.backends.cudnn.benchmark = False
62 |
63 | # DataLoader
64 | train_dset = eval(args.dataloader)('train', args)
65 | eval_dset = eval(args.dataloader)('valid', args, train_dset.token_to_ix)
66 |
67 | train_loader = DataLoader(train_dset, args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
68 | eval_loader = DataLoader(eval_dset, args.batch_size, num_workers=8, pin_memory=True)
69 |
70 | # Net
71 | net = eval(args.model)(args, train_dset.vocab_size, train_dset.pretrained_emb).cuda()
72 | print("Total number of parameters : " + str(sum([p.numel() for p in net.parameters()]) / 1e6) + "M")
73 | net = net.cuda()
74 |
75 | # Create Checkpoint dir
76 | if not os.path.exists(os.path.join(args.output, args.name)):
77 | os.makedirs(os.path.join(args.output, args.name))
78 |
79 | # Run training
80 | eval_accuracies = train(net, train_loader, eval_loader, args)
81 | open('best_scores.txt', 'a+').write(args.output + "/" + args.name + ","
82 | + str(max(eval_accuracies)) + "\n")
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_MAT import Model_MAT
2 | from .model_MNT import Model_MNT
--------------------------------------------------------------------------------
/models/model_MAT.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from layers.fc import MLP, FC
7 | from layers.layer_norm import LayerNorm
8 |
9 | # ------------------------------------
10 | # ---------- Masking sequence --------
11 | # ------------------------------------
12 | def make_mask(feature):
13 | return (torch.sum(
14 | torch.abs(feature),
15 | dim=-1
16 | ) == 0).unsqueeze(1).unsqueeze(2)
17 |
18 |
19 |
20 | class AttFlat(nn.Module):
21 | def __init__(self, args):
22 | super(AttFlat, self).__init__()
23 | self.args = args
24 | self.flat_glimpse = 1
25 |
26 | self.mlp = MLP(
27 | in_size=args.hidden_size,
28 | mid_size=args.ff_size,
29 | out_size=self.flat_glimpse,
30 | dropout_r=args.dropout_r,
31 | use_relu=True
32 | )
33 |
34 | self.linear_merge = nn.Linear(
35 | args.hidden_size * self.flat_glimpse,
36 | args.hidden_size * 2
37 | )
38 |
39 | def forward(self, x, x_mask):
40 | att = self.mlp(x)
41 | att = att.masked_fill(
42 | x_mask.squeeze(1).squeeze(1).unsqueeze(2),
43 | -1e9
44 | )
45 | att = F.softmax(att, dim=1)
46 |
47 | att_list = []
48 | for i in range(self.flat_glimpse):
49 | att_list.append(
50 | torch.sum(att[:, :, i: i + 1] * x, dim=1)
51 | )
52 |
53 | x_atted = torch.cat(att_list, dim=1)
54 | x_atted = self.linear_merge(x_atted)
55 |
56 | return x_atted
57 |
58 |
59 |
60 | class MHAtt(nn.Module):
61 | def __init__(self, args):
62 | super(MHAtt, self).__init__()
63 | self.args = args
64 |
65 | self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
66 | self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
67 | self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
68 | self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
69 |
70 | self.dropout = nn.Dropout(args.dropout_r)
71 |
72 | def forward(self, v, k, q, mask):
73 | n_batches = q.size(0)
74 |
75 | v = self.linear_v(v).view(
76 | n_batches,
77 | -1,
78 | self.args.multi_head,
79 | int(self.args.hidden_size / self.args.multi_head)
80 | ).transpose(1, 2)
81 |
82 | k = self.linear_k(k).view(
83 | n_batches,
84 | -1,
85 | self.args.multi_head,
86 | int(self.args.hidden_size / self.args.multi_head)
87 | ).transpose(1, 2)
88 |
89 | q = self.linear_q(q).view(
90 | n_batches,
91 | -1,
92 | self.args.multi_head,
93 | int(self.args.hidden_size / self.args.multi_head)
94 | ).transpose(1, 2)
95 |
96 | atted = self.att(v, k, q, mask)
97 | atted = atted.transpose(1, 2).contiguous().view(
98 | n_batches,
99 | -1,
100 | self.args.hidden_size
101 | )
102 |
103 | atted = self.linear_merge(atted)
104 |
105 | return atted
106 |
107 | def att(self, value, key, query, mask):
108 | d_k = query.size(-1)
109 |
110 | scores = torch.matmul(
111 | query, key.transpose(-2, -1)
112 | ) / math.sqrt(d_k)
113 |
114 | if mask is not None:
115 | scores = scores.masked_fill(mask, -1e9)
116 |
117 | att_map = F.softmax(scores, dim=-1)
118 | att_map = self.dropout(att_map)
119 |
120 | return torch.matmul(att_map, value)
121 |
122 |
123 | # ---------------------------
124 | # ---- Feed Forward Nets ----
125 | # ---------------------------
126 |
127 | class FFN(nn.Module):
128 | def __init__(self, args):
129 | super(FFN, self).__init__()
130 |
131 | self.mlp = MLP(
132 | in_size=args.hidden_size,
133 | mid_size=args.ff_size,
134 | out_size=args.hidden_size,
135 | dropout_r=args.dropout_r,
136 | use_relu=True
137 | )
138 |
139 | def forward(self, x):
140 | return self.mlp(x)
141 |
142 |
143 | # ------------------------
144 | # ---- Self Attention ----
145 | # ------------------------
146 |
147 | class SA(nn.Module):
148 | def __init__(self, args):
149 | super(SA, self).__init__()
150 |
151 | self.mhatt = MHAtt(args)
152 | self.ffn = FFN(args)
153 |
154 | self.dropout1 = nn.Dropout(args.dropout_r)
155 | self.norm1 = LayerNorm(args.hidden_size)
156 |
157 | self.dropout2 = nn.Dropout(args.dropout_r)
158 | self.norm2 = LayerNorm(args.hidden_size)
159 |
160 | def forward(self, y, y_mask):
161 | y = self.norm1(y + self.dropout1(
162 | self.mhatt(y, y, y, y_mask)
163 | ))
164 |
165 | y = self.norm2(y + self.dropout2(
166 | self.ffn(y)
167 | ))
168 |
169 | return y
170 |
171 |
172 | # -------------------------------
173 | # ---- Self Guided Attention ----
174 | # -------------------------------
175 |
176 | class SGA(nn.Module):
177 | def __init__(self, args):
178 | super(SGA, self).__init__()
179 |
180 | self.mhatt1 = MHAtt(args)
181 | self.mhatt2 = MHAtt(args)
182 | self.ffn = FFN(args)
183 |
184 | self.dropout1 = nn.Dropout(args.dropout_r)
185 | self.norm1 = LayerNorm(args.hidden_size)
186 |
187 | self.dropout2 = nn.Dropout(args.dropout_r)
188 | self.norm2 = LayerNorm(args.hidden_size)
189 |
190 | self.dropout3 = nn.Dropout(args.dropout_r)
191 | self.norm3 = LayerNorm(args.hidden_size)
192 |
193 | def forward(self, x, y, x_mask, y_mask):
194 | x = self.norm1(x + self.dropout1(
195 | self.mhatt1(v=x, k=x, q=x, mask=x_mask)
196 | ))
197 |
198 | x = self.norm2(x + self.dropout2(
199 | self.mhatt2(v=y, k=y, q=x, mask=y_mask)
200 | ))
201 |
202 | x = self.norm3(x + self.dropout3(
203 | self.ffn(x)
204 | ))
205 |
206 | return x
207 |
208 |
209 | # ------------------------------------------------
210 | # ---- MAC Layers Cascaded by Encoder-Decoder ----
211 | # ------------------------------------------------
212 |
213 | class MCA_ED(nn.Module):
214 | def __init__(self, args):
215 | super(MCA_ED, self).__init__()
216 |
217 | self.enc_list = nn.ModuleList([SA(args) for _ in range(args.layer)])
218 | self.dec_list = nn.ModuleList([SGA(args) for _ in range(args.layer)])
219 |
220 | def forward(self, y, x, y_mask, x_mask):
221 | # Get encoder last hidden vector
222 | for enc in self.enc_list:
223 | y = enc(y, y_mask)
224 |
225 | # Input encoder last hidden vector
226 | # And obtain decoder last hidden vectors
227 | for dec in self.dec_list:
228 | x = dec(x, y, x_mask, y_mask)
229 |
230 | return y, x
231 |
232 |
233 |
234 | # -------------------------
235 | # ---- Main MCAN Model ----
236 | # -------------------------
237 |
238 | class Model_MAT(nn.Module):
239 | def __init__(self, args, vocab_size, pretrained_emb):
240 | super(Model_MAT, self).__init__()
241 | self.args = args
242 |
243 | # LSTM
244 | self.embedding = nn.Embedding(
245 | num_embeddings=vocab_size,
246 | embedding_dim=args.word_embed_size
247 | )
248 |
249 | # Loading the GloVe embedding weights
250 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
251 | self.input_drop = nn.Dropout(args.dropout_i)
252 |
253 | self.lstm_x = nn.LSTM(
254 | input_size=args.word_embed_size,
255 | hidden_size=args.hidden_size,
256 | num_layers=1,
257 | batch_first=True
258 | )
259 |
260 | self.lstm_y = nn.LSTM(
261 | input_size=args.audio_feat_size,
262 | hidden_size=args.hidden_size,
263 | num_layers=1,
264 | batch_first=True
265 | )
266 |
267 | # self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size)
268 | self.backbone = MCA_ED(args)
269 |
270 | # Flatten to vector
271 | self.attflat_img = AttFlat(args)
272 | self.attflat_lang = AttFlat(args)
273 |
274 | # Classification layers
275 |
276 | self.proj_norm = LayerNorm(2 * args.hidden_size)
277 | self.proj = nn.Linear(2 * args.hidden_size, args.ans_size)
278 | self.proj_drop = nn.Dropout(args.dropout_o)
279 |
280 | def forward(self, x, y, _):
281 | x_mask = make_mask(x.unsqueeze(2))
282 | y_mask = make_mask(y)
283 |
284 | embedding = self.embedding(x)
285 |
286 | x, _ = self.lstm_x(self.input_drop(embedding))
287 | y, _ = self.lstm_y(self.input_drop(y))
288 |
289 | # Backbone Framework
290 | lang_feat, img_feat = self.backbone(
291 | x,
292 | y,
293 | x_mask,
294 | y_mask
295 | )
296 |
297 | # Flatten to vector
298 | lang_feat = self.attflat_lang(
299 | lang_feat,
300 | x_mask
301 | )
302 |
303 | img_feat = self.attflat_img(
304 | img_feat,
305 | y_mask
306 | )
307 |
308 | # Classification layers
309 | proj_feat = lang_feat + img_feat
310 | proj_feat = self.proj_norm(proj_feat)
311 |
312 | proj_feat = self.proj_drop(proj_feat)
313 | proj_feat = self.proj(proj_feat)
314 |
315 | return proj_feat
--------------------------------------------------------------------------------
/models/model_MNT.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from layers.fc import MLP, FC
7 |
8 | class LayerNorm(nn.Module):
9 | def __init__(self, size, eps=1e-6):
10 | super(LayerNorm, self).__init__()
11 | self.eps = eps
12 |
13 | self.a_2 = nn.Parameter(torch.ones(size))
14 | self.b_2 = nn.Parameter(torch.zeros(size))
15 |
16 | def forward(self, x, ab=None):
17 | mean = x.mean(-1, keepdim=True)
18 | std = x.std(-1, keepdim=True)
19 |
20 | if ab is not None:
21 | cond = torch.chunk(ab, 2, dim=-1)
22 | g = cond[0]
23 | b = cond[1]
24 | return (self.a_2 + torch.unsqueeze(g, 1)) * (x - mean) / (std + self.eps) + (self.b_2 + torch.unsqueeze(b, 1))
25 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
26 |
27 | # ------------------------------------
28 | # ---------- Masking sequence --------
29 | # ------------------------------------
30 | def make_mask(feature):
31 | return (torch.sum(
32 | torch.abs(feature),
33 | dim=-1
34 | ) == 0).unsqueeze(1).unsqueeze(2)
35 |
36 |
37 |
38 | class AttFlat(nn.Module):
39 | def __init__(self, args):
40 | super(AttFlat, self).__init__()
41 | self.args = args
42 | self.flat_glimpse = 1
43 |
44 | self.mlp = MLP(
45 | in_size=args.hidden_size,
46 | mid_size=args.ff_size,
47 | out_size=self.flat_glimpse,
48 | dropout_r=args.dropout_r,
49 | use_relu=True
50 | )
51 |
52 | self.linear_merge = nn.Linear(
53 | args.hidden_size * self.flat_glimpse,
54 | args.hidden_size
55 | )
56 |
57 | def forward(self, x, x_mask):
58 | att = self.mlp(x)
59 | att = att.masked_fill(
60 | x_mask.squeeze(1).squeeze(1).unsqueeze(2),
61 | -1e9
62 | )
63 | att = F.softmax(att, dim=1)
64 |
65 | att_list = []
66 | for i in range(self.flat_glimpse):
67 | att_list.append(
68 | torch.sum(att[:, :, i: i + 1] * x, dim=1)
69 | )
70 |
71 | x_atted = torch.cat(att_list, dim=1)
72 | x_atted = self.linear_merge(x_atted)
73 |
74 | return x_atted
75 |
76 |
77 |
78 | class MHAtt(nn.Module):
79 | def __init__(self, args):
80 | super(MHAtt, self).__init__()
81 | self.args = args
82 |
83 | self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
84 | self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
85 | self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
86 | self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
87 |
88 | self.dropout = nn.Dropout(args.dropout_r)
89 |
90 | def forward(self, v, k, q, mask):
91 | n_batches = q.size(0)
92 |
93 | v = self.linear_v(v).view(
94 | n_batches,
95 | -1,
96 | self.args.multi_head,
97 | int(self.args.hidden_size / self.args.multi_head)
98 | ).transpose(1, 2)
99 |
100 | k = self.linear_k(k).view(
101 | n_batches,
102 | -1,
103 | self.args.multi_head,
104 | int(self.args.hidden_size / self.args.multi_head)
105 | ).transpose(1, 2)
106 |
107 | q = self.linear_q(q).view(
108 | n_batches,
109 | -1,
110 | self.args.multi_head,
111 | int(self.args.hidden_size / self.args.multi_head)
112 | ).transpose(1, 2)
113 |
114 | atted = self.att(v, k, q, mask)
115 | atted = atted.transpose(1, 2).contiguous().view(
116 | n_batches,
117 | -1,
118 | self.args.hidden_size
119 | )
120 |
121 | atted = self.linear_merge(atted)
122 |
123 | return atted
124 |
125 | def att(self, value, key, query, mask):
126 | d_k = query.size(-1)
127 |
128 | scores = torch.matmul(
129 | query, key.transpose(-2, -1)
130 | ) / math.sqrt(d_k)
131 |
132 | if mask is not None:
133 | scores = scores.masked_fill(mask, -1e9)
134 |
135 | att_map = F.softmax(scores, dim=-1)
136 | att_map = self.dropout(att_map)
137 |
138 | return torch.matmul(att_map, value)
139 |
140 |
141 | # ---------------------------
142 | # ---- Feed Forward Nets ----
143 | # ---------------------------
144 |
145 | class FFN(nn.Module):
146 | def __init__(self, args):
147 | super(FFN, self).__init__()
148 |
149 | self.mlp = MLP(
150 | in_size=args.hidden_size,
151 | mid_size=args.ff_size,
152 | out_size=args.hidden_size,
153 | dropout_r=args.dropout_r,
154 | use_relu=True
155 | )
156 |
157 | def forward(self, x):
158 | return self.mlp(x)
159 |
160 |
161 | # ------------------------
162 | # ---- Self Attention ----
163 | # ------------------------
164 |
165 | class SA(nn.Module):
166 | def __init__(self, args):
167 | super(SA, self).__init__()
168 |
169 | self.mhatt = MHAtt(args)
170 | self.ffn = FFN(args)
171 |
172 | self.dropout1 = nn.Dropout(args.dropout_r)
173 | self.norm1 = LayerNorm(args.hidden_size)
174 |
175 | self.dropout2 = nn.Dropout(args.dropout_r)
176 | self.norm2 = LayerNorm(args.hidden_size)
177 |
178 | def forward(self, y, y_mask):
179 | y = self.norm1(y + self.dropout1(
180 | self.mhatt(y, y, y, y_mask)
181 | ))
182 |
183 | y = self.norm2(y + self.dropout2(
184 | self.ffn(y)
185 | ))
186 |
187 | return y
188 |
189 |
190 | class SAG(nn.Module):
191 | def __init__(self, args):
192 | super(SAG, self).__init__()
193 |
194 | self.mhatt = MHAtt(args)
195 | self.ffn = FFN(args)
196 |
197 | self.dropout1 = nn.Dropout(args.dropout_r)
198 | self.norm1 = LayerNorm(args.hidden_size)
199 |
200 | self.dropout2 = nn.Dropout(args.dropout_r)
201 | self.norm2 = LayerNorm(args.hidden_size)
202 |
203 | def forward(self, y, y_mask, cond):
204 | cond = torch.chunk(cond, 2, dim=-1)
205 | y = self.norm1(y + self.dropout1(
206 | self.mhatt(y, y, y, y_mask)
207 | ), cond[0])
208 |
209 | y = self.norm2(y + self.dropout2(
210 | self.ffn(y)
211 | ), cond[1])
212 |
213 | return y
214 |
215 |
216 | # -------------------------
217 | # ---- Main MCAN Model ----
218 | # -------------------------
219 |
220 | class Model_MNT(nn.Module):
221 | def __init__(self, args, vocab_size, pretrained_emb):
222 | super(Model_MNT, self).__init__()
223 | self.args = args
224 |
225 | # LSTM
226 | self.embedding = nn.Embedding(
227 | num_embeddings=vocab_size,
228 | embedding_dim=args.word_embed_size
229 | )
230 |
231 | # Loading the GloVe embedding weights
232 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
233 | self.input_drop = nn.Dropout(args.dropout_i)
234 |
235 | self.lstm_x = nn.LSTM(
236 | input_size=args.word_embed_size,
237 | hidden_size=args.hidden_size,
238 | num_layers=1,
239 | batch_first=True
240 | )
241 |
242 | self.lstm_y = nn.LSTM(
243 | input_size=args.audio_feat_size,
244 | hidden_size=args.hidden_size,
245 | num_layers=1,
246 | batch_first=True
247 | )
248 |
249 | # self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size)
250 | self.fc = FC(args.hidden_size, args.hidden_size * args.layer * 2 * 2)
251 | self.enc_list = nn.ModuleList([SA(args) for _ in range(args.layer)])
252 | self.dec_list = nn.ModuleList([SAG(args) for _ in range(args.layer)])
253 |
254 | # Flatten to vector
255 | self.attflat_img = AttFlat(args)
256 | self.attflat_lang = AttFlat(args)
257 |
258 | # Classification layers
259 |
260 | self.proj_norm = LayerNorm(args.hidden_size)
261 | self.proj = nn.Linear(args.hidden_size, args.ans_size)
262 | self.proj_drop = nn.Dropout(args.dropout_o)
263 |
264 | def forward(self, x, y, _):
265 | x_mask = make_mask(x.unsqueeze(2))
266 | y_mask = make_mask(y)
267 |
268 | embedding = self.embedding(x)
269 |
270 | x, _ = self.lstm_x(self.input_drop(embedding))
271 | y, _ = self.lstm_y(self.input_drop(y))
272 |
273 | # Backbone Framework
274 | for enc in self.enc_list:
275 | x = enc(x, x_mask)
276 |
277 | lang_feat = self.attflat_lang(
278 | x,
279 | x_mask
280 | )
281 |
282 | cond = torch.chunk(self.fc(lang_feat), self.args.layer, dim=-1)
283 |
284 | for i, dec in enumerate(self.dec_list):
285 | y = dec(y, y_mask, cond[i])
286 |
287 | img_feat = self.attflat_img(
288 | y,
289 | y_mask
290 | )
291 |
292 | # Classification layers
293 | proj_feat = lang_feat + img_feat
294 | proj_feat = self.proj_norm(proj_feat)
295 |
296 | proj_feat = self.proj_drop(proj_feat)
297 | proj_feat = self.proj(proj_feat)
298 |
299 | return proj_feat
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import time
4 | import os
5 | from utils.pred_func import *
6 |
7 | def train(net, train_loader, eval_loader, args):
8 |
9 | logfile = open(
10 | args.output + "/" + args.name +
11 | '/log_run_' + str(args.seed) + '.txt',
12 | 'w+'
13 | )
14 | logfile.write(str(args))
15 |
16 | best_eval_accuracy = 0.0
17 | early_stop = 0
18 | decay_count = 0
19 |
20 | # Load the optimizer paramters
21 | optim = torch.optim.Adam(net.parameters(), lr=args.lr_base)
22 |
23 | loss_fn = args.loss_fn
24 | eval_accuracies = []
25 | for epoch in range(0, args.max_epoch):
26 | time_start = time.time()
27 | loss_sum = 0
28 | for step, (
29 | id,
30 | x,
31 | y,
32 | z,
33 | ans,
34 | ) in enumerate(train_loader):
35 | loss_tmp = 0
36 | optim.zero_grad()
37 |
38 | x = x.cuda()
39 | y = y.cuda()
40 | z = z.cuda()
41 | ans = ans.cuda()
42 |
43 | pred = net(x, y, z)
44 | loss = loss_fn(pred, ans)
45 | loss.backward()
46 |
47 | loss_sum += loss.cpu().data.numpy()
48 | loss_tmp += loss.cpu().data.numpy()
49 |
50 | print("\r[Epoch %2d][Step %4d/%4d] Loss: %.4f, Lr: %.2e, %4d m "
51 | "remaining" % (
52 | epoch + 1,
53 | step,
54 | int(len(train_loader.dataset) / args.batch_size),
55 | loss_tmp / args.batch_size,
56 | *[group['lr'] for group in optim.param_groups],
57 | ((time.time() - time_start) / (step + 1)) * ((len(train_loader.dataset) / args.batch_size) - step) / 60,
58 | ), end=' ')
59 |
60 | # Gradient norm clipping
61 | if args.grad_norm_clip > 0:
62 | nn.utils.clip_grad_norm_(
63 | net.parameters(),
64 | args.grad_norm_clip
65 | )
66 |
67 | optim.step()
68 |
69 | time_end = time.time()
70 | elapse_time = time_end-time_start
71 | print('Finished in {}s'.format(int(elapse_time)))
72 | epoch_finish = epoch + 1
73 |
74 | # Logging
75 | logfile.write(
76 | 'Epoch: ' + str(epoch_finish) +
77 | ', Loss: ' + str(loss_sum / len(train_loader.dataset)) +
78 | ', Lr: ' + str([group['lr'] for group in optim.param_groups]) + '\n' +
79 | 'Elapsed time: ' + str(int(elapse_time)) +
80 | ', Speed(s/batch): ' + str(elapse_time / step) +
81 | '\n\n'
82 | )
83 |
84 | # Eval
85 | if epoch_finish >= args.eval_start:
86 | print('Evaluation...')
87 | accuracy, _ = evaluate(net, eval_loader, args)
88 | print('Accuracy :'+str(accuracy))
89 | eval_accuracies.append(accuracy)
90 | if accuracy > best_eval_accuracy:
91 | # Best
92 | state = {
93 | 'state_dict': net.state_dict(),
94 | 'optimizer': optim.state_dict(),
95 | 'args': args,
96 | }
97 | torch.save(
98 | state,
99 | args.output + "/" + args.name +
100 | '/best'+str(args.seed)+'.pkl'
101 | )
102 | best_eval_accuracy = accuracy
103 | early_stop = 0
104 |
105 | elif decay_count < args.lr_decay_times:
106 | # Decay
107 | print('LR Decay...')
108 | decay_count += 1
109 |
110 | ckpt = torch.load(args.output + "/" + args.name +
111 | '/best'+str(args.seed)+'.pkl')
112 | net.load_state_dict(ckpt['state_dict'])
113 | optim.load_state_dict(ckpt['optimizer'])
114 |
115 | # adjust_lr(optim, args.lr_decay)
116 | for group in optim.param_groups:
117 | group['lr'] = (args.lr_base * args.lr_decay**decay_count)
118 | else:
119 | # Early stop, does not start before lr_decay_times reached
120 | early_stop += 1
121 |
122 | if early_stop == args.early_stop:
123 | logfile.write('Early stop reached' + '\n')
124 | print('Early stop reached')
125 | break
126 |
127 | logfile.write('best_acc :' + str(best_eval_accuracy) + '\n\n')
128 | print('best_eval_acc :' + str(best_eval_accuracy) + '\n\n')
129 | os.rename(args.output + "/" + args.name +
130 | '/best' + str(args.seed) + '.pkl',
131 | args.output + "/" + args.name +
132 | '/best' + str(best_eval_accuracy) + "_" + str(args.seed) + '.pkl')
133 | logfile.close()
134 | return eval_accuracies
135 |
136 |
137 | def evaluate(net, eval_loader, args):
138 | accuracy = []
139 | net.train(False)
140 | preds = {}
141 | for step, (
142 | ids,
143 | x,
144 | y,
145 | z,
146 | ans,
147 | ) in enumerate(eval_loader):
148 | x = x.cuda()
149 | y = y.cuda()
150 | z = z.cuda()
151 | pred = net(x, y, z).cpu().data.numpy()
152 |
153 | if not eval_loader.dataset.private_set:
154 | ans = ans.cpu().data.numpy()
155 | accuracy += list(eval(args.pred_func)(pred) == ans)
156 |
157 | # Save preds
158 | for id, p in zip(ids, pred):
159 | preds[id] = p
160 |
161 | net.train(True)
162 | return 100*np.mean(np.array(accuracy)), preds
163 |
164 |
--------------------------------------------------------------------------------
/utils/compute_args.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def compute_args(args):
5 |
6 | # DataLoader
7 | if args.dataset == "MOSEI": args.dataloader = 'Mosei_Dataset'
8 | if args.dataset == "MELD": args.dataloader = 'Meld_Dataset'
9 | if args.dataset == "MOSI": args.dataloader = 'Mosi_Dataset'
10 | if args.dataset == "IEMOCAP": args.dataloader = 'Iemocap_Dataset'
11 | if args.dataset == "VGAF": args.dataloader = 'Vgaf_Dataset'
12 |
13 | # Loss function to use
14 | if args.dataset == 'MOSEI' and args.task == 'sentiment': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
15 | if args.dataset == 'MOSEI' and args.task == 'emotion': args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum").cuda()
16 | if args.dataset == 'MELD' and args.task == 'sentiment': args.loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 1.6]), reduction="sum").cuda()
17 | if args.dataset == 'MELD' and args.task == 'emotion': args.loss_fn = torch.nn.CrossEntropyLoss(
18 | weight=torch.tensor([1.0, 4709/1205, 4709/268, 4709/683, 4709/1743, 4709/271, 4709/1109]),
19 | reduction="sum").cuda()
20 |
21 | if args.dataset == 'MOSI': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
22 | if args.dataset == "IEMOCAP": args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
23 | if args.dataset == "VGAF": args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
24 |
25 | # Answer size
26 | if args.dataset == 'MOSEI' and args.task == "sentiment": args.ans_size = 7
27 | if args.dataset == 'MOSEI' and args.task == "sentiment" and args.task_binary: args.ans_size = 2
28 | if args.dataset == 'MOSEI' and args.task == "emotion": args.ans_size = 6
29 | if args.dataset == 'MELD' and args.task == "emotion": args.ans_size = 7
30 | if args.dataset == 'MELD' and args.task == "sentiment": args.ans_size = 3
31 | if args.dataset == 'MOSI': args.ans_size = 2
32 | if args.dataset == "IEMOCAP": args.ans_size = 4
33 | if args.dataset == "VGAF": args.ans_size = 3
34 |
35 | # Pred function
36 | if args.dataset == 'MOSEI': args.pred_func = "amax"
37 | if args.dataset == 'MOSEI' and args.task == "emotion": args.pred_func = "multi_label"
38 | if args.dataset == 'MELD': args.pred_func = "amax"
39 | if args.dataset == 'MOSI': args.pred_func = "amax"
40 | if args.dataset == 'IEMOCAP': args.pred_func = "amax"
41 | if args.dataset == "VGAF": args.pred_func = "amax"
42 |
43 | return args
--------------------------------------------------------------------------------
/utils/optim.py:
--------------------------------------------------------------------------------
1 | import torch.optim as Optim
2 |
3 |
4 | class WarmupOptimizer(object):
5 | def __init__(self, lr_base, optimizer, data_size, batch_size, warmup_epoch):
6 | self.optimizer = optimizer
7 | self._step = 0
8 | self.lr_base = lr_base
9 | self._rate = 0
10 | self.data_size = data_size
11 | self.batch_size = batch_size
12 | self.warmup_epoch = warmup_epoch
13 |
14 |
15 | def step(self):
16 | self._step += 1
17 |
18 | rate = self.rate()
19 | for p in self.optimizer.param_groups:
20 | p['lr'] = rate
21 | self._rate = rate
22 |
23 | self.optimizer.step()
24 |
25 |
26 | def zero_grad(self):
27 | self.optimizer.zero_grad()
28 |
29 |
30 | def rate(self, step=None):
31 | if step is None:
32 | step = self._step
33 |
34 | if step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.25):
35 | r = self.lr_base * 1/(self.warmup_epoch + 1)
36 | elif step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.5):
37 | r = self.lr_base * 2/(self.warmup_epoch + 1)
38 | elif step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.75):
39 | r = self.lr_base * 3/(self.warmup_epoch + 1)
40 | else:
41 | r = self.lr_base
42 |
43 | return r
44 |
45 | def __str__(self):
46 | return "optimizer -> " + str(self.optimizer)
47 |
48 |
49 |
50 | def get_optim(args, model, data_size, lr_base=None):
51 | if lr_base is None:
52 | lr_base = args.lr_base
53 |
54 | std_optim = getattr(Optim, args.opt)
55 | params = filter(lambda p: p.requires_grad, model.parameters())
56 | eval_str = 'params, lr=0'
57 | d_opt = eval(args.opt_params)
58 | for key in d_opt:
59 | eval_str += ' ,' + key + '=' + str(d_opt[key])
60 |
61 | optim = WarmupOptimizer(
62 | lr_base,
63 | eval('std_optim' + '(' + eval_str + ')'),
64 | data_size,
65 | args.batch_size,
66 | args.warmup_epoch
67 | )
68 |
69 | return optim
70 |
71 |
72 | def adjust_lr(optim, decay_r):
73 | optim.lr_base *= decay_r
--------------------------------------------------------------------------------
/utils/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import itertools
4 |
5 | def plot(d):
6 | # An "interface" to matplotlib.axes.Axes.hist() method
7 | n, bins, patches = plt.hist(x=d, bins='auto', color='#0504aa',
8 | alpha=0.7, rwidth=0.85)
9 | plt.grid(axis='y', alpha=0.75)
10 | plt.title('Temporal dimension for visual features')
11 | maxfreq = n.max()
12 | # Set a clean upper y-axis limit.
13 | plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
14 | axes = plt.gca()
15 | axes.set_xlim([0, 100])
16 | plt.xlabel("Temporal dimension")
17 | plt.ylabel("Number of samples")
18 | plt.show()
19 | plt.savefig('oui')
20 |
21 |
22 |
23 | def plot_confusion_matrix(cm, classes,
24 | normalize=False,
25 | title='Confusion matrix',
26 | cmap=plt.cm.Blues):
27 | """
28 | This function prints and plots the confusion matrix.
29 | Normalization can be applied by setting `normalize=True`.
30 | """
31 |
32 |
33 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
34 | plt.title(title)
35 | plt.colorbar()
36 | tick_marks = np.arange(len(classes))
37 | plt.xticks(tick_marks, classes, rotation=45)
38 | plt.yticks(tick_marks, classes)
39 |
40 | if normalize:
41 | cm = np.round((cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100, 2) # from 0.45555555 to 45.55
42 | print("Normalized confusion matrix")
43 | else:
44 | print('Confusion matrix, without normalization')
45 |
46 | print(cm)
47 |
48 | thresh = cm.max() / 2.
49 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
50 | plt.text(j, i, cm[i, j],
51 | horizontalalignment="center",
52 | color="white" if cm[i, j] > thresh else "black")
53 |
54 | plt.tight_layout()
55 | plt.ylabel('True class')
56 | plt.xlabel('Predicted class')
57 | plt.savefig("confusion_matrix")
58 | plt.clf()
--------------------------------------------------------------------------------
/utils/pred_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def amax(x):
5 | return np.argmax(x, axis=1)
6 |
7 |
8 | def multi_label(x):
9 | return (x > 0)
--------------------------------------------------------------------------------
/utils/tokenize.py:
--------------------------------------------------------------------------------
1 | # $ wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz
2 | # $ pip install en_vectors_web_lg-2.1.0.tar.gz
3 | import en_vectors_web_lg
4 | import re
5 | import numpy as np
6 | import os
7 | import pickle
8 |
9 | def clean(w):
10 | return re.sub(
11 | r"([.,'!?\"()*#:;])",
12 | '',
13 | w.lower()
14 | ).replace('-', ' ').replace('/', ' ')
15 |
16 |
17 | def tokenize(key_to_word):
18 | key_to_sentence = {}
19 | for k, v in key_to_word.items():
20 | key_to_sentence[k] = [clean(w) for w in v if clean(w) != '']
21 | return key_to_sentence
22 |
23 |
24 | def create_dict(key_to_sentence, dataroot, use_glove=True):
25 | token_file = dataroot+"/token_to_ix.pkl"
26 | glove_file = dataroot+"/train_glove.npy"
27 | if os.path.exists(glove_file) and os.path.exists(token_file):
28 | print("Loading train language files")
29 | return pickle.load(open(token_file, "rb")), np.load(glove_file)
30 |
31 | print("Creating train language files")
32 | token_to_ix = {
33 | 'UNK': 1,
34 | }
35 |
36 | spacy_tool = None
37 | pretrained_emb = []
38 | if use_glove:
39 | spacy_tool = en_vectors_web_lg.load()
40 | pretrained_emb.append(spacy_tool('UNK').vector)
41 |
42 | for k, v in key_to_sentence.items():
43 | for word in v:
44 | if word not in token_to_ix:
45 | token_to_ix[word] = len(token_to_ix)
46 | if use_glove:
47 | pretrained_emb.append(spacy_tool(word).vector)
48 |
49 | pretrained_emb = np.array(pretrained_emb)
50 | np.save(glove_file, pretrained_emb)
51 | pickle.dump(token_to_ix, open(token_file, "wb"))
52 | return token_to_ix, pretrained_emb
53 |
54 | def sent_to_ix(s, token_to_ix, max_token=100):
55 | ques_ix = np.zeros(max_token, np.int64)
56 |
57 | for ix, word in enumerate(s):
58 | if word in token_to_ix:
59 | ques_ix[ix] = token_to_ix[word]
60 | else:
61 | ques_ix[ix] = token_to_ix['UNK']
62 |
63 | if ix + 1 == max_token:
64 | break
65 |
66 | return ques_ix
67 |
68 |
69 | def cmumosei_7(a):
70 | if a < -2:
71 | res = 0
72 | if -2 <= a and a < -1:
73 | res = 1
74 | if -1 <= a and a < 0:
75 | res = 2
76 | if 0 <= a and a <= 0:
77 | res = 3
78 | if 0 < a and a <= 1:
79 | res = 4
80 | if 1 < a and a <= 2:
81 | res = 5
82 | if a > 2:
83 | res = 6
84 | return res
85 |
86 | def cmumosei_2(a):
87 | if a < 0:
88 | return 0
89 | if a >= 0:
90 | return 1
91 |
92 | def pad_feature(feat, max_len):
93 | if feat.shape[0] > max_len:
94 | feat = feat[:max_len]
95 | feat = np.pad(
96 | feat,
97 | ((0, max_len - feat.shape[0]), (0, 0)),
98 | mode='constant',
99 | constant_values=0
100 | )
101 | return feat
102 |
--------------------------------------------------------------------------------