├── README.md ├── dataloader.py ├── dataset.py ├── environment.yml ├── main.py ├── model.py ├── model_utils.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | The code for ACL2023 paper: 《DualGATs: Dual Graph Attention Networks for Emotion Recognition in Conversations》 2 | 3 | 4 | ## Requirements 5 | 6 | - Python 3.6.13 7 | - PyTorch 1.7.1+cu110 8 | 9 | 10 | With Anaconda, we can create the environment with the provided `environment.yml`: 11 | 12 | ```bash 13 | conda env create --file environment.yml 14 | conda activate MMERC 15 | ``` 16 | 17 | The code has been tested on Ubuntu 16.04 using a single GPU. 18 |
19 | 20 | ## Run Steps 21 | 22 | 1. Please download the four ERC datasets (including pre-processed discourse graphs and RoBERTa utterance feature) and put them in the data folder. Here we utilize the data and codes from [here](https://github.com/shizhouxing/DialogueDiscourseParsing) to pre-train a conversation discourse parser and use that parser to extract discourse graphs in the four ERC datasets. And we utilize the codes from [here](https://github.com/declare-lab/conv-emotion/tree/master/COSMIC) to extract utterance feature. 23 | 2. Run our model: 24 | 25 | ```bash 26 | # For IEMOCAP: 27 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset IEMOCAP --lr 1e-4 --dropout 0.2 --batch_size 16 --gnn_layers 2 28 | # For MELD: 29 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset MELD --lr 1e-4 --dropout 0.3 --batch_size 32 --gnn_layers 2 30 | # For EmoryNLP: 31 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset EmoryNLP --lr 1e-4 --dropout 0.1 --batch_size 32 --gnn_layers 2 32 | # For DailyDialog: 33 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset DailyDialog --lr 5e-5 --dropout 0.4 --batch_size 64 --gnn_layers 3 34 | ``` 35 | 36 | ## Citation 37 | 38 | ``` 39 | @inproceedings{zhang2023dualgats, 40 | title={DualGATs: Dual Graph Attention Networks for Emotion Recognition in Conversations}, 41 | author={Zhang, Duzhen and Chen, Feilong and Chen, Xiuyi}, 42 | booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 43 | pages={7395--7408}, 44 | year={2023} 45 | } 46 | ``` 47 | 48 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from dataset import * 2 | from torch.utils.data.sampler import SubsetRandomSampler 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def get_train_valid_sampler(trainset): 7 | size = len(trainset) # 对话数量 n 8 | idx = list(range(size)) # [0,1,2,...,n-1] 9 | return SubsetRandomSampler(idx) # 无放回地按照给定的索引列表采样样本元素。 那么这里就相当于抽取了一个全排列 10 | 11 | 12 | 13 | def get_data_loaders(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None): 14 | 15 | print('building datasets..') 16 | trainset = MyDataset(dataset_name, 'train', args) 17 | devset = MyDataset(dataset_name, 'dev', args) 18 | 19 | train_sampler = get_train_valid_sampler(trainset) 20 | valid_sampler = get_train_valid_sampler(devset) 21 | 22 | train_loader = DataLoader(trainset, 23 | batch_size=batch_size, 24 | sampler=train_sampler, 25 | collate_fn=trainset.collate_fn, 26 | num_workers=num_workers, 27 | pin_memory=pin_memory) # 和不用sampler shuffle=True 是等价的 28 | 29 | valid_loader = DataLoader(devset, 30 | batch_size=batch_size, 31 | sampler=valid_sampler, 32 | collate_fn=devset.collate_fn, 33 | num_workers=num_workers, 34 | pin_memory=pin_memory) 35 | 36 | testset = MyDataset(dataset_name, 'test',args) 37 | test_loader = DataLoader(testset, 38 | batch_size=batch_size, 39 | collate_fn=testset.collate_fn, 40 | num_workers=num_workers, 41 | pin_memory=pin_memory) 42 | 43 | return train_loader, valid_loader, test_loader -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from turtle import pd 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torch.nn.utils.rnn import pad_sequence 5 | import random 6 | 7 | import pickle 8 | def read_pickle(filename): 9 | try: 10 | with open(filename,'rb') as f: 11 | data = pickle.load(f) 12 | except: 13 | with open(filename,'rb') as f: 14 | data = pickle.load(f,encoding='latin1') 15 | return data 16 | 17 | class MyDataset(Dataset): 18 | 19 | def __init__(self, dataset_name = 'IEMOCAP', split = 'train', args = None): 20 | self.args = args 21 | self.dataset_name = dataset_name 22 | 23 | 24 | if dataset_name == 'IEMOCAP': 25 | self.videoSpeakers, self.videoLabels, roberta_feature, \ 26 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle('./data/IEMOCAP_Features.pkl') 27 | 28 | self.utterance_feature = roberta_feature 29 | 30 | elif dataset_name == 'MELD': 31 | self.videoSpeakers, self.videoLabels, roberta_feature, \ 32 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle('./data/MELD_Features.pkl') 33 | 34 | self.utterance_feature = roberta_feature 35 | 36 | elif dataset_name == 'DailyDialog': 37 | self.videoSpeakers, self.videoLabels, roberta_feature, \ 38 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle( 39 | './data/DailyDialog_Features.pkl') 40 | 41 | self.utterance_feature = roberta_feature 42 | 43 | elif dataset_name == 'EmoryNLP': 44 | self.videoSpeakers, self.videoLabels, roberta_feature, \ 45 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle( 46 | './data/EmoryNLP_Features.pkl') 47 | 48 | self.utterance_feature = roberta_feature 49 | 50 | 51 | 52 | self.data = self.read(split) 53 | print(split+' dialogue num:') 54 | print(len(self.data)) # 对话数量 55 | 56 | self.len = len(self.data) 57 | 58 | def read(self, split): 59 | 60 | # process dialogue 61 | if split=='train': 62 | dialog_ids = self.trainVid 63 | elif split=='dev': 64 | dialog_ids = self.validVid 65 | elif split=='test': 66 | dialog_ids = self.testVid 67 | 68 | dialogs = [] 69 | for dialog_id in dialog_ids: 70 | utterances = self.videoSentence[dialog_id] 71 | labels = self.videoLabels[dialog_id] 72 | if self.dataset_name == 'IEMOCAP': 73 | speakers = self.videoSpeakers[dialog_id] 74 | elif self.dataset_name == 'MELD': 75 | speakers = [speaker.index(1) for speaker in self.videoSpeakers[dialog_id]] 76 | elif self.dataset_name == 'DailyDialog': 77 | speakers = [int(speaker) for speaker in self.videoSpeakers[dialog_id]] 78 | elif self.dataset_name == 'EmoryNLP_small' or 'EmoryNLP_big': 79 | speakers = self.videoSpeakers[dialog_id] 80 | 81 | 82 | utterance_features = [item.tolist() for item in self.utterance_feature[dialog_id]] 83 | utterance_links = self.links[dialog_id] 84 | utterance_relations = self.relations[dialog_id] 85 | 86 | dialogs.append({ 87 | 'id':dialog_id, 88 | 'utterances': utterances, 89 | 'labels': labels, 90 | 'speakers': speakers, 91 | 'utterance_features': utterance_features, 92 | 'utterance_links': utterance_links, 93 | 'utterance_relations': utterance_relations 94 | }) 95 | 96 | 97 | random.shuffle(dialogs) # 打乱对话 98 | return dialogs 99 | 100 | def __getitem__(self, index): # 获取一个样本/ 对话 101 | ''' 102 | :param index: 103 | :return: 104 | feature, 105 | label 106 | speaker 107 | length 108 | text 109 | ''' 110 | return torch.FloatTensor(self.data[index]['utterance_features']), \ 111 | torch.LongTensor(self.data[index]['labels']),\ 112 | self.data[index]['speakers'], \ 113 | self.data[index]['utterance_links'],\ 114 | self.data[index]['utterance_relations'],\ 115 | len(self.data[index]['labels']), \ 116 | self.data[index]['utterances'],\ 117 | self.data[index]['id'] 118 | 119 | def __len__(self): 120 | return self.len 121 | 122 | def get_semantic_adj(self, speakers, max_dialog_len): 123 | 124 | semantic_adj = [] 125 | for speaker in speakers: # 遍历每个对话 对应的说话人列表(非去重) 126 | s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long) # (N,N) 0 表示填充部分 没有语义关系 127 | for i in range(len(speaker)): # 每个utterance 的说话人 和 其他 utterance 的说话人 是否相同 128 | for j in range(len(speaker)): 129 | if speaker[i] == speaker[j]: 130 | if i==j: 131 | s[i,j] = 1 # 对角线 self 132 | elif i < j: 133 | s[i,j] = 2 # self-future 134 | else: 135 | s[i,j] =3 # self-past 136 | else: 137 | if ij: 140 | s[i,j] = 5 # inter-past 141 | 142 | 143 | semantic_adj.append(s) 144 | 145 | return torch.stack(semantic_adj) 146 | 147 | 148 | def get_structure_adj(self, links, relations, lengths, max_dialog_len): 149 | ''' 150 | map_relations = {'Comment': 0, 'Contrast': 1, 'Correction': 2, 'Question-answer_pair': 3, 'QAP': 3, 'Parallel': 4, 'Acknowledgement': 5, 151 | 'Elaboration': 6, 'Clarification_question': 7, 'Conditional': 8, 'Continuation': 9, 'Result': 10, 'Explanation': 11, 152 | 'Q-Elab': 12, 'Alternation': 13, 'Narration': 14, 'Background': 15} 153 | 154 | ''' 155 | structure_adj = [] 156 | 157 | for link,relation,length in zip(links,relations,lengths): 158 | s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long) # (N,N) 0 表示填充部分 或 没有关系 159 | assert len(link)==len(relation) 160 | 161 | for index, (i,j) in enumerate(link): 162 | s[i,j] = relation[index] + 1 163 | s[j,i] = s[i,j] # 变成对称矩阵 164 | 165 | 166 | 167 | for i in range(length): # 填充对角线 168 | s[i,i] = 17 169 | 170 | structure_adj.append(s) 171 | 172 | return torch.stack(structure_adj) 173 | 174 | 175 | 176 | def collate_fn(self, data): # data 是一个batch 的对话 获取一批样本/对话 并填充 177 | ''' 178 | :param data: 179 | utterance_features, labels, speakers, utterance_links, utterance_relations, length, texts,id 180 | :return: 181 | text_features: (B, N, D) padded 182 | 183 | labels: (B, N) padded 184 | adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i 185 | s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker 186 | lengths: (B, ) 187 | utterances: not a tensor 188 | ''' 189 | max_dialog_len = max([d[5] for d in data]) # batch 中 对话的最大长度 N 190 | 191 | utterance_features = pad_sequence([d[0] for d in data], batch_first = True) # (B, N, D) 192 | 193 | labels = pad_sequence([d[1] for d in data], batch_first = True, padding_value = -1) # (B, N ) label 填充值为 -1 194 | 195 | semantic_adj = self.get_semantic_adj([d[2] for d in data], max_dialog_len) 196 | 197 | structure_adj = self.get_structure_adj([d[3] for d in data], [d[4] for d in data], [d[5] for d in data],max_dialog_len) 198 | 199 | lengths = torch.LongTensor([d[5] for d in data]) # batch 每个对话的长度 200 | 201 | speakers = pad_sequence([torch.LongTensor(d[2]) for d in data], batch_first = True, padding_value = -1) # (B, N) speaker 填充值为 -1 202 | utterances = [d[6] for d in data] # batch 中每个对话对应的 utterance 文本 203 | ids = [d[7] for d in data] # batch 中每个对话对应的id 204 | 205 | return utterance_features, labels, semantic_adj, structure_adj, lengths, speakers, utterances, ids 206 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: MMERC 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - ca-certificates=2021.7.5=h06a4308_1 8 | - certifi=2021.5.30=py36h06a4308_0 9 | - ld_impl_linux-64=2.35.1=h7274673_9 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.3.0=h5101ec6_17 12 | - libgomp=9.3.0=h5101ec6_17 13 | - libstdcxx-ng=9.3.0=hd4cf53a_17 14 | - ncurses=6.2=he6710b0_1 15 | - openssl=1.1.1k=h27cfd23_0 16 | - pip=21.2.2=py36h06a4308_0 17 | - python=3.6.13=h12debd9_1 18 | - readline=8.1=h27cfd23_0 19 | - setuptools=52.0.0=py36h06a4308_0 20 | - sqlite=3.36.0=hc218d9a_0 21 | - tk=8.6.10=hbc83047_0 22 | - wheel=0.36.2=pyhd3eb1b0_0 23 | - xz=5.2.5=h7b6447c_0 24 | - zlib=1.2.11=h7b6447c_3 25 | - pip: 26 | - absl-py==0.13.0 27 | - anykeystore==0.2 28 | - apex==0.1 29 | - astunparse==1.6.3 30 | - boto3==1.18.14 31 | - botocore==1.21.14 32 | - cached-property==1.5.2 33 | - cachetools==4.2.2 34 | - cffi==1.15.0 35 | - chardet==3.0.4 36 | - click==8.0.3 37 | - cloudpickle==2.0.0 38 | - cryptacular==1.6.2 39 | - cycler==0.10.0 40 | - cython==0.29.24 41 | - dataclasses==0.8 42 | - decorator==4.4.2 43 | - defusedxml==0.7.1 44 | - easydict==1.9 45 | - fasteners==0.17.3 46 | - filelock==3.4.0 47 | - flatbuffers==1.12 48 | - gast==0.4.0 49 | - glfw==2.5.3 50 | - google-auth==1.34.0 51 | - google-auth-oauthlib==0.4.5 52 | - google-pasta==0.2.0 53 | - googledrivedownloader==0.4 54 | - greenlet==1.1.2 55 | - grpcio==1.34.1 56 | - gym==0.21.0 57 | - h5py==3.1.0 58 | - huggingface-hub==0.2.0 59 | - hupper==1.10.3 60 | - idna==2.8 61 | - imageio==2.9.0 62 | - importlib-metadata==4.8.3 63 | - inplace-abn==1.1.1.dev6+gd2728c8 64 | - isodate==0.6.1 65 | - jinja2==3.0.3 66 | - jmespath==0.10.0 67 | - joblib==1.1.0 68 | - jsonlines==2.0.0 69 | - keras-nightly==2.5.0.dev2021032900 70 | - keras-preprocessing==1.1.2 71 | - kiwisolver==1.3.1 72 | - llvmlite==0.36.0 73 | - lmdb==1.2.1 74 | - markdown==3.3.4 75 | - markupsafe==2.0.1 76 | - matplotlib==3.3.4 77 | - mujoco-py==2.1.2.14 78 | - networkx==2.5.1 79 | - nltk==3.6.7 80 | - numba==0.53.1 81 | - numpy==1.19.5 82 | - oauthlib==3.1.1 83 | - opencv-python==4.5.3.56 84 | - opt-einsum==3.3.0 85 | - packaging==21.3 86 | - pandas==1.1.5 87 | - pastedeploy==2.1.1 88 | - pbkdf2==1.3 89 | - pillow==8.3.1 90 | - plaster==1.0 91 | - plaster-pastedeploy==0.7 92 | - plyfile==0.7.4 93 | - protobuf==3.10.0 94 | - pyasn1==0.4.8 95 | - pyasn1-modules==0.2.8 96 | - pycocotools==2.0.2 97 | - pycparser==2.21 98 | - pyparsing==2.4.7 99 | - pyramid==2.0 100 | - pyramid-mailer==0.15.1 101 | - python-dateutil==2.8.2 102 | - python3-openid==3.2.0 103 | - pytz==2021.1 104 | - pywavelets==1.1.1 105 | - pyyaml==5.4.1 106 | - rdflib==5.0.0 107 | - regex==2022.6.2 108 | - repoze-sendmail==4.4.1 109 | - requests==2.22.0 110 | - requests-oauthlib==1.3.0 111 | - rsa==4.7.2 112 | - s3transfer==0.5.0 113 | - sacremoses==0.0.46 114 | - scikit-image==0.17.2 115 | - scikit-learn==0.24.2 116 | - scipy==1.5.4 117 | - setuptools-scm==6.4.2 118 | - six==1.15.0 119 | - sqlalchemy==1.4.39 120 | - tensorboard==2.5.0 121 | - tensorboard-data-server==0.6.1 122 | - tensorboard-plugin-wit==1.8.0 123 | - tensorboardx==2.4 124 | - tensorflow==2.5.0 125 | - tensorflow-estimator==2.5.0 126 | - termcolor==1.1.0 127 | - threadpoolctl==3.0.0 128 | - tifffile==2020.9.3 129 | - tokenizers==0.10.3 130 | - tomli==1.2.3 131 | - torch==1.7.1+cu110 132 | - torch-geometric==2.0.3 133 | - torch-scatter==2.0.5 134 | - torch-sparse==0.6.8 135 | - torchaudio==0.7.2 136 | - torchvision==0.8.2+cu110 137 | - tqdm==4.62.0 138 | - transaction==3.0.1 139 | - transformers==4.12.5 140 | - translationstring==1.4 141 | - typing-extensions==3.7.4.3 142 | - urllib3==1.25.11 143 | - velruse==1.1.1 144 | - venusian==3.0.0 145 | - webob==1.8.7 146 | - werkzeug==2.0.1 147 | - wrapt==1.12.1 148 | - wtforms==3.0.0 149 | - wtforms-recaptcha==0.3.2 150 | - yacs==0.1.8 151 | - zipp==3.5.0 152 | - zope-deprecation==4.4.0 153 | - zope-interface==5.4.0 154 | - zope-sqlalchemy==1.6 155 | prefix: /home/cfl/anaconda3/envs/MMERC 156 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np, argparse, time, random 3 | 4 | 5 | from model import * 6 | 7 | from trainer import train_or_eval_model 8 | 9 | from dataloader import get_data_loaders 10 | from transformers import AdamW 11 | 12 | 13 | 14 | def seed_everything(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | torch.backends.cudnn.benchmark = False 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | def str2bool(v): 25 | """ Usage: 26 | parser.add_argument('--pretrained', type=str2bool, nargs='?', const=True, 27 | dest='pretrained', help='Whether to use pretrained models.') 28 | """ 29 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 30 | return True 31 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 32 | return False 33 | else: 34 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | path = './saved_models/' # 日志 模型保存路径 41 | 42 | parser = argparse.ArgumentParser() 43 | 44 | parser.add_argument('--hidden_dim', type=int, default=300) 45 | parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.') 46 | 47 | parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.') 48 | parser.add_argument('--dataset_name', default='MELD', type=str, help='dataset name, IEMOCAP, MELD, DailyDialog, EmoryNLP') 49 | parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.') 50 | parser.add_argument('--epochs', type=int, default=60, metavar='E', help='number of epochs') 51 | 52 | parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.') 53 | 54 | parser.add_argument('--lr', type=float, default=5e-5, metavar='LR', help='learning rate') ##### 55 | parser.add_argument('--dropout', type=float, default=0.4, metavar='dropout', help='dropout rate') 56 | parser.add_argument('--batch_size', type=int, default=64, metavar='BS', help='batch size') ## 57 | 58 | 59 | parser.add_argument('--seed', type=int, default=100, help='random seed') ## 60 | 61 | 62 | args = parser.parse_args() 63 | 64 | print(args) 65 | 66 | # 固定随机种子 67 | seed_everything(args.seed) 68 | 69 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 70 | print("device:", args.device) 71 | 72 | device = args.device 73 | n_epochs = args.epochs 74 | batch_size = args.batch_size 75 | 76 | 77 | train_loader, valid_loader, test_loader = get_data_loaders( 78 | dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args=args) 79 | 80 | 81 | if 'IEMOCAP' in args.dataset_name: 82 | n_classes = 6 83 | else: 84 | n_classes = 7 85 | 86 | print('building model..') 87 | 88 | model = DualGATs(args, n_classes) 89 | 90 | if torch.cuda.device_count() > 1: 91 | print('Multi-GPU...........') 92 | model = nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 93 | 94 | model.to(device) 95 | 96 | 97 | loss_function = nn.CrossEntropyLoss(ignore_index=-1) # 忽略掉label=-1 的类 98 | 99 | 100 | optimizer = AdamW(model.parameters(), lr=args.lr) 101 | 102 | best_fscore, best_acc, best_loss, best_label, best_pred, best_mask = None, None, None, None, None, None 103 | all_fscore, all_acc, all_loss = [], [], [] 104 | best_acc = 0. 105 | best_fscore = 0. 106 | 107 | best_model = None 108 | for e in range(n_epochs): # 遍历每个epoch 109 | start_time = time.time() 110 | 111 | train_loss, train_acc, _, _, train_fscore = train_or_eval_model(model, loss_function, 112 | train_loader, device, 113 | args, optimizer, True) 114 | valid_loss, valid_acc, _, _, valid_fscore = train_or_eval_model(model, loss_function, 115 | valid_loader, device, args) 116 | test_loss, test_acc, test_label, test_pred, test_fscore = train_or_eval_model(model, loss_function, 117 | test_loader, device, args) 118 | 119 | all_fscore.append([valid_fscore, test_fscore]) 120 | 121 | print( 122 | 'Epoch: {}, train_loss: {}, train_acc: {}, train_fscore: {}, valid_loss: {}, valid_acc: {}, valid_fscore: {}, test_loss: {}, test_acc: {}, test_fscore: {}, time: {} sec'. \ 123 | format(e + 1, train_loss, train_acc, train_fscore, valid_loss, valid_acc, valid_fscore, test_loss, 124 | test_acc, 125 | test_fscore, round(time.time() - start_time, 2))) 126 | 127 | e += 1 128 | 129 | 130 | print('finish training!') 131 | 132 | 133 | all_fscore = sorted(all_fscore, key=lambda x: (x[0], x[1]), reverse=True) # 优先按照验证集 f1 进行排序 134 | 135 | print('Best val F-Score:{}'.format(all_fscore[0][0])) # 验证集最好性能 136 | print('Best test F-Score based on validation:{}'.format(all_fscore[0][1])) # 验证集取得最好性能时 对应测试集的下性能 137 | print('Best test F-Score based on test:{}'.format(max([f[1] for f in all_fscore]))) # 测试集 最好的性能 138 | 139 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from model_utils import RGAT, DiffLoss 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DualGATs(nn.Module): 8 | 9 | def __init__(self, args, num_class): 10 | super().__init__() 11 | self.args = args 12 | 13 | self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) 14 | 15 | SpkGAT = [] 16 | DisGAT = [] 17 | for _ in range(args.gnn_layers): 18 | SpkGAT.append(RGAT(args, args.hidden_dim, args.hidden_dim, dropout=args.dropout, num_relation=6)) 19 | DisGAT.append(RGAT(args, args.hidden_dim, args.hidden_dim, dropout=args.dropout, num_relation=18)) 20 | 21 | self.SpkGAT = nn.ModuleList(SpkGAT) 22 | self.DisGAT = nn.ModuleList(DisGAT) 23 | 24 | 25 | self.affine1 = nn.Parameter(torch.empty(size=(args.hidden_dim, args.hidden_dim))) 26 | nn.init.xavier_uniform_(self.affine1.data, gain=1.414) 27 | self.affine2 = nn.Parameter(torch.empty(size=(args.hidden_dim, args.hidden_dim))) 28 | nn.init.xavier_uniform_(self.affine2.data, gain=1.414) 29 | 30 | self.diff_loss = DiffLoss(args) 31 | self.beta = 0.3 32 | 33 | in_dim = args.hidden_dim *2 + args.emb_dim 34 | # output mlp layers 35 | layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] 36 | for _ in range(args.mlp_layers - 1): 37 | layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] 38 | layers += [nn.Linear(args.hidden_dim, num_class)] 39 | 40 | self.out_mlp = nn.Sequential(*layers) 41 | 42 | self.drop = nn.Dropout(args.dropout) 43 | 44 | 45 | 46 | def forward(self, utterance_features, semantic_adj, structure_adj): 47 | ''' 48 | :param tutterance_features: (B, N, emb_dim) 49 | :param xx_adj: (B, N, N) 50 | :return: 51 | ''' 52 | batch_size = utterance_features.size(0) 53 | H0 = F.relu(self.fc1(utterance_features)) # (B, N, hidden_dim) 54 | H = [H0] 55 | diff_loss = 0 56 | for l in range(self.args.gnn_layers): 57 | if l==0: 58 | H1_semantic = self.SpkGAT[l](H[l], semantic_adj) 59 | H1_structure = self.DisGAT[l](H[l], structure_adj) 60 | else: 61 | H1_semantic = self.SpkGAT[l](H[2*l-1], semantic_adj) 62 | H1_structure = self.DisGAT[l](H[2*l], structure_adj) 63 | 64 | 65 | diff_loss = diff_loss + self.diff_loss(H1_semantic, H1_structure) 66 | # BiAffine 67 | 68 | A1 = F.softmax(torch.bmm(torch.matmul(H1_semantic, self.affine1), torch.transpose(H1_structure, 1, 2)), dim=-1) 69 | A2 = F.softmax(torch.bmm(torch.matmul(H1_structure, self.affine2), torch.transpose(H1_semantic, 1, 2)), dim=-1) 70 | 71 | H1_semantic_new = torch.bmm(A1, H1_structure) 72 | H1_structure_new = torch.bmm(A2, H1_semantic) 73 | 74 | H1_semantic_out = self.drop(H1_semantic_new) if l < self.args.gnn_layers - 1 else H1_semantic_new 75 | H1_structure_out = self.drop(H1_structure_new) if l 0, e, zero_vec) # adj中非零位置 对应e的部分 保留,零位置(填充或没有关系连接)置为非常小的负数 104 | attention = F.softmax(attention, dim=2) # B, N, N 105 | attention = F.dropout(attention, self.dropout, training=self.training) 106 | h_prime = torch.matmul(attention, Wh) # (B,N,N_out) 107 | 108 | h_prime = self.layer_norm(h_prime) 109 | 110 | 111 | if self.concat: 112 | return F.gelu(h_prime) 113 | else: 114 | return h_prime 115 | 116 | def _prepare_attentional_mechanism_input(self, Wh): 117 | N = Wh.size()[1] # N 118 | B = Wh.size()[0] # B 119 | # Below, two matrices are created that contain embeddings in their rows in different orders. 120 | # (e stands for embedding) 121 | # These are the rows of the first matrix (Wh_repeated_in_chunks): 122 | # e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN 123 | # '-------------' -> N times '-------------' -> N times '-------------' -> N times 124 | # 125 | # These are the rows of the second matrix (Wh_repeated_alternating): 126 | # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 127 | # '----------------------------------------------------' -> N times 128 | # 129 | #print('Wh', Wh.shape) 130 | Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) 131 | Wh_repeated_alternating = Wh.repeat(1, N, 1) 132 | # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features) 133 | 134 | # The all_combination_matrix, created below, will look like this (|| denotes concatenation): 135 | # e1 || e1 136 | # e1 || e2 137 | # e1 || e3 138 | # ... 139 | # e1 || eN 140 | # e2 || e1 141 | # e2 || e2 142 | # e2 || e3 143 | # ... 144 | # e2 || eN 145 | # ... 146 | # eN || e1 147 | # eN || e2 148 | # eN || e3 149 | # ... 150 | # eN || eN 151 | 152 | all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) # (B, N*N, 2*D_out) 153 | # all_combinations_matrix.shape == (B, N * N, 2 * out_features) 154 | 155 | return all_combinations_matrix.view(B, N, N, 2 * self.out_features) 156 | 157 | def __repr__(self): 158 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 159 | 160 | 161 | 162 | class RGAT(nn.Module): 163 | def __init__(self, args, nfeat, nhid, dropout = 0.2, alpha = 0.2, nheads = 2, num_relation=-1): 164 | """Dense version of GAT.""" 165 | super(RGAT, self).__init__() 166 | self.dropout = dropout 167 | 168 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True, relation = True, num_relation=num_relation) for _ in range(nheads)] # 多头注意力 169 | 170 | for i, attention in enumerate(self.attentions): 171 | self.add_module('attention_{}'.format(i), attention) 172 | self.out_att = GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=False, relation = True, num_relation=num_relation) # 恢复到正常维度 173 | 174 | self.fc = nn.Linear(nhid, nhid) 175 | self.layer_norm = LayerNorm(nhid) 176 | 177 | def forward(self, x, adj): 178 | redisual = x 179 | x = F.dropout(x, self.dropout, training=self.training) 180 | x = torch.cat([att(x, adj) for att in self.attentions], dim=-1) # (B,N,num_head*N_out) 181 | x = F.dropout(x, self.dropout, training=self.training) 182 | x = F.gelu(self.out_att(x, adj)) # (B, N, N_out) 183 | x = self.fc(x) # (B, N, N_out) 184 | x = x + redisual 185 | x = self.layer_norm(x) 186 | return x 187 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from sklearn.metrics import f1_score, accuracy_score 5 | 6 | import json 7 | 8 | 9 | def train_or_eval_model(model, loss_function, dataloader, device, args, optimizer=None, train=False): 10 | losses, preds, labels = [], [], [] 11 | 12 | assert not train or optimizer != None 13 | if train: # 训练模式 14 | model.train() 15 | else: # 验证模式 16 | model.eval() 17 | 18 | for data in dataloader: # 遍历每个batch 19 | if train: 20 | optimizer.zero_grad() 21 | 22 | utterance_features, label, semantic_adj, structure_adj, lengths, speakers, utterances, ids = data 23 | 24 | 25 | utterance_features = utterance_features.to(device) 26 | label = label.to(device) # (B,N) 27 | semantic_adj = semantic_adj.to(device) 28 | structure_adj = structure_adj.to(device) 29 | 30 | log_prob, diff_loss = model(utterance_features, semantic_adj,structure_adj) # (B, N, C) 31 | 32 | loss = loss_function(log_prob.permute(0,2,1), label) 33 | 34 | loss = loss + diff_loss 35 | label = label.cpu().numpy().tolist() 36 | pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B,N) 37 | preds += pred 38 | labels += label 39 | losses.append(loss.item()) 40 | 41 | 42 | if train: 43 | 44 | loss.backward() 45 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 46 | optimizer.step() 47 | 48 | if preds != []: 49 | new_preds = [] 50 | new_labels = [] 51 | for i,label in enumerate(labels): # 遍历每个对话 52 | for j,l in enumerate(label): # 遍历每个utterance 53 | if l != -1: # 去除填充标签 (IEMOCAP内部utterance 也有填充) 54 | new_labels.append(l) 55 | new_preds.append(preds[i][j]) 56 | else: 57 | return float('nan'), float('nan'), [], [], float('nan'), [], [], [], [], [] 58 | 59 | avg_loss = round(np.sum(losses) / len(losses), 4) 60 | avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2) 61 | 62 | if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP_small', 'EmoryNLP_big']: 63 | avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2) 64 | elif args.dataset_name == 'DailyDialog': 65 | avg_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=[0,2,3,4,5,6]) * 100, 2) #1 is neutral 66 | 67 | return avg_loss, avg_accuracy, labels, preds, avg_fscore 68 | --------------------------------------------------------------------------------