├── README.md
├── dataloader.py
├── dataset.py
├── environment.yml
├── main.py
├── model.py
├── model_utils.py
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | The code for ACL2023 paper: 《DualGATs: Dual Graph Attention Networks for Emotion Recognition in Conversations》
2 |
3 |
4 | ## Requirements
5 |
6 | - Python 3.6.13
7 | - PyTorch 1.7.1+cu110
8 |
9 |
10 | With Anaconda, we can create the environment with the provided `environment.yml`:
11 |
12 | ```bash
13 | conda env create --file environment.yml
14 | conda activate MMERC
15 | ```
16 |
17 | The code has been tested on Ubuntu 16.04 using a single GPU.
18 |
19 |
20 | ## Run Steps
21 |
22 | 1. Please download the four ERC datasets (including pre-processed discourse graphs and RoBERTa utterance feature) and put them in the data folder. Here we utilize the data and codes from [here](https://github.com/shizhouxing/DialogueDiscourseParsing) to pre-train a conversation discourse parser and use that parser to extract discourse graphs in the four ERC datasets. And we utilize the codes from [here](https://github.com/declare-lab/conv-emotion/tree/master/COSMIC) to extract utterance feature.
23 | 2. Run our model:
24 |
25 | ```bash
26 | # For IEMOCAP:
27 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset IEMOCAP --lr 1e-4 --dropout 0.2 --batch_size 16 --gnn_layers 2
28 | # For MELD:
29 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset MELD --lr 1e-4 --dropout 0.3 --batch_size 32 --gnn_layers 2
30 | # For EmoryNLP:
31 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset EmoryNLP --lr 1e-4 --dropout 0.1 --batch_size 32 --gnn_layers 2
32 | # For DailyDialog:
33 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset DailyDialog --lr 5e-5 --dropout 0.4 --batch_size 64 --gnn_layers 3
34 | ```
35 |
36 | ## Citation
37 |
38 | ```
39 | @inproceedings{zhang2023dualgats,
40 | title={DualGATs: Dual Graph Attention Networks for Emotion Recognition in Conversations},
41 | author={Zhang, Duzhen and Chen, Feilong and Chen, Xiuyi},
42 | booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
43 | pages={7395--7408},
44 | year={2023}
45 | }
46 | ```
47 |
48 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from dataset import *
2 | from torch.utils.data.sampler import SubsetRandomSampler
3 | from torch.utils.data import DataLoader
4 |
5 |
6 | def get_train_valid_sampler(trainset):
7 | size = len(trainset) # 对话数量 n
8 | idx = list(range(size)) # [0,1,2,...,n-1]
9 | return SubsetRandomSampler(idx) # 无放回地按照给定的索引列表采样样本元素。 那么这里就相当于抽取了一个全排列
10 |
11 |
12 |
13 | def get_data_loaders(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None):
14 |
15 | print('building datasets..')
16 | trainset = MyDataset(dataset_name, 'train', args)
17 | devset = MyDataset(dataset_name, 'dev', args)
18 |
19 | train_sampler = get_train_valid_sampler(trainset)
20 | valid_sampler = get_train_valid_sampler(devset)
21 |
22 | train_loader = DataLoader(trainset,
23 | batch_size=batch_size,
24 | sampler=train_sampler,
25 | collate_fn=trainset.collate_fn,
26 | num_workers=num_workers,
27 | pin_memory=pin_memory) # 和不用sampler shuffle=True 是等价的
28 |
29 | valid_loader = DataLoader(devset,
30 | batch_size=batch_size,
31 | sampler=valid_sampler,
32 | collate_fn=devset.collate_fn,
33 | num_workers=num_workers,
34 | pin_memory=pin_memory)
35 |
36 | testset = MyDataset(dataset_name, 'test',args)
37 | test_loader = DataLoader(testset,
38 | batch_size=batch_size,
39 | collate_fn=testset.collate_fn,
40 | num_workers=num_workers,
41 | pin_memory=pin_memory)
42 |
43 | return train_loader, valid_loader, test_loader
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from turtle import pd
2 | import torch
3 | from torch.utils.data import Dataset
4 | from torch.nn.utils.rnn import pad_sequence
5 | import random
6 |
7 | import pickle
8 | def read_pickle(filename):
9 | try:
10 | with open(filename,'rb') as f:
11 | data = pickle.load(f)
12 | except:
13 | with open(filename,'rb') as f:
14 | data = pickle.load(f,encoding='latin1')
15 | return data
16 |
17 | class MyDataset(Dataset):
18 |
19 | def __init__(self, dataset_name = 'IEMOCAP', split = 'train', args = None):
20 | self.args = args
21 | self.dataset_name = dataset_name
22 |
23 |
24 | if dataset_name == 'IEMOCAP':
25 | self.videoSpeakers, self.videoLabels, roberta_feature, \
26 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle('./data/IEMOCAP_Features.pkl')
27 |
28 | self.utterance_feature = roberta_feature
29 |
30 | elif dataset_name == 'MELD':
31 | self.videoSpeakers, self.videoLabels, roberta_feature, \
32 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle('./data/MELD_Features.pkl')
33 |
34 | self.utterance_feature = roberta_feature
35 |
36 | elif dataset_name == 'DailyDialog':
37 | self.videoSpeakers, self.videoLabels, roberta_feature, \
38 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle(
39 | './data/DailyDialog_Features.pkl')
40 |
41 | self.utterance_feature = roberta_feature
42 |
43 | elif dataset_name == 'EmoryNLP':
44 | self.videoSpeakers, self.videoLabels, roberta_feature, \
45 | self.links, self.relations, self.videoSentence, self.trainVid, self.testVid, self.validVid = read_pickle(
46 | './data/EmoryNLP_Features.pkl')
47 |
48 | self.utterance_feature = roberta_feature
49 |
50 |
51 |
52 | self.data = self.read(split)
53 | print(split+' dialogue num:')
54 | print(len(self.data)) # 对话数量
55 |
56 | self.len = len(self.data)
57 |
58 | def read(self, split):
59 |
60 | # process dialogue
61 | if split=='train':
62 | dialog_ids = self.trainVid
63 | elif split=='dev':
64 | dialog_ids = self.validVid
65 | elif split=='test':
66 | dialog_ids = self.testVid
67 |
68 | dialogs = []
69 | for dialog_id in dialog_ids:
70 | utterances = self.videoSentence[dialog_id]
71 | labels = self.videoLabels[dialog_id]
72 | if self.dataset_name == 'IEMOCAP':
73 | speakers = self.videoSpeakers[dialog_id]
74 | elif self.dataset_name == 'MELD':
75 | speakers = [speaker.index(1) for speaker in self.videoSpeakers[dialog_id]]
76 | elif self.dataset_name == 'DailyDialog':
77 | speakers = [int(speaker) for speaker in self.videoSpeakers[dialog_id]]
78 | elif self.dataset_name == 'EmoryNLP_small' or 'EmoryNLP_big':
79 | speakers = self.videoSpeakers[dialog_id]
80 |
81 |
82 | utterance_features = [item.tolist() for item in self.utterance_feature[dialog_id]]
83 | utterance_links = self.links[dialog_id]
84 | utterance_relations = self.relations[dialog_id]
85 |
86 | dialogs.append({
87 | 'id':dialog_id,
88 | 'utterances': utterances,
89 | 'labels': labels,
90 | 'speakers': speakers,
91 | 'utterance_features': utterance_features,
92 | 'utterance_links': utterance_links,
93 | 'utterance_relations': utterance_relations
94 | })
95 |
96 |
97 | random.shuffle(dialogs) # 打乱对话
98 | return dialogs
99 |
100 | def __getitem__(self, index): # 获取一个样本/ 对话
101 | '''
102 | :param index:
103 | :return:
104 | feature,
105 | label
106 | speaker
107 | length
108 | text
109 | '''
110 | return torch.FloatTensor(self.data[index]['utterance_features']), \
111 | torch.LongTensor(self.data[index]['labels']),\
112 | self.data[index]['speakers'], \
113 | self.data[index]['utterance_links'],\
114 | self.data[index]['utterance_relations'],\
115 | len(self.data[index]['labels']), \
116 | self.data[index]['utterances'],\
117 | self.data[index]['id']
118 |
119 | def __len__(self):
120 | return self.len
121 |
122 | def get_semantic_adj(self, speakers, max_dialog_len):
123 |
124 | semantic_adj = []
125 | for speaker in speakers: # 遍历每个对话 对应的说话人列表(非去重)
126 | s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long) # (N,N) 0 表示填充部分 没有语义关系
127 | for i in range(len(speaker)): # 每个utterance 的说话人 和 其他 utterance 的说话人 是否相同
128 | for j in range(len(speaker)):
129 | if speaker[i] == speaker[j]:
130 | if i==j:
131 | s[i,j] = 1 # 对角线 self
132 | elif i < j:
133 | s[i,j] = 2 # self-future
134 | else:
135 | s[i,j] =3 # self-past
136 | else:
137 | if ij:
140 | s[i,j] = 5 # inter-past
141 |
142 |
143 | semantic_adj.append(s)
144 |
145 | return torch.stack(semantic_adj)
146 |
147 |
148 | def get_structure_adj(self, links, relations, lengths, max_dialog_len):
149 | '''
150 | map_relations = {'Comment': 0, 'Contrast': 1, 'Correction': 2, 'Question-answer_pair': 3, 'QAP': 3, 'Parallel': 4, 'Acknowledgement': 5,
151 | 'Elaboration': 6, 'Clarification_question': 7, 'Conditional': 8, 'Continuation': 9, 'Result': 10, 'Explanation': 11,
152 | 'Q-Elab': 12, 'Alternation': 13, 'Narration': 14, 'Background': 15}
153 |
154 | '''
155 | structure_adj = []
156 |
157 | for link,relation,length in zip(links,relations,lengths):
158 | s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long) # (N,N) 0 表示填充部分 或 没有关系
159 | assert len(link)==len(relation)
160 |
161 | for index, (i,j) in enumerate(link):
162 | s[i,j] = relation[index] + 1
163 | s[j,i] = s[i,j] # 变成对称矩阵
164 |
165 |
166 |
167 | for i in range(length): # 填充对角线
168 | s[i,i] = 17
169 |
170 | structure_adj.append(s)
171 |
172 | return torch.stack(structure_adj)
173 |
174 |
175 |
176 | def collate_fn(self, data): # data 是一个batch 的对话 获取一批样本/对话 并填充
177 | '''
178 | :param data:
179 | utterance_features, labels, speakers, utterance_links, utterance_relations, length, texts,id
180 | :return:
181 | text_features: (B, N, D) padded
182 |
183 | labels: (B, N) padded
184 | adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
185 | s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
186 | lengths: (B, )
187 | utterances: not a tensor
188 | '''
189 | max_dialog_len = max([d[5] for d in data]) # batch 中 对话的最大长度 N
190 |
191 | utterance_features = pad_sequence([d[0] for d in data], batch_first = True) # (B, N, D)
192 |
193 | labels = pad_sequence([d[1] for d in data], batch_first = True, padding_value = -1) # (B, N ) label 填充值为 -1
194 |
195 | semantic_adj = self.get_semantic_adj([d[2] for d in data], max_dialog_len)
196 |
197 | structure_adj = self.get_structure_adj([d[3] for d in data], [d[4] for d in data], [d[5] for d in data],max_dialog_len)
198 |
199 | lengths = torch.LongTensor([d[5] for d in data]) # batch 每个对话的长度
200 |
201 | speakers = pad_sequence([torch.LongTensor(d[2]) for d in data], batch_first = True, padding_value = -1) # (B, N) speaker 填充值为 -1
202 | utterances = [d[6] for d in data] # batch 中每个对话对应的 utterance 文本
203 | ids = [d[7] for d in data] # batch 中每个对话对应的id
204 |
205 | return utterance_features, labels, semantic_adj, structure_adj, lengths, speakers, utterances, ids
206 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: MMERC
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=4.5=1_gnu
7 | - ca-certificates=2021.7.5=h06a4308_1
8 | - certifi=2021.5.30=py36h06a4308_0
9 | - ld_impl_linux-64=2.35.1=h7274673_9
10 | - libffi=3.3=he6710b0_2
11 | - libgcc-ng=9.3.0=h5101ec6_17
12 | - libgomp=9.3.0=h5101ec6_17
13 | - libstdcxx-ng=9.3.0=hd4cf53a_17
14 | - ncurses=6.2=he6710b0_1
15 | - openssl=1.1.1k=h27cfd23_0
16 | - pip=21.2.2=py36h06a4308_0
17 | - python=3.6.13=h12debd9_1
18 | - readline=8.1=h27cfd23_0
19 | - setuptools=52.0.0=py36h06a4308_0
20 | - sqlite=3.36.0=hc218d9a_0
21 | - tk=8.6.10=hbc83047_0
22 | - wheel=0.36.2=pyhd3eb1b0_0
23 | - xz=5.2.5=h7b6447c_0
24 | - zlib=1.2.11=h7b6447c_3
25 | - pip:
26 | - absl-py==0.13.0
27 | - anykeystore==0.2
28 | - apex==0.1
29 | - astunparse==1.6.3
30 | - boto3==1.18.14
31 | - botocore==1.21.14
32 | - cached-property==1.5.2
33 | - cachetools==4.2.2
34 | - cffi==1.15.0
35 | - chardet==3.0.4
36 | - click==8.0.3
37 | - cloudpickle==2.0.0
38 | - cryptacular==1.6.2
39 | - cycler==0.10.0
40 | - cython==0.29.24
41 | - dataclasses==0.8
42 | - decorator==4.4.2
43 | - defusedxml==0.7.1
44 | - easydict==1.9
45 | - fasteners==0.17.3
46 | - filelock==3.4.0
47 | - flatbuffers==1.12
48 | - gast==0.4.0
49 | - glfw==2.5.3
50 | - google-auth==1.34.0
51 | - google-auth-oauthlib==0.4.5
52 | - google-pasta==0.2.0
53 | - googledrivedownloader==0.4
54 | - greenlet==1.1.2
55 | - grpcio==1.34.1
56 | - gym==0.21.0
57 | - h5py==3.1.0
58 | - huggingface-hub==0.2.0
59 | - hupper==1.10.3
60 | - idna==2.8
61 | - imageio==2.9.0
62 | - importlib-metadata==4.8.3
63 | - inplace-abn==1.1.1.dev6+gd2728c8
64 | - isodate==0.6.1
65 | - jinja2==3.0.3
66 | - jmespath==0.10.0
67 | - joblib==1.1.0
68 | - jsonlines==2.0.0
69 | - keras-nightly==2.5.0.dev2021032900
70 | - keras-preprocessing==1.1.2
71 | - kiwisolver==1.3.1
72 | - llvmlite==0.36.0
73 | - lmdb==1.2.1
74 | - markdown==3.3.4
75 | - markupsafe==2.0.1
76 | - matplotlib==3.3.4
77 | - mujoco-py==2.1.2.14
78 | - networkx==2.5.1
79 | - nltk==3.6.7
80 | - numba==0.53.1
81 | - numpy==1.19.5
82 | - oauthlib==3.1.1
83 | - opencv-python==4.5.3.56
84 | - opt-einsum==3.3.0
85 | - packaging==21.3
86 | - pandas==1.1.5
87 | - pastedeploy==2.1.1
88 | - pbkdf2==1.3
89 | - pillow==8.3.1
90 | - plaster==1.0
91 | - plaster-pastedeploy==0.7
92 | - plyfile==0.7.4
93 | - protobuf==3.10.0
94 | - pyasn1==0.4.8
95 | - pyasn1-modules==0.2.8
96 | - pycocotools==2.0.2
97 | - pycparser==2.21
98 | - pyparsing==2.4.7
99 | - pyramid==2.0
100 | - pyramid-mailer==0.15.1
101 | - python-dateutil==2.8.2
102 | - python3-openid==3.2.0
103 | - pytz==2021.1
104 | - pywavelets==1.1.1
105 | - pyyaml==5.4.1
106 | - rdflib==5.0.0
107 | - regex==2022.6.2
108 | - repoze-sendmail==4.4.1
109 | - requests==2.22.0
110 | - requests-oauthlib==1.3.0
111 | - rsa==4.7.2
112 | - s3transfer==0.5.0
113 | - sacremoses==0.0.46
114 | - scikit-image==0.17.2
115 | - scikit-learn==0.24.2
116 | - scipy==1.5.4
117 | - setuptools-scm==6.4.2
118 | - six==1.15.0
119 | - sqlalchemy==1.4.39
120 | - tensorboard==2.5.0
121 | - tensorboard-data-server==0.6.1
122 | - tensorboard-plugin-wit==1.8.0
123 | - tensorboardx==2.4
124 | - tensorflow==2.5.0
125 | - tensorflow-estimator==2.5.0
126 | - termcolor==1.1.0
127 | - threadpoolctl==3.0.0
128 | - tifffile==2020.9.3
129 | - tokenizers==0.10.3
130 | - tomli==1.2.3
131 | - torch==1.7.1+cu110
132 | - torch-geometric==2.0.3
133 | - torch-scatter==2.0.5
134 | - torch-sparse==0.6.8
135 | - torchaudio==0.7.2
136 | - torchvision==0.8.2+cu110
137 | - tqdm==4.62.0
138 | - transaction==3.0.1
139 | - transformers==4.12.5
140 | - translationstring==1.4
141 | - typing-extensions==3.7.4.3
142 | - urllib3==1.25.11
143 | - velruse==1.1.1
144 | - venusian==3.0.0
145 | - webob==1.8.7
146 | - werkzeug==2.0.1
147 | - wrapt==1.12.1
148 | - wtforms==3.0.0
149 | - wtforms-recaptcha==0.3.2
150 | - yacs==0.1.8
151 | - zipp==3.5.0
152 | - zope-deprecation==4.4.0
153 | - zope-interface==5.4.0
154 | - zope-sqlalchemy==1.6
155 | prefix: /home/cfl/anaconda3/envs/MMERC
156 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np, argparse, time, random
3 |
4 |
5 | from model import *
6 |
7 | from trainer import train_or_eval_model
8 |
9 | from dataloader import get_data_loaders
10 | from transformers import AdamW
11 |
12 |
13 |
14 | def seed_everything(seed):
15 | random.seed(seed)
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | torch.backends.cudnn.benchmark = False
21 | torch.backends.cudnn.deterministic = True
22 |
23 |
24 | def str2bool(v):
25 | """ Usage:
26 | parser.add_argument('--pretrained', type=str2bool, nargs='?', const=True,
27 | dest='pretrained', help='Whether to use pretrained models.')
28 | """
29 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
30 | return True
31 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
32 | return False
33 | else:
34 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
35 |
36 |
37 |
38 | if __name__ == '__main__':
39 |
40 | path = './saved_models/' # 日志 模型保存路径
41 |
42 | parser = argparse.ArgumentParser()
43 |
44 | parser.add_argument('--hidden_dim', type=int, default=300)
45 | parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
46 |
47 | parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
48 | parser.add_argument('--dataset_name', default='MELD', type=str, help='dataset name, IEMOCAP, MELD, DailyDialog, EmoryNLP')
49 | parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
50 | parser.add_argument('--epochs', type=int, default=60, metavar='E', help='number of epochs')
51 |
52 | parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
53 |
54 | parser.add_argument('--lr', type=float, default=5e-5, metavar='LR', help='learning rate') #####
55 | parser.add_argument('--dropout', type=float, default=0.4, metavar='dropout', help='dropout rate')
56 | parser.add_argument('--batch_size', type=int, default=64, metavar='BS', help='batch size') ##
57 |
58 |
59 | parser.add_argument('--seed', type=int, default=100, help='random seed') ##
60 |
61 |
62 | args = parser.parse_args()
63 |
64 | print(args)
65 |
66 | # 固定随机种子
67 | seed_everything(args.seed)
68 |
69 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70 | print("device:", args.device)
71 |
72 | device = args.device
73 | n_epochs = args.epochs
74 | batch_size = args.batch_size
75 |
76 |
77 | train_loader, valid_loader, test_loader = get_data_loaders(
78 | dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args=args)
79 |
80 |
81 | if 'IEMOCAP' in args.dataset_name:
82 | n_classes = 6
83 | else:
84 | n_classes = 7
85 |
86 | print('building model..')
87 |
88 | model = DualGATs(args, n_classes)
89 |
90 | if torch.cuda.device_count() > 1:
91 | print('Multi-GPU...........')
92 | model = nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
93 |
94 | model.to(device)
95 |
96 |
97 | loss_function = nn.CrossEntropyLoss(ignore_index=-1) # 忽略掉label=-1 的类
98 |
99 |
100 | optimizer = AdamW(model.parameters(), lr=args.lr)
101 |
102 | best_fscore, best_acc, best_loss, best_label, best_pred, best_mask = None, None, None, None, None, None
103 | all_fscore, all_acc, all_loss = [], [], []
104 | best_acc = 0.
105 | best_fscore = 0.
106 |
107 | best_model = None
108 | for e in range(n_epochs): # 遍历每个epoch
109 | start_time = time.time()
110 |
111 | train_loss, train_acc, _, _, train_fscore = train_or_eval_model(model, loss_function,
112 | train_loader, device,
113 | args, optimizer, True)
114 | valid_loss, valid_acc, _, _, valid_fscore = train_or_eval_model(model, loss_function,
115 | valid_loader, device, args)
116 | test_loss, test_acc, test_label, test_pred, test_fscore = train_or_eval_model(model, loss_function,
117 | test_loader, device, args)
118 |
119 | all_fscore.append([valid_fscore, test_fscore])
120 |
121 | print(
122 | 'Epoch: {}, train_loss: {}, train_acc: {}, train_fscore: {}, valid_loss: {}, valid_acc: {}, valid_fscore: {}, test_loss: {}, test_acc: {}, test_fscore: {}, time: {} sec'. \
123 | format(e + 1, train_loss, train_acc, train_fscore, valid_loss, valid_acc, valid_fscore, test_loss,
124 | test_acc,
125 | test_fscore, round(time.time() - start_time, 2)))
126 |
127 | e += 1
128 |
129 |
130 | print('finish training!')
131 |
132 |
133 | all_fscore = sorted(all_fscore, key=lambda x: (x[0], x[1]), reverse=True) # 优先按照验证集 f1 进行排序
134 |
135 | print('Best val F-Score:{}'.format(all_fscore[0][0])) # 验证集最好性能
136 | print('Best test F-Score based on validation:{}'.format(all_fscore[0][1])) # 验证集取得最好性能时 对应测试集的下性能
137 | print('Best test F-Score based on test:{}'.format(max([f[1] for f in all_fscore]))) # 测试集 最好的性能
138 |
139 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from model_utils import RGAT, DiffLoss
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class DualGATs(nn.Module):
8 |
9 | def __init__(self, args, num_class):
10 | super().__init__()
11 | self.args = args
12 |
13 | self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
14 |
15 | SpkGAT = []
16 | DisGAT = []
17 | for _ in range(args.gnn_layers):
18 | SpkGAT.append(RGAT(args, args.hidden_dim, args.hidden_dim, dropout=args.dropout, num_relation=6))
19 | DisGAT.append(RGAT(args, args.hidden_dim, args.hidden_dim, dropout=args.dropout, num_relation=18))
20 |
21 | self.SpkGAT = nn.ModuleList(SpkGAT)
22 | self.DisGAT = nn.ModuleList(DisGAT)
23 |
24 |
25 | self.affine1 = nn.Parameter(torch.empty(size=(args.hidden_dim, args.hidden_dim)))
26 | nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
27 | self.affine2 = nn.Parameter(torch.empty(size=(args.hidden_dim, args.hidden_dim)))
28 | nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
29 |
30 | self.diff_loss = DiffLoss(args)
31 | self.beta = 0.3
32 |
33 | in_dim = args.hidden_dim *2 + args.emb_dim
34 | # output mlp layers
35 | layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
36 | for _ in range(args.mlp_layers - 1):
37 | layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
38 | layers += [nn.Linear(args.hidden_dim, num_class)]
39 |
40 | self.out_mlp = nn.Sequential(*layers)
41 |
42 | self.drop = nn.Dropout(args.dropout)
43 |
44 |
45 |
46 | def forward(self, utterance_features, semantic_adj, structure_adj):
47 | '''
48 | :param tutterance_features: (B, N, emb_dim)
49 | :param xx_adj: (B, N, N)
50 | :return:
51 | '''
52 | batch_size = utterance_features.size(0)
53 | H0 = F.relu(self.fc1(utterance_features)) # (B, N, hidden_dim)
54 | H = [H0]
55 | diff_loss = 0
56 | for l in range(self.args.gnn_layers):
57 | if l==0:
58 | H1_semantic = self.SpkGAT[l](H[l], semantic_adj)
59 | H1_structure = self.DisGAT[l](H[l], structure_adj)
60 | else:
61 | H1_semantic = self.SpkGAT[l](H[2*l-1], semantic_adj)
62 | H1_structure = self.DisGAT[l](H[2*l], structure_adj)
63 |
64 |
65 | diff_loss = diff_loss + self.diff_loss(H1_semantic, H1_structure)
66 | # BiAffine
67 |
68 | A1 = F.softmax(torch.bmm(torch.matmul(H1_semantic, self.affine1), torch.transpose(H1_structure, 1, 2)), dim=-1)
69 | A2 = F.softmax(torch.bmm(torch.matmul(H1_structure, self.affine2), torch.transpose(H1_semantic, 1, 2)), dim=-1)
70 |
71 | H1_semantic_new = torch.bmm(A1, H1_structure)
72 | H1_structure_new = torch.bmm(A2, H1_semantic)
73 |
74 | H1_semantic_out = self.drop(H1_semantic_new) if l < self.args.gnn_layers - 1 else H1_semantic_new
75 | H1_structure_out = self.drop(H1_structure_new) if l 0, e, zero_vec) # adj中非零位置 对应e的部分 保留,零位置(填充或没有关系连接)置为非常小的负数
104 | attention = F.softmax(attention, dim=2) # B, N, N
105 | attention = F.dropout(attention, self.dropout, training=self.training)
106 | h_prime = torch.matmul(attention, Wh) # (B,N,N_out)
107 |
108 | h_prime = self.layer_norm(h_prime)
109 |
110 |
111 | if self.concat:
112 | return F.gelu(h_prime)
113 | else:
114 | return h_prime
115 |
116 | def _prepare_attentional_mechanism_input(self, Wh):
117 | N = Wh.size()[1] # N
118 | B = Wh.size()[0] # B
119 | # Below, two matrices are created that contain embeddings in their rows in different orders.
120 | # (e stands for embedding)
121 | # These are the rows of the first matrix (Wh_repeated_in_chunks):
122 | # e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN
123 | # '-------------' -> N times '-------------' -> N times '-------------' -> N times
124 | #
125 | # These are the rows of the second matrix (Wh_repeated_alternating):
126 | # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN
127 | # '----------------------------------------------------' -> N times
128 | #
129 | #print('Wh', Wh.shape)
130 | Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1)
131 | Wh_repeated_alternating = Wh.repeat(1, N, 1)
132 | # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)
133 |
134 | # The all_combination_matrix, created below, will look like this (|| denotes concatenation):
135 | # e1 || e1
136 | # e1 || e2
137 | # e1 || e3
138 | # ...
139 | # e1 || eN
140 | # e2 || e1
141 | # e2 || e2
142 | # e2 || e3
143 | # ...
144 | # e2 || eN
145 | # ...
146 | # eN || e1
147 | # eN || e2
148 | # eN || e3
149 | # ...
150 | # eN || eN
151 |
152 | all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) # (B, N*N, 2*D_out)
153 | # all_combinations_matrix.shape == (B, N * N, 2 * out_features)
154 |
155 | return all_combinations_matrix.view(B, N, N, 2 * self.out_features)
156 |
157 | def __repr__(self):
158 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
159 |
160 |
161 |
162 | class RGAT(nn.Module):
163 | def __init__(self, args, nfeat, nhid, dropout = 0.2, alpha = 0.2, nheads = 2, num_relation=-1):
164 | """Dense version of GAT."""
165 | super(RGAT, self).__init__()
166 | self.dropout = dropout
167 |
168 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True, relation = True, num_relation=num_relation) for _ in range(nheads)] # 多头注意力
169 |
170 | for i, attention in enumerate(self.attentions):
171 | self.add_module('attention_{}'.format(i), attention)
172 | self.out_att = GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=False, relation = True, num_relation=num_relation) # 恢复到正常维度
173 |
174 | self.fc = nn.Linear(nhid, nhid)
175 | self.layer_norm = LayerNorm(nhid)
176 |
177 | def forward(self, x, adj):
178 | redisual = x
179 | x = F.dropout(x, self.dropout, training=self.training)
180 | x = torch.cat([att(x, adj) for att in self.attentions], dim=-1) # (B,N,num_head*N_out)
181 | x = F.dropout(x, self.dropout, training=self.training)
182 | x = F.gelu(self.out_att(x, adj)) # (B, N, N_out)
183 | x = self.fc(x) # (B, N, N_out)
184 | x = x + redisual
185 | x = self.layer_norm(x)
186 | return x
187 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from sklearn.metrics import f1_score, accuracy_score
5 |
6 | import json
7 |
8 |
9 | def train_or_eval_model(model, loss_function, dataloader, device, args, optimizer=None, train=False):
10 | losses, preds, labels = [], [], []
11 |
12 | assert not train or optimizer != None
13 | if train: # 训练模式
14 | model.train()
15 | else: # 验证模式
16 | model.eval()
17 |
18 | for data in dataloader: # 遍历每个batch
19 | if train:
20 | optimizer.zero_grad()
21 |
22 | utterance_features, label, semantic_adj, structure_adj, lengths, speakers, utterances, ids = data
23 |
24 |
25 | utterance_features = utterance_features.to(device)
26 | label = label.to(device) # (B,N)
27 | semantic_adj = semantic_adj.to(device)
28 | structure_adj = structure_adj.to(device)
29 |
30 | log_prob, diff_loss = model(utterance_features, semantic_adj,structure_adj) # (B, N, C)
31 |
32 | loss = loss_function(log_prob.permute(0,2,1), label)
33 |
34 | loss = loss + diff_loss
35 | label = label.cpu().numpy().tolist()
36 | pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B,N)
37 | preds += pred
38 | labels += label
39 | losses.append(loss.item())
40 |
41 |
42 | if train:
43 |
44 | loss.backward()
45 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
46 | optimizer.step()
47 |
48 | if preds != []:
49 | new_preds = []
50 | new_labels = []
51 | for i,label in enumerate(labels): # 遍历每个对话
52 | for j,l in enumerate(label): # 遍历每个utterance
53 | if l != -1: # 去除填充标签 (IEMOCAP内部utterance 也有填充)
54 | new_labels.append(l)
55 | new_preds.append(preds[i][j])
56 | else:
57 | return float('nan'), float('nan'), [], [], float('nan'), [], [], [], [], []
58 |
59 | avg_loss = round(np.sum(losses) / len(losses), 4)
60 | avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
61 |
62 | if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP_small', 'EmoryNLP_big']:
63 | avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
64 | elif args.dataset_name == 'DailyDialog':
65 | avg_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=[0,2,3,4,5,6]) * 100, 2) #1 is neutral
66 |
67 | return avg_loss, avg_accuracy, labels, preds, avg_fscore
68 |
--------------------------------------------------------------------------------