├── README.md ├── dataloaders ├── __init__.py ├── iemocap_dataset.py ├── meld_dataset.py ├── mosei_dataset.py ├── mosi_dataset.py └── vgaf_dataset.py ├── ensembling.py ├── layers ├── fc.py └── layer_norm.py ├── main.py ├── models ├── __init__.py ├── model_MAT.py └── model_MNT.py ├── train.py └── utils ├── compute_args.py ├── optim.py ├── plot.py ├── pred_func.py └── tokenize.py /README.md: -------------------------------------------------------------------------------- 1 |   
2 |     EMNLP2020
3 | 4 | Pytorch implementation of the paper "Modulated Fusion using Transformer for Linguistic-Acoustic Emotion Recognition"
5 | ``` 6 | @inproceedings{delbrouck-etal-2020-modulated, 7 | title = "Modulated Fusion using Transformer for Linguistic-Acoustic Emotion Recognition", 8 | author = "Delbrouck, Jean-Benoit and 9 | Tits, No{\'e} and 10 | Dupont, St{\'e}phane", 11 | booktitle = "Proceedings of the First International Workshop on Natural Language Processing Beyond Text", 12 | month = nov, 13 | year = "2020", 14 | address = "Online", 15 | publisher = "Association for Computational Linguistics", 16 | url = "https://www.aclweb.org/anthology/2020.nlpbt-1.1", 17 | doi = "10.18653/v1/2020.nlpbt-1.1", 18 | pages = "1--10", 19 | abstract = "This paper aims to bring a new lightweight yet powerful solution for the task of Emotion Recognition and Sentiment Analysis. Our motivation is to propose two architectures based on Transformers and modulation that combine the linguistic and acoustic inputs from a wide range of datasets to challenge, and sometimes surpass, the state-of-the-art in the field. To demonstrate the efficiency of our models, we carefully evaluate their performances on the IEMOCAP, MOSI, MOSEI and MELD dataset. The experiments can be directly replicated and the code is fully open for future researches.", 20 | } 21 | ``` 22 | 23 | #### Environement 24 | 25 | Create a 3.6 python environement with: 26 | ``` 27 | torch 1.2.0 28 | torchvision 0.4.0 29 | numpy 1.18.1 30 | ``` 31 | 32 | We use GloVe vectors from space. This can be installed to your environement using the following commands : 33 | ``` 34 | wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz 35 | pip install en_vectors_web_lg-2.1.0.tar.gz 36 | ``` 37 | #### Data 38 | 39 | Create a data folder and get the data: 40 | ``` 41 | mkdir -p data 42 | cd data 43 | wget -O data.zip https://www.dropbox.com/s/tz25q3xxfraw2r3/data.zip?dl=1 44 | unzip data.zip 45 | ``` 46 | 47 | #### Training 48 | 49 | Here is an example to train a MAT model on IEMOCAP: 50 | 51 | ``` 52 | mkdir -p ckpt 53 | for i in {1..10} 54 | do 55 | python main.py --dataset IEMOCAP \ 56 | --model Model_MAT \ 57 | --multi_head 4 \ 58 | --ff_size 1024 \ 59 | --hidden_size 512 \ 60 | --layer 2 \ 61 | --batch_size 32 \ 62 | --lr_base 0.0001 \ 63 | --dropout_r 0.1 \ 64 | --dropout_o 0.5 \ 65 | --name mymodel 66 | done 67 | 68 | ``` 69 | Checkpoints will be stored in folder `ckpt/mymodel` 70 | 71 | #### Evaluation 72 | 73 | You can evaluate a model by typing : 74 | ``` 75 | python ensembling.py --name mymodel --sets test 76 | ``` 77 | The task settings are defined in the checkpoint state dict, so the evaluation will be carried on the dataset you trained your model on. 78 | 79 | By default, the script globs all the training checkpoints inside the folder and ensembling will be performed 80 | To show further details of the evaluation from a specific ensembling, you can use the `--index` argument: 81 | ``` 82 | python ensembling.py --name mymodel --sets test --index 5 83 | ``` 84 | 85 | #### Pre-trained model 86 | We release pre-trained models to replicate the results as shown in the paper. Models should be placed in the `ckpt` folder. 87 | ``` 88 | mkdir -p ckpt 89 | ``` 90 | 91 | [IEMOCAP 4-class emotions](https://www.dropbox.com/s/wzoiwrtc9m3nb78/IEMOCAP_pretrained.zip?dl=1) 92 | ``` 93 | python ensembling.py --name IEMOCAP_pretrained --index 5 --sets test 94 | 95 | precision recall f1-score support 96 | 97 | 0 0.70 0.66 0.68 384 98 | 1 0.68 0.75 0.71 278 99 | 2 0.79 0.71 0.75 194 100 | 3 0.78 0.81 0.79 229 101 | 102 | accuracy 0.73 1085 103 | macro avg 0.74 0.73 0.73 1085 104 | weighted avg 0.73 0.73 0.73 1085 105 | 106 | Max ensemble w-accuracies for test : 72.53456221198157 107 | ``` 108 | 109 | [MOSEI 2-class sentiment](https://www.dropbox.com/s/t2p8soswt9t1ii4/MOSEI_pretrained.zip?dl=1) 110 | ``` 111 | python ensembling.py --name MOSEI_pretrained --index 9 --sets test 112 | 113 | precision recall f1-score support 114 | 115 | 0 0.75 0.57 0.65 1350 116 | 1 0.84 0.92 0.88 3312 117 | 118 | accuracy 0.82 4662 119 | macro avg 0.80 0.75 0.77 4662 120 | weighted avg 0.82 0.82 0.81 4662 121 | 122 | Max ensemble w-accuracies for test : 82.15358215358215 123 | ``` 124 | 125 | [MOSI 2-class sentiment](https://www.dropbox.com/s/zw4a9ukk1npzt9r/MOSI_pretrained.zip?dl=1) 126 | ``` 127 | python ensembling.py --name MOSI_pretrained --index 2 --sets test 128 | 129 | 130 | precision recall f1-score support 131 | 132 | 0 0.77 0.91 0.84 379 133 | 1 0.84 0.63 0.72 277 134 | 135 | accuracy 0.79 656 136 | macro avg 0.81 0.77 0.78 656 137 | weighted avg 0.80 0.79 0.79 656 138 | 139 | Max ensemble w-accuracies for test : 79.26829268292683 140 | ``` 141 | [MELD 7-class emotions](https://www.dropbox.com/s/458h1ze6cic3h1l/MELD_pretrained.zip?dl=1) 142 | ``` 143 | python ensembling.py --name MELD_pretrained --index 9 --sets test 144 | 145 | 146 | precision recall f1-score support 147 | 148 | 0 0.64 0.52 0.58 1256 149 | 1 0.36 0.58 0.45 281 150 | 2 0.08 0.18 0.11 50 151 | 3 0.23 0.25 0.24 208 152 | 4 0.44 0.47 0.46 402 153 | 5 0.23 0.24 0.23 68 154 | 6 0.31 0.27 0.29 345 155 | 156 | accuracy 0.45 2610 157 | macro avg 0.33 0.36 0.34 2610 158 | weighted avg 0.48 0.45 0.46 2610 159 | ``` 160 | 161 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .meld_dataset import Meld_Dataset 2 | from .mosei_dataset import Mosei_Dataset 3 | from .mosi_dataset import Mosi_Dataset 4 | from .iemocap_dataset import Iemocap_Dataset 5 | from .vgaf_dataset import Vgaf_Dataset -------------------------------------------------------------------------------- /dataloaders/iemocap_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | from utils.plot import plot 7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature 8 | from torch.utils.data import Dataset 9 | 10 | class Iemocap_Dataset(Dataset): 11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'): 12 | super(Iemocap_Dataset, self).__init__() 13 | assert name in ['train', 'valid', 'test'] 14 | self.name = name 15 | self.args = args 16 | self.dataroot = os.path.join(dataroot, 'IEMOCAP') 17 | self.private_set = name == 'private' 18 | 19 | if name == 'train': 20 | name = 'traindev' 21 | if name == 'valid': 22 | name = 'test' 23 | 24 | word_file = os.path.join(self.dataroot, name + "_sentences.p") 25 | audio_file = os.path.join(self.dataroot, name + "_mels.p") 26 | y_s_file = os.path.join(self.dataroot, name + "_emotions.p") 27 | 28 | self.key_to_word = pickle.load(open(word_file, "rb")) 29 | self.key_to_audio = pickle.load(open(audio_file, "rb")) 30 | self.key_to_label = pickle.load(open(y_s_file, "rb")) 31 | self.set = list(self.key_to_label.keys()) 32 | 33 | for key in self.set: 34 | if not (key in self.key_to_word and 35 | key in self.key_to_audio and 36 | key in self.key_to_label): 37 | print("Not present everywhere, removing key ", key) 38 | self.set.remove(key) 39 | 40 | # Plot temporal dimension of feature 41 | # t = [] 42 | # for key in self.key_to_word.keys(): 43 | # x = np.array(self.key_to_word[key]).shape[0] 44 | # t.append(x) 45 | # plot(t) 46 | # sys.exit() 47 | 48 | # Creating embeddings and word indexes 49 | self.key_to_sentence = tokenize(self.key_to_word) 50 | if token_to_ix is not None: 51 | self.token_to_ix = token_to_ix 52 | else: # Train 53 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot) 54 | self.vocab_size = len(self.token_to_ix) 55 | 56 | self.l_max_len = 15 57 | self.a_max_len = 40 58 | 59 | def __getitem__(self, idx): 60 | key = self.set[idx] 61 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len) 62 | A = pad_feature(self.key_to_audio[key], self.a_max_len) 63 | V = np.zeros(1) # not using video, insert dummy 64 | 65 | y = self.key_to_label[key] 66 | y = np.array(y) 67 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y) 68 | 69 | def __len__(self): 70 | return len(self.set) -------------------------------------------------------------------------------- /dataloaders/mosi_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | from utils.plot import plot 7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature 8 | from torch.utils.data import Dataset 9 | 10 | class Mosi_Dataset(Dataset): 11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'): 12 | super(Mosi_Dataset, self).__init__() 13 | assert name in ['train', 'valid', 'test'] 14 | self.name = name 15 | self.args = args 16 | self.dataroot = os.path.join(dataroot, 'MOSI') 17 | self.private_set = name == 'private' 18 | if name == 'train': 19 | name = 'trainval' 20 | if name == 'valid': 21 | name = 'test' 22 | word_file = os.path.join(self.dataroot, name + "_sentences.p") 23 | audio_file = os.path.join(self.dataroot, name + "_mels.p") 24 | y_s_file = os.path.join(self.dataroot, name + "_sentiment.p") 25 | 26 | self.key_to_word = pickle.load(open(word_file, "rb")) 27 | self.key_to_audio = pickle.load(open(audio_file, "rb")) 28 | self.key_to_label = pickle.load(open(y_s_file, "rb")) 29 | self.set = list(self.key_to_label.keys()) 30 | 31 | # filter y = 0 for binary task (https://github.com/A2Zadeh/CMU-MultimodalSDK/tree/master/mmsdk/mmdatasdk/dataset/standard_datasets/CMU_MOSI) 32 | if self.args.task_binary: 33 | for key in self.key_to_label.keys(): 34 | if self.key_to_label[key] == 0.0: 35 | print("2-class Sentiment, removing key ", key) 36 | self.set.remove(key) 37 | 38 | for key in self.set: 39 | if not (key in self.key_to_word and 40 | key in self.key_to_audio and 41 | key in self.key_to_label): 42 | print("Not present everywhere, removing key ", key) 43 | self.set.remove(key) 44 | 45 | # Plot temporal dimension of feature 46 | # t = [] 47 | # for key in self.key_to_word.keys(): 48 | # x = np.array(self.key_to_word[key]).shape[0] 49 | # t.append(x) 50 | # plot(t) 51 | # sys.exit() 52 | 53 | # Creating embeddings and word indexes 54 | self.key_to_sentence = tokenize(self.key_to_word) 55 | if token_to_ix is not None: 56 | self.token_to_ix = token_to_ix 57 | else: # Train 58 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot) 59 | self.vocab_size = len(self.token_to_ix) 60 | 61 | self.l_max_len = 30 62 | self.a_max_len = 60 63 | 64 | def __getitem__(self, idx): 65 | key = self.set[idx] 66 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len) 67 | A = pad_feature(self.key_to_audio[key], self.a_max_len) 68 | V = np.zeros(1) # not using video, insert dummy 69 | 70 | y = self.key_to_label[key] 71 | if self.args.task_binary: 72 | c = 0 if y < 0.0 else 1 73 | else: 74 | c = int(round(y)) + 3 # from -3;3 to 0;6 75 | y = np.array(c) 76 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y) 77 | 78 | def __len__(self): 79 | return len(self.set) -------------------------------------------------------------------------------- /dataloaders/vgaf_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | from utils.plot import plot 7 | from utils.tokenize import tokenize, create_dict, sent_to_ix, pad_feature 8 | from torch.utils.data import Dataset 9 | 10 | class Vgaf_Dataset(Dataset): 11 | def __init__(self, name, args, token_to_ix=None, dataroot='data'): 12 | super(Vgaf_Dataset, self).__init__() 13 | assert name in ['train', 'valid', 'test'] 14 | self.name = name 15 | self.args = args 16 | self.dataroot = os.path.join(dataroot, 'VGAF') 17 | self.private_set = name == 'private' 18 | 19 | if name == 'test': 20 | name = 'valid' 21 | 22 | word_file = os.path.join(self.dataroot, name + "_sentences.p") 23 | audio_file = os.path.join(self.dataroot, name + "_mels.p") 24 | y_file = os.path.join(self.dataroot, name + "_emotions.p") 25 | 26 | self.key_to_word = pickle.load(open(word_file, "rb")) 27 | self.key_to_audio = pickle.load(open(audio_file, "rb")) 28 | self.key_to_label = pickle.load(open(y_file, "rb")) 29 | self.set = list(self.key_to_label.keys()) 30 | 31 | for key in self.set: 32 | if not (key in self.key_to_word and 33 | key in self.key_to_audio and 34 | key in self.key_to_label): 35 | print("Not present everywhere, removing key ", key) 36 | self.set.remove(key) 37 | 38 | # Plot temporal dimension of feature 39 | # t = [] 40 | # for key in self.key_to_word.keys(): 41 | # x = np.array(self.key_to_word[key]).shape[0] 42 | # t.append(x) 43 | # print(max(t)) 44 | # plot(t) 45 | # sys.exit() 46 | 47 | # Creating embeddings and word indexes 48 | self.key_to_sentence = tokenize(self.key_to_word) 49 | if token_to_ix is not None: 50 | self.token_to_ix = token_to_ix 51 | else: # Train 52 | self.token_to_ix, self.pretrained_emb = create_dict(self.key_to_sentence, self.dataroot) 53 | self.vocab_size = len(self.token_to_ix) 54 | 55 | self.l_max_len = 30 56 | self.a_max_len = 26 57 | 58 | def __getitem__(self, idx): 59 | key = self.set[idx] 60 | L = sent_to_ix(self.key_to_sentence[key], self.token_to_ix, max_token=self.l_max_len) 61 | A = pad_feature(self.key_to_audio[key], self.a_max_len) 62 | V = np.zeros(1) # not using video, insert dummy 63 | 64 | y = self.key_to_label[key] 65 | y = np.array(int(y)-1) #from 1,3 to 0,2 66 | return key, torch.from_numpy(L), torch.from_numpy(A), torch.from_numpy(V).float(), torch.from_numpy(y) 67 | 68 | def __len__(self): 69 | return len(self.set) -------------------------------------------------------------------------------- /ensembling.py: -------------------------------------------------------------------------------- 1 | import argparse, os, glob, warnings, torch 2 | import numpy as np 3 | from utils.pred_func import * 4 | from sklearn.metrics import classification_report 5 | from torch.utils.data import DataLoader 6 | from dataloaders import * 7 | from models import * 8 | from utils.compute_args import compute_args 9 | from train import evaluate 10 | 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--output', type=str, default='ckpt/') 17 | parser.add_argument('--name', type=str, default='exp0/') 18 | parser.add_argument('--sets', nargs='+', default=["valid", "test"]) 19 | 20 | parser.add_argument('--index', type=int, default=99) 21 | parser.add_argument('--private_set', type=str, default=None) 22 | 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | if __name__ == '__main__': 28 | args = parse_args() 29 | 30 | # Save vars 31 | private_set = args.private_set 32 | index = args.index 33 | sets = args.sets 34 | 35 | # Listing sorted checkpoints 36 | ckpts = sorted(glob.glob(os.path.join(args.output, args.name,'best*')), reverse=True) 37 | 38 | # Load original args 39 | args = torch.load(ckpts[0])['args'] 40 | args = compute_args(args) 41 | 42 | 43 | # Define the splits to be evaluated 44 | evaluation_sets = list(sets) + ([private_set] if private_set is not None else []) 45 | print("Evaluated sets: ", str(evaluation_sets)) 46 | # Creating dataloader 47 | train_dset = eval(args.dataloader)('train', args) 48 | loaders = {set: DataLoader(eval(args.dataloader)(set, args, train_dset.token_to_ix), 49 | args.batch_size, 50 | num_workers=8, 51 | pin_memory=True) for set in evaluation_sets} 52 | 53 | # Creating net 54 | net = eval(args.model)(args, train_dset.vocab_size, train_dset.pretrained_emb).cuda() 55 | 56 | # Ensembling sets 57 | ensemble_preds = {set: {} for set in evaluation_sets} 58 | ensemble_accuracies = {set: [] for set in evaluation_sets} 59 | 60 | # Iterating over checkpoints 61 | for i, ckpt in enumerate(ckpts): 62 | 63 | if i >= index: 64 | break 65 | 66 | print("###### Ensembling " + str(i+1)) 67 | state_dict = torch.load(ckpt)['state_dict'] 68 | net.load_state_dict(state_dict) 69 | 70 | # Evaluation per checkpoint predictions 71 | for set in evaluation_sets: 72 | accuracy, preds = evaluate(net, loaders[set], args) 73 | print('Accuracy for ' + set + ' for model ' + ckpt + ":", accuracy) 74 | for id, pred in preds.items(): 75 | if id not in ensemble_preds[set]: 76 | ensemble_preds[set][id] = [] 77 | ensemble_preds[set][id].append(pred) 78 | 79 | # Compute set ensembling accuracy 80 | # Get all ids and answers 81 | ids = [id for ids, _, _, _, _ in loaders[set] for id in ids] 82 | ans = [np.array(a) for _, _, _, _, ans in loaders[set] for a in ans] 83 | 84 | # for all id, get averaged probabilities 85 | avg_preds = np.array([np.mean(np.array(ensemble_preds[set][id]), axis=0) for id in ids]) 86 | # Compute accuracies 87 | if set != private_set: 88 | accuracy = np.mean(eval(args.pred_func)(avg_preds) == ans) * 100 89 | print("New " + set + " ens. Accuracy :", accuracy) 90 | ensemble_accuracies[set].append(accuracy) 91 | 92 | if i + 1 == index: 93 | print(classification_report(ans, eval(args.pred_func)(avg_preds))) 94 | 95 | # Printing overall results 96 | for set in sets: 97 | print("Max ensemble w-accuracies for " + set + " : " + str(max(ensemble_accuracies[set]))) -------------------------------------------------------------------------------- /layers/fc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class FC(nn.Module): 4 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 5 | super(FC, self).__init__() 6 | self.dropout_r = dropout_r 7 | self.use_relu = use_relu 8 | 9 | self.linear = nn.Linear(in_size, out_size) 10 | 11 | if use_relu: 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | if dropout_r > 0: 15 | self.dropout = nn.Dropout(dropout_r) 16 | 17 | def forward(self, x): 18 | x = self.linear(x) 19 | 20 | if self.use_relu: 21 | x = self.relu(x) 22 | 23 | if self.dropout_r > 0: 24 | x = self.dropout(x) 25 | 26 | return x 27 | 28 | 29 | class MLP(nn.Module): 30 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 31 | super(MLP, self).__init__() 32 | 33 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 34 | self.linear = nn.Linear(mid_size, out_size) 35 | 36 | def forward(self, x): 37 | return self.linear(self.fc(x)) 38 | -------------------------------------------------------------------------------- /layers/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class LayerNorm(nn.Module): 5 | def __init__(self, size, eps=1e-6): 6 | super(LayerNorm, self).__init__() 7 | self.eps = eps 8 | 9 | self.a_2 = nn.Parameter(torch.ones(size)) 10 | self.b_2 = nn.Parameter(torch.zeros(size)) 11 | 12 | def forward(self, x): 13 | mean = x.mean(-1, keepdim=True) 14 | std = x.std(-1, keepdim=True) 15 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 16 | 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse, os, random 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from dataloaders import * 6 | from models import * 7 | from train import train 8 | from utils.compute_args import compute_args 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | # Model 13 | parser.add_argument('--model', type=str, default="Model_MCAN") 14 | parser.add_argument('--layer', type=int, default=4) 15 | parser.add_argument('--hidden_size', type=int, default=512) 16 | parser.add_argument('--dropout_i', type=float, default=0.0) 17 | parser.add_argument('--dropout_r', type=float, default=0.1) 18 | parser.add_argument('--dropout_o', type=float, default=0.5) 19 | parser.add_argument('--multi_head', type=int, default=8) 20 | parser.add_argument('--ff_size', type=int, default=2048) 21 | parser.add_argument('--word_embed_size', type=int, default=300) 22 | parser.add_argument('--bidirectional', type=bool, default=False) 23 | 24 | # Data 25 | parser.add_argument('--lang_seq_len', type=int, default=50) 26 | parser.add_argument('--audio_seq_len', type=int, default=50) 27 | parser.add_argument('--video_seq_len', type=int, default=60) 28 | parser.add_argument('--audio_feat_size', type=int, default=80) 29 | parser.add_argument('--video_feat_size', type=int, default=512) 30 | 31 | # Training 32 | parser.add_argument('--output', type=str, default='ckpt/') 33 | parser.add_argument('--name', type=str, default='exp0/') 34 | parser.add_argument('--batch_size', type=int, default=64) 35 | parser.add_argument('--max_epoch', type=int, default=99) 36 | parser.add_argument('--lr_base', type=float, default=0.00005) 37 | parser.add_argument('--lr_decay', type=float, default=0.2) 38 | parser.add_argument('--lr_decay_times', type=int, default=2) 39 | parser.add_argument('--grad_norm_clip', type=float, default=-1) 40 | parser.add_argument('--eval_start', type=int, default=0) 41 | parser.add_argument('--early_stop', type=int, default=3) 42 | parser.add_argument('--seed', type=int, default=random.randint(0, 9999999)) 43 | 44 | # Dataset and task 45 | parser.add_argument('--dataset', type=str, choices=['MELD', 'MOSEI', 'MOSI', 'IEMOCAP', 'VGAF'], default='MOSEI') 46 | parser.add_argument('--task', type=str, choices=['sentiment', 'emotion'], default='sentiment') 47 | parser.add_argument('--task_binary', type=bool, default=False) 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | if __name__ == '__main__': 54 | # Base on args given, compute new args 55 | args = compute_args(parse_args()) 56 | 57 | # Seed 58 | torch.manual_seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.backends.cudnn.deterministic = True 61 | torch.backends.cudnn.benchmark = False 62 | 63 | # DataLoader 64 | train_dset = eval(args.dataloader)('train', args) 65 | eval_dset = eval(args.dataloader)('valid', args, train_dset.token_to_ix) 66 | 67 | train_loader = DataLoader(train_dset, args.batch_size, shuffle=True, num_workers=8, pin_memory=True) 68 | eval_loader = DataLoader(eval_dset, args.batch_size, num_workers=8, pin_memory=True) 69 | 70 | # Net 71 | net = eval(args.model)(args, train_dset.vocab_size, train_dset.pretrained_emb).cuda() 72 | print("Total number of parameters : " + str(sum([p.numel() for p in net.parameters()]) / 1e6) + "M") 73 | net = net.cuda() 74 | 75 | # Create Checkpoint dir 76 | if not os.path.exists(os.path.join(args.output, args.name)): 77 | os.makedirs(os.path.join(args.output, args.name)) 78 | 79 | # Run training 80 | eval_accuracies = train(net, train_loader, eval_loader, args) 81 | open('best_scores.txt', 'a+').write(args.output + "/" + args.name + "," 82 | + str(max(eval_accuracies)) + "\n") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_MAT import Model_MAT 2 | from .model_MNT import Model_MNT -------------------------------------------------------------------------------- /models/model_MAT.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers.fc import MLP, FC 7 | from layers.layer_norm import LayerNorm 8 | 9 | # ------------------------------------ 10 | # ---------- Masking sequence -------- 11 | # ------------------------------------ 12 | def make_mask(feature): 13 | return (torch.sum( 14 | torch.abs(feature), 15 | dim=-1 16 | ) == 0).unsqueeze(1).unsqueeze(2) 17 | 18 | 19 | 20 | class AttFlat(nn.Module): 21 | def __init__(self, args): 22 | super(AttFlat, self).__init__() 23 | self.args = args 24 | self.flat_glimpse = 1 25 | 26 | self.mlp = MLP( 27 | in_size=args.hidden_size, 28 | mid_size=args.ff_size, 29 | out_size=self.flat_glimpse, 30 | dropout_r=args.dropout_r, 31 | use_relu=True 32 | ) 33 | 34 | self.linear_merge = nn.Linear( 35 | args.hidden_size * self.flat_glimpse, 36 | args.hidden_size * 2 37 | ) 38 | 39 | def forward(self, x, x_mask): 40 | att = self.mlp(x) 41 | att = att.masked_fill( 42 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 43 | -1e9 44 | ) 45 | att = F.softmax(att, dim=1) 46 | 47 | att_list = [] 48 | for i in range(self.flat_glimpse): 49 | att_list.append( 50 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 51 | ) 52 | 53 | x_atted = torch.cat(att_list, dim=1) 54 | x_atted = self.linear_merge(x_atted) 55 | 56 | return x_atted 57 | 58 | 59 | 60 | class MHAtt(nn.Module): 61 | def __init__(self, args): 62 | super(MHAtt, self).__init__() 63 | self.args = args 64 | 65 | self.linear_v = nn.Linear(args.hidden_size, args.hidden_size) 66 | self.linear_k = nn.Linear(args.hidden_size, args.hidden_size) 67 | self.linear_q = nn.Linear(args.hidden_size, args.hidden_size) 68 | self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size) 69 | 70 | self.dropout = nn.Dropout(args.dropout_r) 71 | 72 | def forward(self, v, k, q, mask): 73 | n_batches = q.size(0) 74 | 75 | v = self.linear_v(v).view( 76 | n_batches, 77 | -1, 78 | self.args.multi_head, 79 | int(self.args.hidden_size / self.args.multi_head) 80 | ).transpose(1, 2) 81 | 82 | k = self.linear_k(k).view( 83 | n_batches, 84 | -1, 85 | self.args.multi_head, 86 | int(self.args.hidden_size / self.args.multi_head) 87 | ).transpose(1, 2) 88 | 89 | q = self.linear_q(q).view( 90 | n_batches, 91 | -1, 92 | self.args.multi_head, 93 | int(self.args.hidden_size / self.args.multi_head) 94 | ).transpose(1, 2) 95 | 96 | atted = self.att(v, k, q, mask) 97 | atted = atted.transpose(1, 2).contiguous().view( 98 | n_batches, 99 | -1, 100 | self.args.hidden_size 101 | ) 102 | 103 | atted = self.linear_merge(atted) 104 | 105 | return atted 106 | 107 | def att(self, value, key, query, mask): 108 | d_k = query.size(-1) 109 | 110 | scores = torch.matmul( 111 | query, key.transpose(-2, -1) 112 | ) / math.sqrt(d_k) 113 | 114 | if mask is not None: 115 | scores = scores.masked_fill(mask, -1e9) 116 | 117 | att_map = F.softmax(scores, dim=-1) 118 | att_map = self.dropout(att_map) 119 | 120 | return torch.matmul(att_map, value) 121 | 122 | 123 | # --------------------------- 124 | # ---- Feed Forward Nets ---- 125 | # --------------------------- 126 | 127 | class FFN(nn.Module): 128 | def __init__(self, args): 129 | super(FFN, self).__init__() 130 | 131 | self.mlp = MLP( 132 | in_size=args.hidden_size, 133 | mid_size=args.ff_size, 134 | out_size=args.hidden_size, 135 | dropout_r=args.dropout_r, 136 | use_relu=True 137 | ) 138 | 139 | def forward(self, x): 140 | return self.mlp(x) 141 | 142 | 143 | # ------------------------ 144 | # ---- Self Attention ---- 145 | # ------------------------ 146 | 147 | class SA(nn.Module): 148 | def __init__(self, args): 149 | super(SA, self).__init__() 150 | 151 | self.mhatt = MHAtt(args) 152 | self.ffn = FFN(args) 153 | 154 | self.dropout1 = nn.Dropout(args.dropout_r) 155 | self.norm1 = LayerNorm(args.hidden_size) 156 | 157 | self.dropout2 = nn.Dropout(args.dropout_r) 158 | self.norm2 = LayerNorm(args.hidden_size) 159 | 160 | def forward(self, y, y_mask): 161 | y = self.norm1(y + self.dropout1( 162 | self.mhatt(y, y, y, y_mask) 163 | )) 164 | 165 | y = self.norm2(y + self.dropout2( 166 | self.ffn(y) 167 | )) 168 | 169 | return y 170 | 171 | 172 | # ------------------------------- 173 | # ---- Self Guided Attention ---- 174 | # ------------------------------- 175 | 176 | class SGA(nn.Module): 177 | def __init__(self, args): 178 | super(SGA, self).__init__() 179 | 180 | self.mhatt1 = MHAtt(args) 181 | self.mhatt2 = MHAtt(args) 182 | self.ffn = FFN(args) 183 | 184 | self.dropout1 = nn.Dropout(args.dropout_r) 185 | self.norm1 = LayerNorm(args.hidden_size) 186 | 187 | self.dropout2 = nn.Dropout(args.dropout_r) 188 | self.norm2 = LayerNorm(args.hidden_size) 189 | 190 | self.dropout3 = nn.Dropout(args.dropout_r) 191 | self.norm3 = LayerNorm(args.hidden_size) 192 | 193 | def forward(self, x, y, x_mask, y_mask): 194 | x = self.norm1(x + self.dropout1( 195 | self.mhatt1(v=x, k=x, q=x, mask=x_mask) 196 | )) 197 | 198 | x = self.norm2(x + self.dropout2( 199 | self.mhatt2(v=y, k=y, q=x, mask=y_mask) 200 | )) 201 | 202 | x = self.norm3(x + self.dropout3( 203 | self.ffn(x) 204 | )) 205 | 206 | return x 207 | 208 | 209 | # ------------------------------------------------ 210 | # ---- MAC Layers Cascaded by Encoder-Decoder ---- 211 | # ------------------------------------------------ 212 | 213 | class MCA_ED(nn.Module): 214 | def __init__(self, args): 215 | super(MCA_ED, self).__init__() 216 | 217 | self.enc_list = nn.ModuleList([SA(args) for _ in range(args.layer)]) 218 | self.dec_list = nn.ModuleList([SGA(args) for _ in range(args.layer)]) 219 | 220 | def forward(self, y, x, y_mask, x_mask): 221 | # Get encoder last hidden vector 222 | for enc in self.enc_list: 223 | y = enc(y, y_mask) 224 | 225 | # Input encoder last hidden vector 226 | # And obtain decoder last hidden vectors 227 | for dec in self.dec_list: 228 | x = dec(x, y, x_mask, y_mask) 229 | 230 | return y, x 231 | 232 | 233 | 234 | # ------------------------- 235 | # ---- Main MCAN Model ---- 236 | # ------------------------- 237 | 238 | class Model_MAT(nn.Module): 239 | def __init__(self, args, vocab_size, pretrained_emb): 240 | super(Model_MAT, self).__init__() 241 | self.args = args 242 | 243 | # LSTM 244 | self.embedding = nn.Embedding( 245 | num_embeddings=vocab_size, 246 | embedding_dim=args.word_embed_size 247 | ) 248 | 249 | # Loading the GloVe embedding weights 250 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 251 | self.input_drop = nn.Dropout(args.dropout_i) 252 | 253 | self.lstm_x = nn.LSTM( 254 | input_size=args.word_embed_size, 255 | hidden_size=args.hidden_size, 256 | num_layers=1, 257 | batch_first=True 258 | ) 259 | 260 | self.lstm_y = nn.LSTM( 261 | input_size=args.audio_feat_size, 262 | hidden_size=args.hidden_size, 263 | num_layers=1, 264 | batch_first=True 265 | ) 266 | 267 | # self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size) 268 | self.backbone = MCA_ED(args) 269 | 270 | # Flatten to vector 271 | self.attflat_img = AttFlat(args) 272 | self.attflat_lang = AttFlat(args) 273 | 274 | # Classification layers 275 | 276 | self.proj_norm = LayerNorm(2 * args.hidden_size) 277 | self.proj = nn.Linear(2 * args.hidden_size, args.ans_size) 278 | self.proj_drop = nn.Dropout(args.dropout_o) 279 | 280 | def forward(self, x, y, _): 281 | x_mask = make_mask(x.unsqueeze(2)) 282 | y_mask = make_mask(y) 283 | 284 | embedding = self.embedding(x) 285 | 286 | x, _ = self.lstm_x(self.input_drop(embedding)) 287 | y, _ = self.lstm_y(self.input_drop(y)) 288 | 289 | # Backbone Framework 290 | lang_feat, img_feat = self.backbone( 291 | x, 292 | y, 293 | x_mask, 294 | y_mask 295 | ) 296 | 297 | # Flatten to vector 298 | lang_feat = self.attflat_lang( 299 | lang_feat, 300 | x_mask 301 | ) 302 | 303 | img_feat = self.attflat_img( 304 | img_feat, 305 | y_mask 306 | ) 307 | 308 | # Classification layers 309 | proj_feat = lang_feat + img_feat 310 | proj_feat = self.proj_norm(proj_feat) 311 | 312 | proj_feat = self.proj_drop(proj_feat) 313 | proj_feat = self.proj(proj_feat) 314 | 315 | return proj_feat -------------------------------------------------------------------------------- /models/model_MNT.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers.fc import MLP, FC 7 | 8 | class LayerNorm(nn.Module): 9 | def __init__(self, size, eps=1e-6): 10 | super(LayerNorm, self).__init__() 11 | self.eps = eps 12 | 13 | self.a_2 = nn.Parameter(torch.ones(size)) 14 | self.b_2 = nn.Parameter(torch.zeros(size)) 15 | 16 | def forward(self, x, ab=None): 17 | mean = x.mean(-1, keepdim=True) 18 | std = x.std(-1, keepdim=True) 19 | 20 | if ab is not None: 21 | cond = torch.chunk(ab, 2, dim=-1) 22 | g = cond[0] 23 | b = cond[1] 24 | return (self.a_2 + torch.unsqueeze(g, 1)) * (x - mean) / (std + self.eps) + (self.b_2 + torch.unsqueeze(b, 1)) 25 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 26 | 27 | # ------------------------------------ 28 | # ---------- Masking sequence -------- 29 | # ------------------------------------ 30 | def make_mask(feature): 31 | return (torch.sum( 32 | torch.abs(feature), 33 | dim=-1 34 | ) == 0).unsqueeze(1).unsqueeze(2) 35 | 36 | 37 | 38 | class AttFlat(nn.Module): 39 | def __init__(self, args): 40 | super(AttFlat, self).__init__() 41 | self.args = args 42 | self.flat_glimpse = 1 43 | 44 | self.mlp = MLP( 45 | in_size=args.hidden_size, 46 | mid_size=args.ff_size, 47 | out_size=self.flat_glimpse, 48 | dropout_r=args.dropout_r, 49 | use_relu=True 50 | ) 51 | 52 | self.linear_merge = nn.Linear( 53 | args.hidden_size * self.flat_glimpse, 54 | args.hidden_size 55 | ) 56 | 57 | def forward(self, x, x_mask): 58 | att = self.mlp(x) 59 | att = att.masked_fill( 60 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 61 | -1e9 62 | ) 63 | att = F.softmax(att, dim=1) 64 | 65 | att_list = [] 66 | for i in range(self.flat_glimpse): 67 | att_list.append( 68 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 69 | ) 70 | 71 | x_atted = torch.cat(att_list, dim=1) 72 | x_atted = self.linear_merge(x_atted) 73 | 74 | return x_atted 75 | 76 | 77 | 78 | class MHAtt(nn.Module): 79 | def __init__(self, args): 80 | super(MHAtt, self).__init__() 81 | self.args = args 82 | 83 | self.linear_v = nn.Linear(args.hidden_size, args.hidden_size) 84 | self.linear_k = nn.Linear(args.hidden_size, args.hidden_size) 85 | self.linear_q = nn.Linear(args.hidden_size, args.hidden_size) 86 | self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size) 87 | 88 | self.dropout = nn.Dropout(args.dropout_r) 89 | 90 | def forward(self, v, k, q, mask): 91 | n_batches = q.size(0) 92 | 93 | v = self.linear_v(v).view( 94 | n_batches, 95 | -1, 96 | self.args.multi_head, 97 | int(self.args.hidden_size / self.args.multi_head) 98 | ).transpose(1, 2) 99 | 100 | k = self.linear_k(k).view( 101 | n_batches, 102 | -1, 103 | self.args.multi_head, 104 | int(self.args.hidden_size / self.args.multi_head) 105 | ).transpose(1, 2) 106 | 107 | q = self.linear_q(q).view( 108 | n_batches, 109 | -1, 110 | self.args.multi_head, 111 | int(self.args.hidden_size / self.args.multi_head) 112 | ).transpose(1, 2) 113 | 114 | atted = self.att(v, k, q, mask) 115 | atted = atted.transpose(1, 2).contiguous().view( 116 | n_batches, 117 | -1, 118 | self.args.hidden_size 119 | ) 120 | 121 | atted = self.linear_merge(atted) 122 | 123 | return atted 124 | 125 | def att(self, value, key, query, mask): 126 | d_k = query.size(-1) 127 | 128 | scores = torch.matmul( 129 | query, key.transpose(-2, -1) 130 | ) / math.sqrt(d_k) 131 | 132 | if mask is not None: 133 | scores = scores.masked_fill(mask, -1e9) 134 | 135 | att_map = F.softmax(scores, dim=-1) 136 | att_map = self.dropout(att_map) 137 | 138 | return torch.matmul(att_map, value) 139 | 140 | 141 | # --------------------------- 142 | # ---- Feed Forward Nets ---- 143 | # --------------------------- 144 | 145 | class FFN(nn.Module): 146 | def __init__(self, args): 147 | super(FFN, self).__init__() 148 | 149 | self.mlp = MLP( 150 | in_size=args.hidden_size, 151 | mid_size=args.ff_size, 152 | out_size=args.hidden_size, 153 | dropout_r=args.dropout_r, 154 | use_relu=True 155 | ) 156 | 157 | def forward(self, x): 158 | return self.mlp(x) 159 | 160 | 161 | # ------------------------ 162 | # ---- Self Attention ---- 163 | # ------------------------ 164 | 165 | class SA(nn.Module): 166 | def __init__(self, args): 167 | super(SA, self).__init__() 168 | 169 | self.mhatt = MHAtt(args) 170 | self.ffn = FFN(args) 171 | 172 | self.dropout1 = nn.Dropout(args.dropout_r) 173 | self.norm1 = LayerNorm(args.hidden_size) 174 | 175 | self.dropout2 = nn.Dropout(args.dropout_r) 176 | self.norm2 = LayerNorm(args.hidden_size) 177 | 178 | def forward(self, y, y_mask): 179 | y = self.norm1(y + self.dropout1( 180 | self.mhatt(y, y, y, y_mask) 181 | )) 182 | 183 | y = self.norm2(y + self.dropout2( 184 | self.ffn(y) 185 | )) 186 | 187 | return y 188 | 189 | 190 | class SAG(nn.Module): 191 | def __init__(self, args): 192 | super(SAG, self).__init__() 193 | 194 | self.mhatt = MHAtt(args) 195 | self.ffn = FFN(args) 196 | 197 | self.dropout1 = nn.Dropout(args.dropout_r) 198 | self.norm1 = LayerNorm(args.hidden_size) 199 | 200 | self.dropout2 = nn.Dropout(args.dropout_r) 201 | self.norm2 = LayerNorm(args.hidden_size) 202 | 203 | def forward(self, y, y_mask, cond): 204 | cond = torch.chunk(cond, 2, dim=-1) 205 | y = self.norm1(y + self.dropout1( 206 | self.mhatt(y, y, y, y_mask) 207 | ), cond[0]) 208 | 209 | y = self.norm2(y + self.dropout2( 210 | self.ffn(y) 211 | ), cond[1]) 212 | 213 | return y 214 | 215 | 216 | # ------------------------- 217 | # ---- Main MCAN Model ---- 218 | # ------------------------- 219 | 220 | class Model_MNT(nn.Module): 221 | def __init__(self, args, vocab_size, pretrained_emb): 222 | super(Model_MNT, self).__init__() 223 | self.args = args 224 | 225 | # LSTM 226 | self.embedding = nn.Embedding( 227 | num_embeddings=vocab_size, 228 | embedding_dim=args.word_embed_size 229 | ) 230 | 231 | # Loading the GloVe embedding weights 232 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 233 | self.input_drop = nn.Dropout(args.dropout_i) 234 | 235 | self.lstm_x = nn.LSTM( 236 | input_size=args.word_embed_size, 237 | hidden_size=args.hidden_size, 238 | num_layers=1, 239 | batch_first=True 240 | ) 241 | 242 | self.lstm_y = nn.LSTM( 243 | input_size=args.audio_feat_size, 244 | hidden_size=args.hidden_size, 245 | num_layers=1, 246 | batch_first=True 247 | ) 248 | 249 | # self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size) 250 | self.fc = FC(args.hidden_size, args.hidden_size * args.layer * 2 * 2) 251 | self.enc_list = nn.ModuleList([SA(args) for _ in range(args.layer)]) 252 | self.dec_list = nn.ModuleList([SAG(args) for _ in range(args.layer)]) 253 | 254 | # Flatten to vector 255 | self.attflat_img = AttFlat(args) 256 | self.attflat_lang = AttFlat(args) 257 | 258 | # Classification layers 259 | 260 | self.proj_norm = LayerNorm(args.hidden_size) 261 | self.proj = nn.Linear(args.hidden_size, args.ans_size) 262 | self.proj_drop = nn.Dropout(args.dropout_o) 263 | 264 | def forward(self, x, y, _): 265 | x_mask = make_mask(x.unsqueeze(2)) 266 | y_mask = make_mask(y) 267 | 268 | embedding = self.embedding(x) 269 | 270 | x, _ = self.lstm_x(self.input_drop(embedding)) 271 | y, _ = self.lstm_y(self.input_drop(y)) 272 | 273 | # Backbone Framework 274 | for enc in self.enc_list: 275 | x = enc(x, x_mask) 276 | 277 | lang_feat = self.attflat_lang( 278 | x, 279 | x_mask 280 | ) 281 | 282 | cond = torch.chunk(self.fc(lang_feat), self.args.layer, dim=-1) 283 | 284 | for i, dec in enumerate(self.dec_list): 285 | y = dec(y, y_mask, cond[i]) 286 | 287 | img_feat = self.attflat_img( 288 | y, 289 | y_mask 290 | ) 291 | 292 | # Classification layers 293 | proj_feat = lang_feat + img_feat 294 | proj_feat = self.proj_norm(proj_feat) 295 | 296 | proj_feat = self.proj_drop(proj_feat) 297 | proj_feat = self.proj(proj_feat) 298 | 299 | return proj_feat -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import os 5 | from utils.pred_func import * 6 | 7 | def train(net, train_loader, eval_loader, args): 8 | 9 | logfile = open( 10 | args.output + "/" + args.name + 11 | '/log_run_' + str(args.seed) + '.txt', 12 | 'w+' 13 | ) 14 | logfile.write(str(args)) 15 | 16 | best_eval_accuracy = 0.0 17 | early_stop = 0 18 | decay_count = 0 19 | 20 | # Load the optimizer paramters 21 | optim = torch.optim.Adam(net.parameters(), lr=args.lr_base) 22 | 23 | loss_fn = args.loss_fn 24 | eval_accuracies = [] 25 | for epoch in range(0, args.max_epoch): 26 | time_start = time.time() 27 | loss_sum = 0 28 | for step, ( 29 | id, 30 | x, 31 | y, 32 | z, 33 | ans, 34 | ) in enumerate(train_loader): 35 | loss_tmp = 0 36 | optim.zero_grad() 37 | 38 | x = x.cuda() 39 | y = y.cuda() 40 | z = z.cuda() 41 | ans = ans.cuda() 42 | 43 | pred = net(x, y, z) 44 | loss = loss_fn(pred, ans) 45 | loss.backward() 46 | 47 | loss_sum += loss.cpu().data.numpy() 48 | loss_tmp += loss.cpu().data.numpy() 49 | 50 | print("\r[Epoch %2d][Step %4d/%4d] Loss: %.4f, Lr: %.2e, %4d m " 51 | "remaining" % ( 52 | epoch + 1, 53 | step, 54 | int(len(train_loader.dataset) / args.batch_size), 55 | loss_tmp / args.batch_size, 56 | *[group['lr'] for group in optim.param_groups], 57 | ((time.time() - time_start) / (step + 1)) * ((len(train_loader.dataset) / args.batch_size) - step) / 60, 58 | ), end=' ') 59 | 60 | # Gradient norm clipping 61 | if args.grad_norm_clip > 0: 62 | nn.utils.clip_grad_norm_( 63 | net.parameters(), 64 | args.grad_norm_clip 65 | ) 66 | 67 | optim.step() 68 | 69 | time_end = time.time() 70 | elapse_time = time_end-time_start 71 | print('Finished in {}s'.format(int(elapse_time))) 72 | epoch_finish = epoch + 1 73 | 74 | # Logging 75 | logfile.write( 76 | 'Epoch: ' + str(epoch_finish) + 77 | ', Loss: ' + str(loss_sum / len(train_loader.dataset)) + 78 | ', Lr: ' + str([group['lr'] for group in optim.param_groups]) + '\n' + 79 | 'Elapsed time: ' + str(int(elapse_time)) + 80 | ', Speed(s/batch): ' + str(elapse_time / step) + 81 | '\n\n' 82 | ) 83 | 84 | # Eval 85 | if epoch_finish >= args.eval_start: 86 | print('Evaluation...') 87 | accuracy, _ = evaluate(net, eval_loader, args) 88 | print('Accuracy :'+str(accuracy)) 89 | eval_accuracies.append(accuracy) 90 | if accuracy > best_eval_accuracy: 91 | # Best 92 | state = { 93 | 'state_dict': net.state_dict(), 94 | 'optimizer': optim.state_dict(), 95 | 'args': args, 96 | } 97 | torch.save( 98 | state, 99 | args.output + "/" + args.name + 100 | '/best'+str(args.seed)+'.pkl' 101 | ) 102 | best_eval_accuracy = accuracy 103 | early_stop = 0 104 | 105 | elif decay_count < args.lr_decay_times: 106 | # Decay 107 | print('LR Decay...') 108 | decay_count += 1 109 | 110 | ckpt = torch.load(args.output + "/" + args.name + 111 | '/best'+str(args.seed)+'.pkl') 112 | net.load_state_dict(ckpt['state_dict']) 113 | optim.load_state_dict(ckpt['optimizer']) 114 | 115 | # adjust_lr(optim, args.lr_decay) 116 | for group in optim.param_groups: 117 | group['lr'] = (args.lr_base * args.lr_decay**decay_count) 118 | else: 119 | # Early stop, does not start before lr_decay_times reached 120 | early_stop += 1 121 | 122 | if early_stop == args.early_stop: 123 | logfile.write('Early stop reached' + '\n') 124 | print('Early stop reached') 125 | break 126 | 127 | logfile.write('best_acc :' + str(best_eval_accuracy) + '\n\n') 128 | print('best_eval_acc :' + str(best_eval_accuracy) + '\n\n') 129 | os.rename(args.output + "/" + args.name + 130 | '/best' + str(args.seed) + '.pkl', 131 | args.output + "/" + args.name + 132 | '/best' + str(best_eval_accuracy) + "_" + str(args.seed) + '.pkl') 133 | logfile.close() 134 | return eval_accuracies 135 | 136 | 137 | def evaluate(net, eval_loader, args): 138 | accuracy = [] 139 | net.train(False) 140 | preds = {} 141 | for step, ( 142 | ids, 143 | x, 144 | y, 145 | z, 146 | ans, 147 | ) in enumerate(eval_loader): 148 | x = x.cuda() 149 | y = y.cuda() 150 | z = z.cuda() 151 | pred = net(x, y, z).cpu().data.numpy() 152 | 153 | if not eval_loader.dataset.private_set: 154 | ans = ans.cpu().data.numpy() 155 | accuracy += list(eval(args.pred_func)(pred) == ans) 156 | 157 | # Save preds 158 | for id, p in zip(ids, pred): 159 | preds[id] = p 160 | 161 | net.train(True) 162 | return 100*np.mean(np.array(accuracy)), preds 163 | 164 | -------------------------------------------------------------------------------- /utils/compute_args.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def compute_args(args): 5 | 6 | # DataLoader 7 | if args.dataset == "MOSEI": args.dataloader = 'Mosei_Dataset' 8 | if args.dataset == "MELD": args.dataloader = 'Meld_Dataset' 9 | if args.dataset == "MOSI": args.dataloader = 'Mosi_Dataset' 10 | if args.dataset == "IEMOCAP": args.dataloader = 'Iemocap_Dataset' 11 | if args.dataset == "VGAF": args.dataloader = 'Vgaf_Dataset' 12 | 13 | # Loss function to use 14 | if args.dataset == 'MOSEI' and args.task == 'sentiment': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda() 15 | if args.dataset == 'MOSEI' and args.task == 'emotion': args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum").cuda() 16 | if args.dataset == 'MELD' and args.task == 'sentiment': args.loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 1.6]), reduction="sum").cuda() 17 | if args.dataset == 'MELD' and args.task == 'emotion': args.loss_fn = torch.nn.CrossEntropyLoss( 18 | weight=torch.tensor([1.0, 4709/1205, 4709/268, 4709/683, 4709/1743, 4709/271, 4709/1109]), 19 | reduction="sum").cuda() 20 | 21 | if args.dataset == 'MOSI': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda() 22 | if args.dataset == "IEMOCAP": args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda() 23 | if args.dataset == "VGAF": args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum").cuda() 24 | 25 | # Answer size 26 | if args.dataset == 'MOSEI' and args.task == "sentiment": args.ans_size = 7 27 | if args.dataset == 'MOSEI' and args.task == "sentiment" and args.task_binary: args.ans_size = 2 28 | if args.dataset == 'MOSEI' and args.task == "emotion": args.ans_size = 6 29 | if args.dataset == 'MELD' and args.task == "emotion": args.ans_size = 7 30 | if args.dataset == 'MELD' and args.task == "sentiment": args.ans_size = 3 31 | if args.dataset == 'MOSI': args.ans_size = 2 32 | if args.dataset == "IEMOCAP": args.ans_size = 4 33 | if args.dataset == "VGAF": args.ans_size = 3 34 | 35 | # Pred function 36 | if args.dataset == 'MOSEI': args.pred_func = "amax" 37 | if args.dataset == 'MOSEI' and args.task == "emotion": args.pred_func = "multi_label" 38 | if args.dataset == 'MELD': args.pred_func = "amax" 39 | if args.dataset == 'MOSI': args.pred_func = "amax" 40 | if args.dataset == 'IEMOCAP': args.pred_func = "amax" 41 | if args.dataset == "VGAF": args.pred_func = "amax" 42 | 43 | return args -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as Optim 2 | 3 | 4 | class WarmupOptimizer(object): 5 | def __init__(self, lr_base, optimizer, data_size, batch_size, warmup_epoch): 6 | self.optimizer = optimizer 7 | self._step = 0 8 | self.lr_base = lr_base 9 | self._rate = 0 10 | self.data_size = data_size 11 | self.batch_size = batch_size 12 | self.warmup_epoch = warmup_epoch 13 | 14 | 15 | def step(self): 16 | self._step += 1 17 | 18 | rate = self.rate() 19 | for p in self.optimizer.param_groups: 20 | p['lr'] = rate 21 | self._rate = rate 22 | 23 | self.optimizer.step() 24 | 25 | 26 | def zero_grad(self): 27 | self.optimizer.zero_grad() 28 | 29 | 30 | def rate(self, step=None): 31 | if step is None: 32 | step = self._step 33 | 34 | if step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.25): 35 | r = self.lr_base * 1/(self.warmup_epoch + 1) 36 | elif step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.5): 37 | r = self.lr_base * 2/(self.warmup_epoch + 1) 38 | elif step <= int(self.data_size / self.batch_size * (self.warmup_epoch + 1) * 0.75): 39 | r = self.lr_base * 3/(self.warmup_epoch + 1) 40 | else: 41 | r = self.lr_base 42 | 43 | return r 44 | 45 | def __str__(self): 46 | return "optimizer -> " + str(self.optimizer) 47 | 48 | 49 | 50 | def get_optim(args, model, data_size, lr_base=None): 51 | if lr_base is None: 52 | lr_base = args.lr_base 53 | 54 | std_optim = getattr(Optim, args.opt) 55 | params = filter(lambda p: p.requires_grad, model.parameters()) 56 | eval_str = 'params, lr=0' 57 | d_opt = eval(args.opt_params) 58 | for key in d_opt: 59 | eval_str += ' ,' + key + '=' + str(d_opt[key]) 60 | 61 | optim = WarmupOptimizer( 62 | lr_base, 63 | eval('std_optim' + '(' + eval_str + ')'), 64 | data_size, 65 | args.batch_size, 66 | args.warmup_epoch 67 | ) 68 | 69 | return optim 70 | 71 | 72 | def adjust_lr(optim, decay_r): 73 | optim.lr_base *= decay_r -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import itertools 4 | 5 | def plot(d): 6 | # An "interface" to matplotlib.axes.Axes.hist() method 7 | n, bins, patches = plt.hist(x=d, bins='auto', color='#0504aa', 8 | alpha=0.7, rwidth=0.85) 9 | plt.grid(axis='y', alpha=0.75) 10 | plt.title('Temporal dimension for visual features') 11 | maxfreq = n.max() 12 | # Set a clean upper y-axis limit. 13 | plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10) 14 | axes = plt.gca() 15 | axes.set_xlim([0, 100]) 16 | plt.xlabel("Temporal dimension") 17 | plt.ylabel("Number of samples") 18 | plt.show() 19 | plt.savefig('oui') 20 | 21 | 22 | 23 | def plot_confusion_matrix(cm, classes, 24 | normalize=False, 25 | title='Confusion matrix', 26 | cmap=plt.cm.Blues): 27 | """ 28 | This function prints and plots the confusion matrix. 29 | Normalization can be applied by setting `normalize=True`. 30 | """ 31 | 32 | 33 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 34 | plt.title(title) 35 | plt.colorbar() 36 | tick_marks = np.arange(len(classes)) 37 | plt.xticks(tick_marks, classes, rotation=45) 38 | plt.yticks(tick_marks, classes) 39 | 40 | if normalize: 41 | cm = np.round((cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100, 2) # from 0.45555555 to 45.55 42 | print("Normalized confusion matrix") 43 | else: 44 | print('Confusion matrix, without normalization') 45 | 46 | print(cm) 47 | 48 | thresh = cm.max() / 2. 49 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 50 | plt.text(j, i, cm[i, j], 51 | horizontalalignment="center", 52 | color="white" if cm[i, j] > thresh else "black") 53 | 54 | plt.tight_layout() 55 | plt.ylabel('True class') 56 | plt.xlabel('Predicted class') 57 | plt.savefig("confusion_matrix") 58 | plt.clf() -------------------------------------------------------------------------------- /utils/pred_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def amax(x): 5 | return np.argmax(x, axis=1) 6 | 7 | 8 | def multi_label(x): 9 | return (x > 0) -------------------------------------------------------------------------------- /utils/tokenize.py: -------------------------------------------------------------------------------- 1 | # $ wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz 2 | # $ pip install en_vectors_web_lg-2.1.0.tar.gz 3 | import en_vectors_web_lg 4 | import re 5 | import numpy as np 6 | import os 7 | import pickle 8 | 9 | def clean(w): 10 | return re.sub( 11 | r"([.,'!?\"()*#:;])", 12 | '', 13 | w.lower() 14 | ).replace('-', ' ').replace('/', ' ') 15 | 16 | 17 | def tokenize(key_to_word): 18 | key_to_sentence = {} 19 | for k, v in key_to_word.items(): 20 | key_to_sentence[k] = [clean(w) for w in v if clean(w) != ''] 21 | return key_to_sentence 22 | 23 | 24 | def create_dict(key_to_sentence, dataroot, use_glove=True): 25 | token_file = dataroot+"/token_to_ix.pkl" 26 | glove_file = dataroot+"/train_glove.npy" 27 | if os.path.exists(glove_file) and os.path.exists(token_file): 28 | print("Loading train language files") 29 | return pickle.load(open(token_file, "rb")), np.load(glove_file) 30 | 31 | print("Creating train language files") 32 | token_to_ix = { 33 | 'UNK': 1, 34 | } 35 | 36 | spacy_tool = None 37 | pretrained_emb = [] 38 | if use_glove: 39 | spacy_tool = en_vectors_web_lg.load() 40 | pretrained_emb.append(spacy_tool('UNK').vector) 41 | 42 | for k, v in key_to_sentence.items(): 43 | for word in v: 44 | if word not in token_to_ix: 45 | token_to_ix[word] = len(token_to_ix) 46 | if use_glove: 47 | pretrained_emb.append(spacy_tool(word).vector) 48 | 49 | pretrained_emb = np.array(pretrained_emb) 50 | np.save(glove_file, pretrained_emb) 51 | pickle.dump(token_to_ix, open(token_file, "wb")) 52 | return token_to_ix, pretrained_emb 53 | 54 | def sent_to_ix(s, token_to_ix, max_token=100): 55 | ques_ix = np.zeros(max_token, np.int64) 56 | 57 | for ix, word in enumerate(s): 58 | if word in token_to_ix: 59 | ques_ix[ix] = token_to_ix[word] 60 | else: 61 | ques_ix[ix] = token_to_ix['UNK'] 62 | 63 | if ix + 1 == max_token: 64 | break 65 | 66 | return ques_ix 67 | 68 | 69 | def cmumosei_7(a): 70 | if a < -2: 71 | res = 0 72 | if -2 <= a and a < -1: 73 | res = 1 74 | if -1 <= a and a < 0: 75 | res = 2 76 | if 0 <= a and a <= 0: 77 | res = 3 78 | if 0 < a and a <= 1: 79 | res = 4 80 | if 1 < a and a <= 2: 81 | res = 5 82 | if a > 2: 83 | res = 6 84 | return res 85 | 86 | def cmumosei_2(a): 87 | if a < 0: 88 | return 0 89 | if a >= 0: 90 | return 1 91 | 92 | def pad_feature(feat, max_len): 93 | if feat.shape[0] > max_len: 94 | feat = feat[:max_len] 95 | feat = np.pad( 96 | feat, 97 | ((0, max_len - feat.shape[0]), (0, 0)), 98 | mode='constant', 99 | constant_values=0 100 | ) 101 | return feat 102 | --------------------------------------------------------------------------------