├── README.md ├── dataloader.py ├── model.py ├── module.py ├── requirements.txt ├── run.py ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GraphSmile 2 | The official implementation of the paper "[Tracing intricate cues in dialogue: Joint graph structure and sentiment dynamics for multimodal emotion recognition](https://arxiv.org/abs/2407.21536)", which has been accepted by. 3 | Authors: Jiang Li, Xiaoping Wang, Zhigang Zeng 4 | Affiliation: Huazhong University of Science and Technology (HUST) 5 | 6 | ## Citation 7 | ```bibtex 8 | @article{li2025tracing, 9 | title={Tracing intricate cues in dialogue: Joint graph structure and sentiment dynamics for multimodal emotion recognition}, 10 | author={Jiang Li and Xiaoping Wang and Zhigang Zeng}, 11 | year={2025}, 12 | journal = {}, 13 | volume = {}, 14 | number={}, 15 | pages = {1-18}, 16 | doi={} 17 | } 18 | ``` 19 | 20 | ## Requirement 21 | Checking and installing environmental requirements 22 | ```python 23 | pip install -r requirements.txt 24 | ``` 25 | ## Datasets 26 | 链接: https://pan.baidu.com/s/1u1efdbBV3HP8FLj3Gy1bvQ 27 | 提取码: ipnv 28 | Google Drive: https://drive.google.com/drive/folders/1l_ex1wnAAMpEtO71rjjM1MKC7W_olEVi?usp=drive_link 29 | 30 | Adding the dataset path to the corresponding location in the run.py file, e.g. IEMOCAP_path = "". 31 | 32 | ## Run 33 | ### IEMOCAP-6 34 | ```bash 35 | python -u run.py --gpu 2 --port 1530 --classify emotion \ 36 | --dataset IEMOCAP --epochs 120 --textf_mode textf0 \ 37 | --loss_type emo_sen_sft --lr 1e-04 --batch_size 16 --hidden_dim 512 \ 38 | --win 17 17 --heter_n_layers 7 7 7 --drop 0.2 --shift_win 19 --lambd 1.0 1.0 0.7 39 | ``` 40 | 41 | ### IEMOCAP-4 42 | ```bash 43 | python -u run.py --gpu 2 --port 1531 --classify emotion \ 44 | --dataset IEMOCAP4 --epochs 120 --textf_mode textf0 \ 45 | --loss_type emo_sen_sft --lr 3e-04 --batch_size 16 --hidden_dim 256 \ 46 | --win 5 5 --heter_n_layers 4 4 4 --drop 0.2 --shift_win 10 --lambd 1.0 0.6 0.6 47 | ``` 48 | 49 | ### MELD 50 | ```bash 51 | python -u run.py --gpu 2 --port 1532 --classify emotion \ 52 | --dataset MELD --epochs 50 --textf_mode textf0 \ 53 | --loss_type emo_sen_sft --lr 7e-05 --batch_size 16 --hidden_dim 384 \ 54 | --win 3 3 --heter_n_layers 5 5 5 --drop 0.2 --shift_win 3 --lambd 1.0 0.5 0.2 55 | ``` 56 | 57 | ### CMUMOSEI7 58 | ```bash 59 | python -u run.py --gpu 3 --port 1534 --classify emotion \ 60 | --dataset CMUMOSEI7 --epochs 60 --textf_mode textf0 \ 61 | --loss_type emo_sen_sft --lr 8e-05 --batch_size 32 --hidden_dim 256 \ 62 | --win 5 5 --heter_n_layers 2 2 2 --drop 0.4 --shift_win 2 --lambd 1.0 0.8 1.0 63 | ``` -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | import pickle, pandas as pd 5 | import numpy as np 6 | 7 | 8 | class IEMOCAPDataset_BERT4(Dataset): 9 | 10 | def __init__(self, path, train=True): 11 | 12 | ( 13 | self.videoIDs, 14 | self.videoSpeakers, 15 | self.videoLabels, 16 | self.videoText, 17 | self.videoAudio, 18 | self.videoVisual, 19 | self.videoSentence, 20 | self.trainVid, 21 | self.testVid, 22 | ) = pickle.load(open(path, "rb"), encoding="latin1") 23 | self.keys = self.trainVid if train else self.testVid 24 | 25 | self.len = len(self.keys) 26 | 27 | self.labels_emotion = self.videoLabels 28 | 29 | labels_sentiment = {} 30 | for item in self.videoLabels: 31 | array = [] 32 | for e in self.videoLabels[item]: 33 | if e in [1, 3]: 34 | array.append(0) 35 | elif e == 2: 36 | array.append(1) 37 | elif e in [0]: 38 | array.append(2) 39 | labels_sentiment[item] = array 40 | self.labels_sentiment = labels_sentiment 41 | 42 | def __getitem__(self, index): 43 | vid = self.keys[index] 44 | return ( 45 | torch.FloatTensor(np.array(self.videoText[vid])), 46 | torch.FloatTensor(np.array(self.videoText[vid])), 47 | torch.FloatTensor(np.array(self.videoText[vid])), 48 | torch.FloatTensor(np.array(self.videoText[vid])), 49 | torch.FloatTensor(np.array(self.videoVisual[vid])), 50 | torch.FloatTensor(np.array(self.videoAudio[vid])), 51 | torch.FloatTensor( 52 | [ 53 | [1, 0] if x == "M" else [0, 1] 54 | for x in np.array(self.videoSpeakers[vid]) 55 | ] 56 | ), 57 | torch.FloatTensor([1] * len(np.array(self.labels_emotion[vid]))), 58 | torch.LongTensor(np.array(self.labels_emotion[vid])), 59 | torch.LongTensor(np.array(self.labels_sentiment[vid])), 60 | vid, 61 | ) 62 | 63 | def __len__(self): 64 | return self.len 65 | 66 | def collate_fn(self, data): 67 | dat = pd.DataFrame(data) 68 | 69 | return [ 70 | ( 71 | pad_sequence(dat[i]) 72 | if i < 7 73 | else pad_sequence(dat[i]) if i < 10 else dat[i].tolist() 74 | ) 75 | for i in dat 76 | ] 77 | 78 | 79 | class IEMOCAPDataset_BERT(Dataset): 80 | 81 | def __init__(self, path, train=True): 82 | 83 | ( 84 | self.videoIDs, 85 | self.videoSpeakers, 86 | self.videoLabels, 87 | self.videoText0, 88 | self.videoText1, 89 | self.videoText2, 90 | self.videoText3, 91 | self.videoAudio, 92 | self.videoVisual, 93 | self.videoSentence, 94 | self.trainVid, 95 | self.testVid, 96 | ) = pickle.load(open(path, "rb"), encoding="latin1") 97 | 98 | self.keys = self.trainVid if train else self.testVid 99 | 100 | self.len = len(self.keys) 101 | 102 | self.labels_emotion = self.videoLabels 103 | 104 | labels_sentiment = {} 105 | for item in self.videoLabels: 106 | array = [] 107 | for e in self.videoLabels[item]: 108 | if e in [1, 3, 5]: 109 | array.append(0) 110 | elif e == 2: 111 | array.append(1) 112 | elif e in [0, 4]: 113 | array.append(2) 114 | labels_sentiment[item] = array 115 | self.labels_sentiment = labels_sentiment 116 | 117 | def __getitem__(self, index): 118 | vid = self.keys[index] 119 | return ( 120 | torch.FloatTensor(np.array(self.videoText0[vid])), 121 | torch.FloatTensor(np.array(self.videoText1[vid])), 122 | torch.FloatTensor(np.array(self.videoText2[vid])), 123 | torch.FloatTensor(np.array(self.videoText3[vid])), 124 | torch.FloatTensor(np.array(self.videoVisual[vid])), 125 | torch.FloatTensor(np.array(self.videoAudio[vid])), 126 | torch.FloatTensor( 127 | [ 128 | [1, 0] if x == "M" else [0, 1] 129 | for x in np.array(self.videoSpeakers[vid]) 130 | ] 131 | ), 132 | torch.FloatTensor([1] * len(np.array(self.labels_emotion[vid]))), 133 | torch.LongTensor(np.array(self.labels_emotion[vid])), 134 | torch.LongTensor(np.array(self.labels_sentiment[vid])), 135 | vid, 136 | ) 137 | 138 | def __len__(self): 139 | return self.len 140 | 141 | def collate_fn(self, data): 142 | dat = pd.DataFrame(data) 143 | 144 | return [ 145 | ( 146 | pad_sequence(dat[i]) 147 | if i < 7 148 | else pad_sequence(dat[i]) if i < 10 else dat[i].tolist() 149 | ) 150 | for i in dat 151 | ] 152 | 153 | 154 | class MELDDataset_BERT(Dataset): 155 | 156 | def __init__(self, path, train=True): 157 | """ 158 | label index mapping = {0:neutral, 1:surprise, 2:fear, 3:sadness, 4:joy, 5:disgust, 6:anger} 159 | """ 160 | ( 161 | self.videoIDs, 162 | self.videoSpeakers, 163 | self.videoLabels, 164 | self.videoSentiments, 165 | self.videoText0, 166 | self.videoText1, 167 | self.videoText2, 168 | self.videoText3, 169 | self.videoAudio, 170 | self.videoVisual, 171 | self.videoSentence, 172 | self.trainVid, 173 | self.testVid, 174 | _, 175 | ) = pickle.load(open(path, "rb")) 176 | 177 | self.keys = [x for x in (self.trainVid if train else self.testVid)] 178 | 179 | self.len = len(self.keys) 180 | 181 | self.labels_emotion = self.videoLabels 182 | 183 | self.labels_sentiment = self.videoSentiments 184 | 185 | def __getitem__(self, index): 186 | vid = self.keys[index] 187 | return ( 188 | torch.FloatTensor(np.array(self.videoText0[vid])), 189 | torch.FloatTensor(np.array(self.videoText1[vid])), 190 | torch.FloatTensor(np.array(self.videoText2[vid])), 191 | torch.FloatTensor(np.array(self.videoText3[vid])), 192 | torch.FloatTensor(np.array(self.videoVisual[vid])), 193 | torch.FloatTensor(np.array(self.videoAudio[vid])), 194 | torch.FloatTensor(np.array(self.videoSpeakers[vid])), 195 | torch.FloatTensor([1] * len(np.array(self.labels_emotion[vid]))), 196 | torch.LongTensor(np.array(self.labels_emotion[vid])), 197 | torch.LongTensor(np.array(self.labels_sentiment[vid])), 198 | vid, 199 | ) 200 | 201 | def __len__(self): 202 | return self.len 203 | 204 | def return_labels(self): 205 | return_label = [] 206 | for key in self.keys: 207 | return_label += self.videoLabels[key] 208 | return return_label 209 | 210 | def collate_fn(self, data): 211 | dat = pd.DataFrame(data) 212 | 213 | return [ 214 | ( 215 | pad_sequence(dat[i]) 216 | if i < 7 217 | else pad_sequence(dat[i]) if i < 10 else dat[i].tolist() 218 | ) 219 | for i in dat 220 | ] 221 | 222 | 223 | class CMUMOSEIDataset7(Dataset): 224 | 225 | def __init__(self, path, train=True): 226 | 227 | ( 228 | self.videoIDs, 229 | self.videoSpeakers, 230 | self.videoLabels, 231 | self.videoText, 232 | self.videoAudio, 233 | self.videoVisual, 234 | self.videoSentence, 235 | self.trainVid, 236 | self.testVid, 237 | ) = pickle.load(open(path, "rb"), encoding="latin1") 238 | 239 | self.keys = self.trainVid if train else self.testVid 240 | 241 | self.len = len(self.keys) 242 | 243 | labels_emotion = {} 244 | for item in self.videoLabels: 245 | array = [] 246 | for a in self.videoLabels[item]: 247 | if a < -2: 248 | array.append(0) 249 | elif -2 <= a and a < -1: 250 | array.append(1) 251 | elif -1 <= a and a < 0: 252 | array.append(2) 253 | elif 0 <= a and a <= 0: 254 | array.append(3) 255 | elif 0 < a and a <= 1: 256 | array.append(4) 257 | elif 1 < a and a <= 2: 258 | array.append(5) 259 | elif a > 2: 260 | array.append(6) 261 | labels_emotion[item] = array 262 | self.labels_emotion = labels_emotion 263 | 264 | labels_sentiment = {} 265 | for item in self.videoLabels: 266 | array = [] 267 | for a in self.videoLabels[item]: 268 | if a < 0: 269 | array.append(0) 270 | elif 0 <= a and a <= 0: 271 | array.append(1) 272 | elif a > 0: 273 | array.append(2) 274 | labels_sentiment[item] = array 275 | self.labels_sentiment = labels_sentiment 276 | 277 | def __getitem__(self, index): 278 | vid = self.keys[index] 279 | return ( 280 | torch.FloatTensor(np.array(self.videoText[vid])), 281 | torch.FloatTensor(np.array(self.videoText[vid])), 282 | torch.FloatTensor(np.array(self.videoText[vid])), 283 | torch.FloatTensor(np.array(self.videoText[vid])), 284 | torch.FloatTensor(np.array(self.videoVisual[vid])), 285 | torch.FloatTensor(np.array(self.videoAudio[vid])), 286 | torch.FloatTensor( 287 | [ 288 | [1, 0] if x == "M" else [0, 1] 289 | for x in np.array(self.videoSpeakers[vid]) 290 | ] 291 | ), 292 | torch.FloatTensor([1] * len(np.array(self.labels_emotion[vid]))), 293 | torch.LongTensor(np.array(self.labels_emotion[vid])), 294 | torch.LongTensor(np.array(self.labels_sentiment[vid])), 295 | vid, 296 | ) 297 | 298 | def __len__(self): 299 | return self.len 300 | 301 | def collate_fn(self, data): 302 | dat = pd.DataFrame(data) 303 | 304 | return [ 305 | ( 306 | pad_sequence(dat[i]) 307 | if i < 7 308 | else pad_sequence(dat[i]) if i < 10 else dat[i].tolist() 309 | ) 310 | for i in dat 311 | ] 312 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from module import HeterGConv_Edge, HeterGConvLayer, SenShift_Feat 2 | import torch.nn as nn 3 | import torch 4 | from utils import batch_to_all_tva 5 | 6 | 7 | class GraphSmile(nn.Module): 8 | 9 | def __init__(self, args, embedding_dims, n_classes_emo): 10 | super(GraphSmile, self).__init__() 11 | self.textf_mode = args.textf_mode 12 | self.no_cuda = args.no_cuda 13 | self.win_p = args.win[0] 14 | self.win_f = args.win[1] 15 | self.modals = args.modals 16 | self.shift_win = args.shift_win 17 | 18 | self.batchnorms_t = nn.ModuleList( 19 | nn.BatchNorm1d(embedding_dims[0]) for _ in range(4)) 20 | 21 | in_dims_t = (4 * embedding_dims[0] if args.textf_mode == "concat4" else 22 | (2 * embedding_dims[0] 23 | if args.textf_mode == "concat2" else embedding_dims[0])) 24 | self.dim_layer_t = nn.Sequential(nn.Linear(in_dims_t, args.hidden_dim), 25 | nn.LeakyReLU(), nn.Dropout(args.drop)) 26 | self.dim_layer_v = nn.Sequential( 27 | nn.Linear(embedding_dims[1], args.hidden_dim), 28 | nn.LeakyReLU(), 29 | nn.Dropout(args.drop), 30 | ) 31 | self.dim_layer_a = nn.Sequential( 32 | nn.Linear(embedding_dims[2], args.hidden_dim), 33 | nn.LeakyReLU(), 34 | nn.Dropout(args.drop), 35 | ) 36 | 37 | # Heter 38 | hetergconvLayer_tv = HeterGConvLayer(args.hidden_dim, args.drop, 39 | args.no_cuda) 40 | self.hetergconv_tv = HeterGConv_Edge( 41 | args.hidden_dim, 42 | hetergconvLayer_tv, 43 | args.heter_n_layers[0], 44 | args.drop, 45 | args.no_cuda, 46 | ) 47 | hetergconvLayer_ta = HeterGConvLayer(args.hidden_dim, args.drop, 48 | args.no_cuda) 49 | self.hetergconv_ta = HeterGConv_Edge( 50 | args.hidden_dim, 51 | hetergconvLayer_ta, 52 | args.heter_n_layers[1], 53 | args.drop, 54 | args.no_cuda, 55 | ) 56 | hetergconvLayer_va = HeterGConvLayer(args.hidden_dim, args.drop, 57 | args.no_cuda) 58 | self.hetergconv_va = HeterGConv_Edge( 59 | args.hidden_dim, 60 | hetergconvLayer_va, 61 | args.heter_n_layers[2], 62 | args.drop, 63 | args.no_cuda, 64 | ) 65 | 66 | self.modal_fusion = nn.Sequential( 67 | nn.Linear(args.hidden_dim, args.hidden_dim), 68 | nn.LeakyReLU(), 69 | ) 70 | 71 | self.emo_output = nn.Linear(args.hidden_dim, n_classes_emo) 72 | self.sen_output = nn.Linear(args.hidden_dim, 3) 73 | self.senshift = SenShift_Feat(args.hidden_dim, args.drop, 74 | args.shift_win) 75 | 76 | def forward(self, feature_t0, feature_t1, feature_t2, feature_t3, 77 | feature_v, feature_a, umask, qmask, dia_lengths): 78 | 79 | ( 80 | (seq_len_t, batch_size_t, feature_dim_t), 81 | (seq_len_v, batch_size_v, feature_dim_v), 82 | (seq_len_a, batch_size_a, feature_dim_a), 83 | ) = [feature_t0.shape, feature_v.shape, feature_a.shape] 84 | features_t = [ 85 | batchnorm_t(feature_t.transpose(0, 1).reshape( 86 | -1, feature_dim_t)).reshape(-1, seq_len_t, 87 | feature_dim_t).transpose(1, 0) 88 | for batchnorm_t, feature_t in 89 | zip(self.batchnorms_t, 90 | [feature_t0, feature_t1, feature_t2, feature_t3]) 91 | ] 92 | feature_t0, feature_t1, feature_t2, feature_t3 = features_t 93 | 94 | dim_layer_dict_t = { 95 | "concat4": 96 | lambda: self.dim_layer_t( 97 | torch.cat([feature_t0, feature_t1, feature_t2, feature_t3], 98 | dim=-1)), 99 | "sum4": 100 | lambda: 101 | (self.dim_layer_t(feature_t0) + self.dim_layer_t(feature_t1) + self 102 | .dim_layer_t(feature_t2) + self.dim_layer_t(feature_t3)) / 4, 103 | "concat2": 104 | lambda: self.dim_layer_t( 105 | torch.cat([feature_t0, feature_t1], dim=-1)), 106 | "sum2": 107 | lambda: 108 | (self.dim_layer_t(feature_t0) + self.dim_layer_t(feature_t1)) / 2, 109 | "textf0": 110 | lambda: self.dim_layer_t(feature_t0), 111 | "textf1": 112 | lambda: self.dim_layer_t(feature_t1), 113 | "textf2": 114 | lambda: self.dim_layer_t(feature_t2), 115 | "textf3": 116 | lambda: self.dim_layer_t(feature_t3), 117 | } 118 | featdim_t = dim_layer_dict_t[self.textf_mode]() 119 | featdim_v, featdim_a = self.dim_layer_v(feature_v), self.dim_layer_a( 120 | feature_a) 121 | 122 | emo_t, emo_v, emo_a = featdim_t, featdim_v, featdim_a 123 | 124 | emo_t, emo_v, emo_a = batch_to_all_tva(emo_t, emo_v, emo_a, 125 | dia_lengths, self.no_cuda) 126 | 127 | featheter_tv, heter_edge_index = self.hetergconv_tv( 128 | (emo_t, emo_v), dia_lengths, self.win_p, self.win_f) 129 | featheter_ta, heter_edge_index = self.hetergconv_ta( 130 | (emo_t, emo_a), dia_lengths, self.win_p, self.win_f, 131 | heter_edge_index) 132 | featheter_va, heter_edge_index = self.hetergconv_va( 133 | (emo_v, emo_a), dia_lengths, self.win_p, self.win_f, 134 | heter_edge_index) 135 | 136 | feat_fusion = (self.modal_fusion(featheter_tv[0]) + self.modal_fusion( 137 | featheter_ta[0]) + self.modal_fusion(featheter_tv[1]) + 138 | self.modal_fusion(featheter_va[0]) + 139 | self.modal_fusion(featheter_ta[1]) + 140 | self.modal_fusion(featheter_va[1])) / 6 141 | 142 | logit_emo = self.emo_output(feat_fusion) 143 | logit_sen = self.sen_output(feat_fusion) 144 | 145 | logit_shift = self.senshift(feat_fusion, feat_fusion, dia_lengths) 146 | 147 | return logit_emo, logit_sen, logit_shift, feat_fusion 148 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | from itertools import permutations, product 2 | import math 3 | import torch 4 | import copy 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | 9 | def _get_clones(module, N): 10 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 11 | 12 | 13 | class HeterGConv_Edge(torch.nn.Module): 14 | 15 | def __init__(self, feature_size, encoder_layer, num_layers, dropout, 16 | no_cuda): 17 | super(HeterGConv_Edge, self).__init__() 18 | self.num_layers = num_layers 19 | self.no_cuda = no_cuda 20 | 21 | self.edge_weight = nn.Parameter(torch.ones(500000)) 22 | 23 | self.hetergcn_layers = _get_clones(encoder_layer, num_layers) 24 | self.fc_layer = nn.Sequential(nn.Linear(feature_size, feature_size), 25 | nn.LeakyReLU(), nn.Dropout(dropout)) 26 | self.fc_layers = _get_clones(self.fc_layer, num_layers) 27 | 28 | def forward(self, feature_tuple, dia_lens, win_p, win_f, edge_index=None): 29 | 30 | num_modal = len(feature_tuple) 31 | feature = torch.cat(feature_tuple, dim=0) 32 | 33 | if edge_index is None: 34 | edge_index = self._heter_no_weight_edge(feature, num_modal, 35 | dia_lens, win_p, win_f) 36 | edge_weight = self.edge_weight[0:edge_index.size(1)] 37 | 38 | adj_weight = self._edge_index_to_adjacency_matrix( 39 | edge_index, 40 | edge_weight, 41 | num_nodes=feature.size(0), 42 | no_cuda=self.no_cuda) 43 | feature_sum = feature 44 | for i in range(self.num_layers): 45 | feature = self.hetergcn_layers[i](feature, num_modal, adj_weight) 46 | feature_sum = feature_sum + self.fc_layers[i](feature) 47 | feat_tuple = torch.chunk(feature_sum, num_modal, dim=0) 48 | 49 | return feat_tuple, edge_index 50 | 51 | def _edge_index_to_adjacency_matrix(self, 52 | edge_index, 53 | edge_weight=None, 54 | num_nodes=100, 55 | no_cuda=False): 56 | 57 | if edge_weight is not None: 58 | edge_weight = edge_weight.squeeze() 59 | else: 60 | edge_weight = torch.ones( 61 | edge_index.size(1)).cuda() if not no_cuda else torch.ones( 62 | edge_index.size(1)) 63 | adj_sparse = torch.sparse_coo_tensor(edge_index, 64 | edge_weight, 65 | size=(num_nodes, num_nodes)) 66 | adj = adj_sparse.to_dense() 67 | row_sum = torch.sum(adj, dim=1) 68 | d_inv_sqrt = torch.pow(row_sum, -0.5) 69 | d_inv_sqrt[d_inv_sqrt == float("inf")] = 0 70 | d_inv_sqrt_mat = torch.diag_embed(d_inv_sqrt) 71 | gcn_fact = torch.matmul(d_inv_sqrt_mat, 72 | torch.matmul(adj, d_inv_sqrt_mat)) 73 | 74 | if not no_cuda and torch.cuda.is_available(): 75 | gcn_fact = gcn_fact.cuda() 76 | 77 | return gcn_fact 78 | 79 | def _heter_no_weight_edge(self, feature, num_modal, dia_lens, win_p, 80 | win_f): 81 | index_inter = [] 82 | all_dia_len = sum(dia_lens) 83 | all_nodes = list(range(all_dia_len * num_modal)) 84 | nodes_uni = [None] * num_modal 85 | 86 | for m in range(num_modal): 87 | nodes_uni[m] = all_nodes[m * all_dia_len:(m + 1) * all_dia_len] 88 | 89 | start = 0 90 | for dia_len in dia_lens: 91 | for m, n in permutations(range(num_modal), 2): 92 | 93 | for j, node_m in enumerate(nodes_uni[m][start:start + 94 | dia_len]): 95 | if win_p == -1 and win_f == -1: 96 | nodes_n = nodes_uni[n][start:start + dia_len] 97 | elif win_p == -1: 98 | nodes_n = nodes_uni[n][ 99 | start:min(start + dia_len, start + j + win_f + 1)] 100 | elif win_f == -1: 101 | nodes_n = nodes_uni[n][max(start, start + j - 102 | win_p):start + dia_len] 103 | else: 104 | nodes_n = nodes_uni[n][ 105 | max(start, start + j - 106 | win_p):min(start + dia_len, start + j + win_f + 107 | 1)] 108 | index_inter.extend(list(product([node_m], nodes_n))) 109 | start += dia_len 110 | edge_index = (torch.tensor(index_inter).permute(1, 0).cuda() if 111 | not self.no_cuda else torch.tensor(index_inter).permute( 112 | 1, 0)) 113 | 114 | return edge_index 115 | 116 | 117 | class HeterGConvLayer(torch.nn.Module): 118 | 119 | def __init__(self, feature_size, dropout=0.3, no_cuda=False): 120 | super(HeterGConvLayer, self).__init__() 121 | self.no_cuda = no_cuda 122 | self.hetergconv = SGConv_Our(feature_size, feature_size) 123 | 124 | def forward(self, feature, num_modal, adj_weight): 125 | 126 | if num_modal > 1: 127 | feature_heter = self.hetergconv(feature, adj_weight) 128 | else: 129 | print("Unable to construct heterogeneous graph!") 130 | feature_heter = feature 131 | 132 | return feature_heter 133 | 134 | 135 | class SGConv_Our(torch.nn.Module): 136 | """ 137 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 138 | """ 139 | 140 | def __init__(self, in_features, out_features, bias=True): 141 | super(SGConv_Our, self).__init__() 142 | self.in_features = in_features 143 | self.out_features = out_features 144 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 145 | if bias: 146 | self.bias = Parameter(torch.FloatTensor(out_features)) 147 | else: 148 | self.register_parameter("bias", None) 149 | self.reset_parameters() 150 | 151 | def reset_parameters(self): 152 | stdv = 1.0 / math.sqrt(self.weight.size(1)) 153 | self.weight.data.uniform_(-stdv, stdv) 154 | if self.bias is not None: 155 | self.bias.data.uniform_(-stdv, stdv) 156 | 157 | def forward(self, input, adj): 158 | 159 | try: 160 | input = input.float() 161 | except: 162 | pass 163 | support = torch.mm(input, self.weight) 164 | output = torch.spmm(adj, support) 165 | if self.bias is not None: 166 | return output + self.bias 167 | else: 168 | return output 169 | 170 | 171 | class SenShift_Feat(nn.Module): 172 | 173 | def __init__(self, hidden_dim, dropout, shift_win): 174 | super().__init__() 175 | 176 | self.shift_win = shift_win 177 | 178 | hidden_dim_shift = 2 * hidden_dim 179 | 180 | self.shift_output_layer = nn.Sequential(nn.Linear(hidden_dim_shift, 181 | 2), ) 182 | 183 | def forward(self, embeds, embeds_temp=None, dia_lens=[]): 184 | 185 | if embeds_temp == None: 186 | embeds_temp = embeds 187 | embeds_shift = self._build_match_sample(embeds, embeds_temp, dia_lens, 188 | self.shift_win) 189 | logits = self.shift_output_layer(embeds_shift) 190 | 191 | return logits 192 | 193 | def _build_match_sample(self, embeds, embeds_temp, dia_lens, shift_win): 194 | 195 | start = 0 196 | embeds_shifts = [] 197 | if shift_win == -1: 198 | for dia_len in dia_lens: 199 | embeds_shifts.append( 200 | torch.cat( 201 | [ 202 | embeds[start:start + dia_len, None, :].repeat( 203 | 1, dia_len, 1), 204 | embeds_temp[None, start:start + dia_len, :].repeat( 205 | dia_len, 1, 1), 206 | ], 207 | dim=-1, 208 | ).view(-1, 2 * embeds.size(-1))) 209 | start += dia_len 210 | embeds_shift = torch.cat(embeds_shifts, dim=0) 211 | 212 | elif shift_win > 0: 213 | for dia_len in dia_lens: 214 | win_start = 0 215 | for i in range(math.ceil(dia_len / shift_win)): 216 | if (i == math.ceil(dia_len / shift_win) - 1 217 | and dia_len % shift_win != 0): 218 | win = dia_len % shift_win 219 | else: 220 | win = shift_win 221 | embeds_shifts.append( 222 | torch.cat( 223 | [ 224 | embeds[ 225 | start + win_start : start + win_start + win, None, : 226 | ].repeat(1, win, 1), 227 | embeds_temp[ 228 | None, start + win_start : start + win_start + win, : 229 | ].repeat(win, 1, 1), 230 | ], 231 | dim=-1, 232 | ).view(-1, 2 * embeds.size(-1)) 233 | ) 234 | win_start += shift_win 235 | start += dia_len 236 | embeds_shift = torch.cat(embeds_shifts, dim=0) 237 | else: 238 | print("Window must be greater than 0 or equal to -1") 239 | raise NotImplementedError 240 | 241 | return embeds_shift 242 | 243 | 244 | def build_match_sen_shift_label(shift_win, dia_lengths, label_sen): 245 | start = 0 246 | label_shifts = [] 247 | if shift_win == -1: 248 | for dia_len in dia_lengths: 249 | dia_label_shift = ((label_sen[start:start + dia_len, None] 250 | != label_sen[None, start:start + 251 | dia_len]).long().view(-1)) 252 | label_shifts.append(dia_label_shift) 253 | start += dia_len 254 | label_shift = torch.cat(label_shifts, dim=0) 255 | elif shift_win > 0: 256 | for dia_len in dia_lengths: 257 | win_start = 0 258 | for i in range(math.ceil(dia_len / shift_win)): 259 | if i == math.ceil( 260 | dia_len / shift_win) - 1 and dia_len % shift_win != 0: 261 | win = dia_len % shift_win 262 | else: 263 | win = shift_win 264 | dia_label_shift = ( 265 | ( 266 | label_sen[start + win_start : start + win_start + win, None] 267 | != label_sen[None, start + win_start : start + win_start + win] 268 | ) 269 | .long() 270 | .view(-1) 271 | ) 272 | label_shifts.append(dia_label_shift) 273 | win_start += shift_win 274 | start += dia_len 275 | label_shift = torch.cat(label_shifts, dim=0) 276 | else: 277 | print("Window must be greater than 0 or equal to -1") 278 | raise NotImplementedError 279 | 280 | return label_shift 281 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.2.2 2 | pandas==2.2.3 3 | scikit_learn==1.6.1 4 | tensorboardX==2.6.2.2 5 | tensorboardX==2.6.2.2 6 | torch==2.0.0+cu117 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import pickle as pk 5 | import datetime 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | import torch.multiprocessing as mp 12 | from torch.utils.data.distributed import DistributedSampler 13 | import time 14 | from utils import AutomaticWeightedLoss 15 | from model import GraphSmile 16 | from sklearn.metrics import confusion_matrix, classification_report 17 | from trainer import train_or_eval_model, seed_everything 18 | from dataloader import ( 19 | IEMOCAPDataset_BERT, 20 | IEMOCAPDataset_BERT4, 21 | MELDDataset_BERT, 22 | CMUMOSEIDataset7, 23 | ) 24 | from torch.utils.data import DataLoader 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument("--no_cuda", 30 | action="store_true", 31 | default=False, 32 | help="does not use GPU") 33 | parser.add_argument("--gpu", default="2", type=str, help="GPU ids") 34 | parser.add_argument("--port", default="15301", help="MASTER_PORT") 35 | parser.add_argument("--classify", default="emotion", help="sentiment, emotion") 36 | parser.add_argument("--lr", 37 | type=float, 38 | default=0.00001, 39 | metavar="LR", 40 | help="learning rate") 41 | parser.add_argument("--l2", 42 | type=float, 43 | default=0.0001, 44 | metavar="L2", 45 | help="L2 regularization weight") 46 | parser.add_argument("--batch_size", 47 | type=int, 48 | default=32, 49 | metavar="BS", 50 | help="batch size") 51 | parser.add_argument("--epochs", 52 | type=int, 53 | default=100, 54 | metavar="E", 55 | help="number of epochs") 56 | parser.add_argument("--tensorboard", 57 | action="store_true", 58 | default=False, 59 | help="Enables tensorboard log") 60 | parser.add_argument("--modals", default="avl", help="modals") 61 | parser.add_argument( 62 | "--dataset", 63 | default="IEMOCAP", 64 | help="dataset to train and test.MELD/IEMOCAP/IEMOCAP4/CMUMOSEI7", 65 | ) 66 | parser.add_argument( 67 | "--textf_mode", 68 | default="textf0", 69 | help="concat4/concat2/textf0/textf1/textf2/textf3/sum2/sum4", 70 | ) 71 | 72 | parser.add_argument( 73 | "--conv_fpo", 74 | nargs="+", 75 | type=int, 76 | default=[3, 1, 1], 77 | help="n_filter,n_padding; n_out = (n_in + 2*n_padding -n_filter)/stride +1", 78 | ) 79 | 80 | parser.add_argument("--hidden_dim", type=int, default=256, help="hidden_dim") 81 | parser.add_argument( 82 | "--win", 83 | nargs="+", 84 | type=int, 85 | default=[17, 17], 86 | help="[win_p, win_f], -1 denotes all nodes", 87 | ) 88 | parser.add_argument("--heter_n_layers", 89 | nargs="+", 90 | type=int, 91 | default=[6, 6, 6], 92 | help="heter_n_layers") 93 | 94 | parser.add_argument("--drop", 95 | type=float, 96 | default=0.3, 97 | metavar="dropout", 98 | help="dropout rate") 99 | 100 | parser.add_argument("--shift_win", 101 | type=int, 102 | default=12, 103 | help="windows of sentiment shift") 104 | 105 | parser.add_argument( 106 | "--loss_type", 107 | default="emo_sen_sft", 108 | help="auto/epoch/emo_sen_sft/emo_sen/emo_sft/emo/sen_sft/sen", 109 | ) 110 | parser.add_argument( 111 | "--lambd", 112 | nargs="+", 113 | type=float, 114 | default=[1.0, 1.0, 1.0], 115 | help="[loss_emotion, loss_sentiment, loss_shift]", 116 | ) 117 | 118 | args = parser.parse_args() 119 | 120 | os.environ["MASTER_ADDR"] = "localhost" 121 | os.environ["MASTER_PORT"] = args.port 122 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 123 | world_size = torch.cuda.device_count() 124 | os.environ["WORLD_SIZE"] = str(world_size) 125 | 126 | MELD_path = "" 127 | IEMOCAP_path = "" 128 | IEMOCAP4_path = "" 129 | CMUMOSEI7_path = "" 130 | 131 | logging.basicConfig(level=logging.INFO) 132 | logger = logging.getLogger(__name__) 133 | 134 | 135 | def init_ddp(local_rank): 136 | try: 137 | if not dist.is_initialized(): 138 | torch.cuda.set_device(local_rank) 139 | os.environ["RANK"] = str(local_rank) 140 | dist.init_process_group(backend="nccl", init_method="env://") 141 | else: 142 | logger.info("Distributed process group already initialized.") 143 | except Exception as e: 144 | logger.error(f"Failed to initialize distributed process group: {e}") 145 | raise 146 | 147 | 148 | def reduce_tensor(tensor: torch.Tensor): 149 | rt = tensor.clone() 150 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 151 | rt /= dist.get_world_size() 152 | 153 | return rt 154 | 155 | 156 | def get_ddp_generator(seed=3407): 157 | local_rank = dist.get_rank() 158 | g = torch.Generator() 159 | g.manual_seed(seed + local_rank) 160 | 161 | return g 162 | 163 | 164 | def get_train_valid_sampler(trainset, valid_ratio): 165 | size = len(trainset) 166 | idx = list(range(size)) 167 | split = int(valid_ratio * size) 168 | 169 | return DistributedSampler(idx[split:]), DistributedSampler(idx[:split]) 170 | 171 | 172 | def get_data_loaders(path, dataset_class, batch_size, valid_ratio, num_workers, 173 | pin_memory): 174 | 175 | trainset = dataset_class(path) 176 | train_sampler, valid_sampler = get_train_valid_sampler( 177 | trainset, valid_ratio) 178 | train_loader = DataLoader( 179 | trainset, 180 | batch_size=batch_size, 181 | sampler=train_sampler, 182 | collate_fn=trainset.collate_fn, 183 | num_workers=num_workers, 184 | pin_memory=pin_memory, 185 | ) 186 | valid_loader = DataLoader( 187 | trainset, 188 | batch_size=batch_size, 189 | sampler=valid_sampler, 190 | collate_fn=trainset.collate_fn, 191 | num_workers=num_workers, 192 | pin_memory=pin_memory, 193 | ) 194 | testset = dataset_class(path, train=False) 195 | test_loader = DataLoader( 196 | testset, 197 | batch_size=batch_size, 198 | collate_fn=testset.collate_fn, 199 | num_workers=num_workers, 200 | pin_memory=pin_memory, 201 | ) 202 | 203 | return train_loader, valid_loader, test_loader 204 | 205 | 206 | def setup_samplers(trainset, valid_ratio, epoch): 207 | train_sampler, valid_sampler = get_train_valid_sampler( 208 | trainset, valid_ratio=valid_ratio) 209 | train_sampler.set_epoch(epoch) 210 | valid_sampler.set_epoch(epoch) 211 | 212 | 213 | def main(local_rank): 214 | print(f"Running main(**args) on rank {local_rank}.") 215 | init_ddp(local_rank) # 初始化 216 | 217 | today = datetime.datetime.now() 218 | name_ = args.modals + "_" + args.dataset 219 | 220 | cuda = torch.cuda.is_available() and not args.no_cuda 221 | if args.tensorboard: 222 | from tensorboardX import SummaryWriter 223 | 224 | writer = SummaryWriter() 225 | 226 | n_epochs = args.epochs 227 | batch_size = args.batch_size 228 | modals = args.modals 229 | 230 | if args.dataset == "IEMOCAP": 231 | embedding_dims = [1024, 342, 1582] 232 | elif args.dataset == "IEMOCAP4": 233 | embedding_dims = [1024, 512, 100] 234 | elif args.dataset == "MELD": 235 | embedding_dims = [1024, 342, 300] 236 | elif args.dataset == "CMUMOSEI7": 237 | embedding_dims = [1024, 35, 384] 238 | 239 | if args.dataset == "MELD" or args.dataset == "CMUMOSEI7": 240 | n_classes_emo = 7 241 | elif args.dataset == "IEMOCAP": 242 | n_classes_emo = 6 243 | elif args.dataset == "IEMOCAP4": 244 | n_classes_emo = 4 245 | 246 | seed_everything() 247 | model = GraphSmile(args, embedding_dims, n_classes_emo) 248 | 249 | model = model.to(local_rank) 250 | model = DDP( 251 | model, 252 | device_ids=[local_rank], 253 | output_device=local_rank, 254 | find_unused_parameters=True, 255 | ) 256 | 257 | loss_function_emo = nn.NLLLoss() 258 | loss_function_sen = nn.NLLLoss() 259 | loss_function_shift = nn.NLLLoss() 260 | 261 | if args.loss_type == "auto_loss": 262 | awl = AutomaticWeightedLoss(3) 263 | optimizer = optim.AdamW( 264 | [ 265 | { 266 | "params": model.parameters() 267 | }, 268 | { 269 | "params": awl.parameters(), 270 | "weight_decay": 0 271 | }, 272 | ], 273 | lr=args.lr, 274 | weight_decay=args.l2, 275 | amsgrad=True, 276 | ) 277 | else: 278 | optimizer = optim.AdamW(model.parameters(), 279 | lr=args.lr, 280 | weight_decay=args.l2, 281 | amsgrad=True) 282 | 283 | if args.dataset == "MELD": 284 | train_loader, valid_loader, test_loader = get_data_loaders( 285 | path=MELD_path, 286 | dataset_class=MELDDataset_BERT, 287 | valid_ratio=0.1, 288 | batch_size=batch_size, 289 | num_workers=0, 290 | pin_memory=False, 291 | ) 292 | elif args.dataset == "IEMOCAP": 293 | train_loader, valid_loader, test_loader = get_data_loaders( 294 | path=IEMOCAP_path, 295 | dataset_class=IEMOCAPDataset_BERT, 296 | valid_ratio=0.1, 297 | batch_size=batch_size, 298 | num_workers=0, 299 | pin_memory=False, 300 | ) 301 | elif args.dataset == "IEMOCAP4": 302 | train_loader, valid_loader, test_loader = get_data_loaders( 303 | path=IEMOCAP4_path, 304 | dataset_class=IEMOCAPDataset_BERT4, 305 | valid_ratio=0.1, 306 | batch_size=batch_size, 307 | num_workers=0, 308 | pin_memory=False, 309 | ) 310 | elif args.dataset == "CMUMOSEI7": 311 | train_loader, valid_loader, test_loader = get_data_loaders( 312 | path=CMUMOSEI7_path, 313 | dataset_class=CMUMOSEIDataset7, 314 | valid_ratio=0.1, 315 | batch_size=batch_size, 316 | num_workers=0, 317 | pin_memory=False, 318 | ) 319 | else: 320 | print("There is no such dataset") 321 | 322 | best_f1_emo, best_f1_sen, best_loss = None, None, None 323 | best_label_emo, best_pred_emo = None, None 324 | best_label_sen, best_pred_sen = None, None 325 | best_extracted_feats = None 326 | all_f1_emo, all_acc_emo, all_loss = [], [], [] 327 | all_f1_sen, all_acc_sen = [], [] 328 | all_f1_sft, all_acc_sft = [], [] 329 | 330 | for epoch in range(n_epochs): 331 | if args.dataset == "MELD": 332 | trainset = MELDDataset_BERT(MELD_path) 333 | elif args.dataset == "IEMOCAP": 334 | trainset = IEMOCAPDataset_BERT(IEMOCAP_path) 335 | elif args.dataset == "IEMOCAP4": 336 | trainset = IEMOCAPDataset_BERT4(IEMOCAP4_path) 337 | elif args.dataset == "CMUMOSEI7": 338 | trainset = CMUMOSEIDataset7(CMUMOSEI7_path) 339 | 340 | setup_samplers(trainset, valid_ratio=0.1, epoch=epoch) 341 | 342 | start_time = time.time() 343 | 344 | train_loss, _, _, train_acc_emo, train_f1_emo, _, _, train_acc_sen, train_f1_sen, train_acc_sft, train_f1_sft, _, _, _ = train_or_eval_model( 345 | model, 346 | loss_function_emo, 347 | loss_function_sen, 348 | loss_function_shift, 349 | train_loader, 350 | epoch, 351 | cuda, 352 | args.modals, 353 | optimizer, 354 | True, 355 | args.dataset, 356 | args.loss_type, 357 | args.lambd, 358 | args.epochs, 359 | args.classify, 360 | args.shift_win, 361 | ) 362 | 363 | valid_loss, _, _, valid_acc_emo, valid_f1_emo, _, _, valid_acc_sen, valid_f1_sen, valid_acc_sft, valid_f1_sft, _, _, _ = train_or_eval_model( 364 | model, 365 | loss_function_emo, 366 | loss_function_sen, 367 | loss_function_shift, 368 | valid_loader, 369 | epoch, 370 | cuda, 371 | args.modals, 372 | None, 373 | False, 374 | args.dataset, 375 | args.loss_type, 376 | args.lambd, 377 | args.epochs, 378 | args.classify, 379 | args.shift_win, 380 | ) 381 | 382 | print( 383 | "epoch: {}, train_loss: {}, train_acc_emo: {}, train_f1_emo: {}, valid_loss: {}, valid_acc_emo: {}, valid_f1_emo: {}" 384 | .format( 385 | epoch + 1, 386 | train_loss, 387 | train_acc_emo, 388 | train_f1_emo, 389 | valid_loss, 390 | valid_acc_emo, 391 | valid_f1_emo, 392 | )) 393 | 394 | if local_rank == 0: 395 | test_loss, test_label_emo, test_pred_emo, test_acc_emo, test_f1_emo, test_label_sen, test_pred_sen, test_acc_sen, test_f1_sen, test_acc_sft, test_f1_sft, _, test_initial_feats, test_extracted_feats = train_or_eval_model( 396 | model, 397 | loss_function_emo, 398 | loss_function_sen, 399 | loss_function_shift, 400 | test_loader, 401 | epoch, 402 | cuda, 403 | args.modals, 404 | None, 405 | False, 406 | args.dataset, 407 | args.loss_type, 408 | args.lambd, 409 | args.epochs, 410 | args.classify, 411 | args.shift_win, 412 | ) 413 | 414 | all_f1_emo.append(test_f1_emo) 415 | all_acc_emo.append(test_acc_emo) 416 | all_f1_sft.append(test_f1_sft) 417 | all_acc_sft.append(test_acc_sft) 418 | 419 | print( 420 | "test_loss: {}, test_acc_emo: {}, test_f1_emo: {}, test_acc_sen: {}, test_f1_sen: {}, test_acc_sft: {}, test_f1_sft: {}, total time: {} sec, {}" 421 | .format( 422 | test_loss, 423 | test_acc_emo, 424 | test_f1_emo, 425 | test_acc_sen, 426 | test_f1_sen, 427 | test_acc_sft, 428 | test_f1_sft, 429 | round(time.time() - start_time, 2), 430 | time.strftime("%Y-%m-%d %H:%M:%S", 431 | time.localtime(time.time())), 432 | )) 433 | print("-" * 100) 434 | 435 | if args.classify == "emotion": 436 | if best_f1_emo == None or best_f1_emo < test_f1_emo: 437 | best_f1_emo = test_f1_emo 438 | best_f1_sen = test_f1_sen 439 | best_label_emo, best_pred_emo = test_label_emo, test_pred_emo 440 | best_label_sen, best_pred_sen = test_label_sen, test_pred_sen 441 | 442 | elif args.classify == "sentiment": 443 | if best_f1_sen == None or best_f1_sen < test_f1_sen: 444 | best_f1_emo = test_f1_emo 445 | best_f1_sen = test_f1_sen 446 | best_label_emo, best_pred_emo = test_label_emo, test_pred_emo 447 | best_label_sen, best_pred_sen = test_label_sen, test_pred_sen 448 | 449 | if (epoch + 1) % 10 == 0: 450 | np.set_printoptions(suppress=True) 451 | print( 452 | classification_report(best_label_emo, 453 | best_pred_emo, 454 | digits=4, 455 | zero_division=0)) 456 | print(confusion_matrix(best_label_emo, best_pred_emo)) 457 | print( 458 | classification_report(best_label_sen, 459 | best_pred_sen, 460 | digits=4, 461 | zero_division=0)) 462 | print(confusion_matrix(best_label_sen, best_pred_sen)) 463 | print("-" * 100) 464 | 465 | dist.barrier() 466 | 467 | if args.tensorboard: 468 | writer.add_scalar("test: accuracy", test_acc_emo, epoch) 469 | writer.add_scalar("test: fscore", test_f1_emo, epoch) 470 | writer.add_scalar("train: accuracy", train_acc_emo, epoch) 471 | writer.add_scalar("train: fscore", train_f1_emo, epoch) 472 | 473 | if epoch == 1: 474 | allocated_memory = torch.cuda.memory_allocated() 475 | reserved_memory = torch.cuda.memory_reserved() 476 | print(f"Allocated Memory: {allocated_memory / 1024**2:.2f} MB") 477 | print(f"Reserved Memory: {reserved_memory / 1024**2:.2f} MB") 478 | print( 479 | f"All Memory: {(allocated_memory + reserved_memory) / 1024**2:.2f} MB" 480 | ) 481 | 482 | if args.tensorboard: 483 | writer.close() 484 | if local_rank == 0: 485 | print("Test performance..") 486 | print("Acc: {}, F-Score: {}".format(max(all_acc_emo), max(all_f1_emo))) 487 | if not os.path.exists("results/record_{}_{}_{}.pk".format( 488 | today.year, today.month, today.day)): 489 | with open( 490 | "results/record_{}_{}_{}.pk".format( 491 | today.year, today.month, today.day), 492 | "wb", 493 | ) as f: 494 | pk.dump({}, f) 495 | with open( 496 | "results/record_{}_{}_{}.pk".format(today.year, today.month, 497 | today.day), 498 | "rb", 499 | ) as f: 500 | record = pk.load(f) 501 | key_ = name_ 502 | if record.get(key_, False): 503 | record[key_].append(max(all_f1_emo)) 504 | else: 505 | record[key_] = [max(all_f1_emo)] 506 | if record.get(key_ + "record", False): 507 | record[key_ + "record"].append( 508 | classification_report(best_label_emo, 509 | best_pred_emo, 510 | digits=4, 511 | zero_division=0)) 512 | else: 513 | record[key_ + "record"] = [ 514 | classification_report(best_label_emo, 515 | best_pred_emo, 516 | digits=4, 517 | zero_division=0) 518 | ] 519 | with open( 520 | "results/record_{}_{}_{}.pk".format(today.year, today.month, 521 | today.day), 522 | "wb", 523 | ) as f: 524 | pk.dump(record, f) 525 | 526 | print( 527 | classification_report(best_label_emo, 528 | best_pred_emo, 529 | digits=4, 530 | zero_division=0)) 531 | print(confusion_matrix(best_label_emo, best_pred_emo)) 532 | 533 | dist.destroy_process_group() 534 | 535 | 536 | if __name__ == "__main__": 537 | print(args) 538 | print("torch.cuda.is_available():", torch.cuda.is_available()) 539 | print("not args.no_cuda:", not args.no_cuda) 540 | n_gpus = torch.cuda.device_count() 541 | print(f"Use {n_gpus} GPUs") 542 | mp.spawn(fn=main, args=(), nprocs=n_gpus) 543 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np, random 2 | import torch 3 | from sklearn.metrics import f1_score, accuracy_score 4 | import torch.nn.functional as F 5 | from module import build_match_sen_shift_label 6 | from utils import AutomaticWeightedLoss 7 | 8 | seed = 2024 9 | 10 | 11 | def seed_everything(seed=seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.benchmark = False 18 | torch.backends.cudnn.deterministic = True 19 | 20 | 21 | def train_or_eval_model( 22 | model, 23 | loss_function_emo, 24 | loss_function_sen, 25 | loss_function_shift, 26 | dataloader, 27 | epoch, 28 | cuda, 29 | modals, 30 | optimizer=None, 31 | train=False, 32 | dataset="IEMOCAP", 33 | loss_type="", 34 | lambd=[1.0, 1.0, 1.0], 35 | epochs=100, 36 | classify="", 37 | shift_win=5, 38 | ): 39 | losses, preds_emo, labels_emo = [], [], [] 40 | preds_sft, labels_sft = [], [] 41 | preds_sen, labels_sen = [], [] 42 | vids = [] 43 | initial_feats, extracted_feats = [], [] 44 | 45 | assert not train or optimizer != None 46 | if train: 47 | model.train() 48 | else: 49 | model.eval() 50 | 51 | seed_everything() 52 | for iter, data in enumerate(dataloader): 53 | 54 | if train: 55 | optimizer.zero_grad() 56 | 57 | textf0, textf1, textf2, textf3, visuf, acouf, qmask, umask, label_emotion, label_sentiment = ( 58 | [d.cuda() for d in data[:-1]] if cuda else data[:-1]) 59 | 60 | dia_lengths, label_emotions, label_sentiments = [], [], [] 61 | for j in range(umask.size(1)): 62 | dia_lengths.append((umask[:, j] == 1).nonzero().tolist()[-1][0] + 63 | 1) 64 | label_emotions.append(label_emotion[:dia_lengths[j], j]) 65 | label_sentiments.append(label_sentiment[:dia_lengths[j], j]) 66 | label_emo = torch.cat(label_emotions) 67 | label_sen = torch.cat(label_sentiments) 68 | 69 | logit_emo, logit_sen, logit_sft, extracted_feature = model( 70 | textf0, textf1, textf2, textf3, visuf, acouf, umask, qmask, 71 | dia_lengths) 72 | 73 | prob_emo = F.log_softmax(logit_emo, -1) 74 | loss_emo = loss_function_emo(prob_emo, label_emo) 75 | prob_sen = F.log_softmax(logit_sen, -1) 76 | loss_sen = loss_function_sen(prob_sen, label_sen) 77 | prob_sft = F.log_softmax(logit_sft, -1) 78 | label_sft = build_match_sen_shift_label(shift_win, dia_lengths, 79 | label_sen) 80 | loss_sft = loss_function_shift(prob_sft, label_sft) 81 | 82 | if loss_type == "auto": 83 | awl = AutomaticWeightedLoss(3) 84 | loss = awl(loss_emo, loss_sen, loss_sft) 85 | elif loss_type == "epoch": 86 | loss = (epoch / epochs) * (lambd[0] * loss_emo) + ( 87 | 1 - epoch / epochs) * (lambd[1] * loss_sen + 88 | lambd[2] * loss_sft) 89 | elif loss_type == "emo_sen_sft": 90 | loss = lambd[0] * loss_emo + lambd[1] * loss_sen + lambd[ 91 | 2] * loss_sft 92 | elif loss_type == "emo_sen": 93 | loss = lambd[0] * loss_emo + lambd[1] * loss_sen 94 | elif loss_type == "emo_sft": 95 | loss = lambd[0] * loss_emo + lambd[2] * loss_sft 96 | elif loss_type == "emo": 97 | loss = loss_emo 98 | elif loss_type == "sen_sft": 99 | loss = lambd[1] * loss_sen + lambd[2] * loss_sft 100 | elif loss_type == "sen": 101 | loss = loss_sen 102 | else: 103 | NotImplementedError 104 | 105 | preds_emo.append(torch.argmax(prob_emo, 1).cpu().numpy()) 106 | labels_emo.append(label_emo.cpu().numpy()) 107 | preds_sen.append(torch.argmax(prob_sen, 1).cpu().numpy()) 108 | labels_sen.append(label_sen.cpu().numpy()) 109 | preds_sft.append(torch.argmax(prob_sft, 1).cpu().numpy()) 110 | labels_sft.append(label_sft.cpu().numpy()) 111 | losses.append(loss.item()) 112 | 113 | if train: 114 | loss.backward() 115 | optimizer.step() 116 | 117 | extracted_feats.append(extracted_feature.cpu().detach().numpy()) 118 | 119 | if preds_emo != []: 120 | preds_emo = np.concatenate(preds_emo) 121 | labels_emo = np.concatenate(labels_emo) 122 | preds_sen = np.concatenate(preds_sen) 123 | labels_sen = np.concatenate(labels_sen) 124 | preds_sft = np.concatenate(preds_sft) 125 | labels_sft = np.concatenate(labels_sft) 126 | 127 | extracted_feats = np.concatenate(extracted_feats) 128 | 129 | vids += data[-1] 130 | labels_emo = np.array(labels_emo) 131 | preds_emo = np.array(preds_emo) 132 | labels_sen = np.array(labels_sen) 133 | preds_sen = np.array(preds_sen) 134 | labels_sft = np.array(labels_sft) 135 | preds_sft = np.array(preds_sft) 136 | vids = np.array(vids) 137 | 138 | extracted_feats = np.array(extracted_feats) 139 | 140 | avg_loss = round(np.sum(losses) / len(losses), 4) 141 | avg_acc_emo = round(accuracy_score(labels_emo, preds_emo) * 100, 2) 142 | avg_f1_emo = round( 143 | f1_score(labels_emo, preds_emo, average="weighted") * 100, 2) 144 | avg_acc_sen = round(accuracy_score(labels_sen, preds_sen) * 100, 2) 145 | avg_f1_sen = round( 146 | f1_score(labels_sen, preds_sen, average="weighted") * 100, 2) 147 | avg_acc_sft = round(accuracy_score(labels_sft, preds_sft) * 100, 2) 148 | avg_f1_sft = round( 149 | f1_score(labels_sft, preds_sft, average="weighted") * 100, 2) 150 | 151 | return avg_loss, labels_emo, preds_emo, avg_acc_emo, avg_f1_emo, labels_sen, preds_sen, avg_acc_sen, avg_f1_sen, avg_acc_sft, avg_f1_sft, vids, initial_feats, extracted_feats 152 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def batch_to_all_tva(feature_t, feature_v, feature_a, lengths, no_cuda): 6 | 7 | node_feature_t, node_feature_v, node_feature_a = [], [], [] 8 | batch_size = feature_t.size(1) 9 | 10 | for j in range(batch_size): 11 | node_feature_t.append(feature_t[:lengths[j], j, :]) 12 | node_feature_v.append(feature_v[:lengths[j], j, :]) 13 | node_feature_a.append(feature_a[:lengths[j], j, :]) 14 | 15 | node_feature_t = torch.cat(node_feature_t, dim=0) 16 | node_feature_v = torch.cat(node_feature_v, dim=0) 17 | node_feature_a = torch.cat(node_feature_a, dim=0) 18 | 19 | if not no_cuda: 20 | node_feature_t = node_feature_t.cuda() 21 | node_feature_v = node_feature_v.cuda() 22 | node_feature_a = node_feature_a.cuda() 23 | 24 | return node_feature_t, node_feature_v, node_feature_a 25 | 26 | 27 | class AutomaticWeightedLoss(nn.Module): 28 | """automatically weighted multi-task loss 29 | Params: 30 | num: int, the number of loss 31 | x: multi-task loss 32 | Examples: 33 | loss1=1 34 | loss2=2 35 | awl = AutomaticWeightedLoss(2) 36 | loss_sum = awl(loss1, loss2) 37 | """ 38 | 39 | def __init__(self, num=2): 40 | super(AutomaticWeightedLoss, self).__init__() 41 | params = torch.ones(num, requires_grad=True) 42 | self.params = torch.nn.Parameter(params) 43 | 44 | def forward(self, *x): 45 | loss_sum = 0 46 | for i, loss in enumerate(x): 47 | loss_sum += 0.5 / (self.params[i]** 48 | 2) * loss + torch.log(1 + self.params[i]**2) 49 | return loss_sum 50 | --------------------------------------------------------------------------------