├── README.md
├── assets
├── graph.png
└── model.png
├── data_utils.py
├── graph.py
├── layer
├── __init__.py
├── rgcn.py
└── supervisedcontrastiveloss.py
├── main.py
├── model
├── YORO.py
└── __init__.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # **YORO**
2 | Code for Paper "You Only Read Once: Constituency-Oriented Relational Graph Convolutional Network for Multi-Aspect Multi-Sentiment Classification"
3 |
4 | *AAAI2024*
5 |
6 | Yongqiang Zheng, Xia Li
7 |
8 |
9 | ## **Model**
10 | 
11 |
12 | ## **Requirements**
13 | - python==3.10.12
14 | - torch==1.12.1+cu113
15 | - transformers==4.30.2
16 | - scikit-learn==1.2.2
17 | - benepar==0.2.0
18 |
19 | ## **Datasets**
20 | Download datasets from these links and put them in the **dataset** folder:
21 | - [MAMS](https://github.com/siat-nlp/MAMS-for-ABSA)
22 | - [Rest14](https://alt.qcri.org/semeval2014/task4)
23 | - [Lap14](https://alt.qcri.org/semeval2014/task4)
24 |
25 | ## **Usage**
26 | 1. Download Bing Liu's opinion lexicon
27 | ```
28 | wget http://www.cs.uic.edu/\~liub/FBS/opinion-lexicon-English.rar
29 | sudo apt-get install unrar
30 | unrar x opinion-lexicon-English.rar
31 | mv opinion-lexicon-English lexicon
32 | ```
33 | 3. Generate constituency-oriented graph
34 | ```
35 | python graph.py
36 | ```
37 |
38 | An example of the construction of Constituency-Oriented Relational Graph Convolutional Network (CorrGCN)
39 | 
40 |
41 | ## **Training**
42 | ```
43 | bash train.sh
44 | ```
45 |
46 | ## **Credits**
47 | The code in this repository is based on [SEGCN-ABSA](https://github.com/gdufsnlp/SEGCN-ABSA).
48 |
49 | ## **Citation**
50 | ```bibtex
51 | @inproceedings{zheng2024you,
52 | title = {You Only Read Once: Constituency-Oriented Relational Graph Convolutional Network for Multi-Aspect Multi-Sentiment Classification},
53 | author = {Zheng, Yongqiang and Li, Xia},
54 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
55 | volume = {38},
56 | number = {17},
57 | pages = {19715--19723},
58 | year = {2024},
59 | url = {https://ojs.aaai.org/index.php/AAAI/article/view/29945},
60 | doi = {10.1609/aaai.v38i17.29945},
61 | }
62 | ```
63 |
--------------------------------------------------------------------------------
/assets/graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gdufsnlp/YORO/0c1d045ed21858731522a87349d5b59ec875502e/assets/graph.png
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gdufsnlp/YORO/0c1d045ed21858731522a87349d5b59ec875502e/assets/model.png
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import pickle
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from transformers import BertTokenizer
7 |
8 |
9 | def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
10 | x = (np.ones(maxlen) * value).astype(dtype)
11 | if truncating == 'pre':
12 | trunc = sequence[-maxlen:]
13 | else:
14 | trunc = sequence[:maxlen]
15 | trunc = np.asarray(trunc, dtype=dtype)
16 | if padding == 'post':
17 | x[:len(trunc)] = trunc
18 | else:
19 | x[-len(trunc):] = trunc
20 | return x
21 |
22 |
23 | def opinion_lexicon():
24 | pos_file = 'lexicon/positive-words.txt'
25 | neg_file = 'lexicon/negative-words.txt'
26 | lexicon = {}
27 | fin1 = open(pos_file, 'r', encoding='utf-8', newline='\n', errors='ignore')
28 | fin2 = open(neg_file, 'r', encoding='utf-8', newline='\n', errors='ignore')
29 | lines1 = fin1.readlines()
30 | lines2 = fin2.readlines()
31 | fin1.close()
32 | fin2.close()
33 | for pos_word in lines1:
34 | lexicon[pos_word.strip()] = 'positive'
35 | for neg_word in lines2:
36 | lexicon[neg_word.strip()] = 'negative'
37 | return lexicon
38 |
39 |
40 | class Tokenizer4Bert:
41 | def __init__(self, max_seq_len, pretrained_bert_name):
42 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
43 | self.max_seq_len = max_seq_len
44 |
45 | def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
46 | sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
47 | if len(sequence) == 0:
48 | sequence = [0]
49 | if reverse:
50 | sequence = sequence[::-1]
51 | return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)
52 |
53 | def opinion_in_text(self, text, aspects, lexicon):
54 | aspect_index = []
55 | for asp_idx in aspects:
56 | _, start, end = asp_idx
57 | aspect_index.extend(list(range(start + 1, end + 1)))
58 |
59 | opinion_index = []
60 | for idx, word in enumerate(text.split()):
61 | for t in self.tokenizer.tokenize(word):
62 | if idx in aspect_index:
63 | opinion_index.append(-1) # skip aspect words
64 | elif word in lexicon.keys():
65 | if lexicon[word] == 'negative':
66 | opinion_index.append(0)
67 | elif lexicon[word] == 'positive':
68 | opinion_index.append(2)
69 | else:
70 | opinion_index.append(1)
71 | assert len(opinion_index) == len(self.tokenizer.tokenize(text))
72 | return pad_and_truncate(opinion_index, self.max_seq_len, value=-1)
73 |
74 | def map_bert_1D(self, text):
75 | words = text.split()
76 | # bert_tokens = []
77 | bert_map = []
78 | for src_i, word in enumerate(words):
79 | for subword in self.tokenizer.tokenize(word):
80 | # bert_tokens.append(subword) # * ['expand', '##able', 'highly', 'like', '##ing']
81 | bert_map.append(src_i) # * [0, 0, 1, 2, 2]
82 |
83 | return bert_map
84 |
85 |
86 | class ABSADataset(Dataset):
87 | def __init__(self, file, tokenizer):
88 | self.file = file
89 | self.tokenizer = tokenizer
90 | self.load_data()
91 |
92 | def load_data(self):
93 | fin = open(self.file, 'r', encoding='utf-8', newline='\n', errors='ignore')
94 | lines = fin.readlines()
95 | fin.close()
96 | fin = open(self.file + '_relation.pkl', 'rb')
97 | rel_matrix = pickle.load(fin)
98 | fin.close()
99 | fin = open(self.file + '_opinion.pkl', 'rb')
100 | lex_matrix = pickle.load(fin)
101 | fin.close()
102 | fin = open(self.file + '_distance.pkl', 'rb')
103 | dis_matrix = pickle.load(fin)
104 | fin.close()
105 |
106 | lexicon = opinion_lexicon()
107 | all_data = []
108 | for i in range(0, len(lines), 3):
109 | text = lines[i].lower().strip()
110 | all_aspect = lines[i + 1].lower().strip()
111 | all_polarity = lines[i + 2].strip()
112 | aspects = []
113 | for aspect_idx in all_aspect.split('\t'):
114 | aspect, start, end = aspect_idx.split('#')
115 | aspects.append([aspect, int(start), int(end)])
116 | labels = []
117 | for label in all_polarity.split('\t'):
118 | labels.append(int(label) + 1)
119 |
120 | text_len = len(self.tokenizer.tokenizer.tokenize(text))
121 | input_ids = self.tokenizer.text_to_sequence('[CLS] ' + text + ' [SEP]')
122 | token_type_ids = [0] * (text_len + 2)
123 | attention_mask = [1] * len(token_type_ids)
124 | token_type_ids = pad_and_truncate(token_type_ids, self.tokenizer.max_seq_len)
125 | attention_mask = pad_and_truncate(attention_mask, self.tokenizer.max_seq_len)
126 | opinion_indices = self.tokenizer.opinion_in_text('[CLS] ' + text + ' [SEP]', aspects, lexicon)
127 |
128 | distance_adj = np.zeros((self.tokenizer.max_seq_len, self.tokenizer.max_seq_len)).astype('float32')
129 | distance_adj[1:text_len + 1, 1:text_len + 1] = dis_matrix[i]
130 | relation_adj = np.zeros((5, self.tokenizer.max_seq_len, self.tokenizer.max_seq_len)).astype('float32')
131 | for j in range(0, 4):
132 | r_tmp = np.where(rel_matrix[i] == j + 1, 1, 0)
133 | relation_adj[j, 1:text_len + 1, 1:text_len + 1] = r_tmp
134 | for k in range(4, 5):
135 | l_tmp = np.where(lex_matrix[i] == k + 1, 1, 0)
136 | relation_adj[k, 1:text_len + 1, 1:text_len + 1] = l_tmp
137 | polarities = [-1] * self.tokenizer.max_seq_len
138 |
139 | bert_index = self.tokenizer.map_bert_1D(text)
140 | for asp_idx, pol in zip(aspects, labels):
141 | _, start, end = asp_idx
142 | # label the first token of aspect
143 | polarities[bert_index.index(start) + 1] = pol # +1 for cls
144 |
145 | polarities = np.asarray(polarities)
146 | data = {
147 | 'input_ids': input_ids,
148 | 'token_type_ids': token_type_ids,
149 | 'attention_mask': attention_mask,
150 | 'distance_adj': distance_adj,
151 | 'relation_adj': relation_adj,
152 | 'polarities': polarities,
153 | 'opinion_indices': opinion_indices,
154 | }
155 |
156 | all_data.append(data)
157 | self.data = all_data
158 |
159 | def __getitem__(self, index):
160 | return self.data[index]
161 |
162 | def __len__(self):
163 | return len(self.data)
164 |
--------------------------------------------------------------------------------
/graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import benepar
4 | import numpy as np
5 | import pickle
6 | import spacy
7 |
8 | from spacy.tokens import Doc
9 | from tqdm import tqdm
10 | from transformers import BertTokenizer
11 |
12 |
13 | class WhitespaceTokenizer(object):
14 | def __init__(self, vocab):
15 | self.vocab = vocab
16 |
17 | def __call__(self, text):
18 | words = text.split()
19 | # All tokens 'own' a subsequent space character in this tokenizer
20 | spaces = [True] * len(words)
21 | return Doc(self.vocab, words=words, spaces=spaces)
22 |
23 |
24 | # spaCy + Berkeley
25 | nlp = spacy.load('en_core_web_md')
26 | nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
27 | nlp.add_pipe("benepar", config={"model": "benepar_en3"})
28 | # BERT
29 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
30 |
31 |
32 | def get_unique_elements(lists, aspects):
33 | unique_lists = []
34 | for i, lst in enumerate(lists):
35 | other_lists = lists[:i] + lists[i + 1:]
36 | unique = set(lst) - set.union(*map(set, other_lists))
37 | if len(unique) > 0:
38 | cons = max(list(unique), key=lambda x: len(x))
39 | unique_lists.append([cons.start, cons.end])
40 | else:
41 | start = aspects[i][1]
42 | end = aspects[i][2]
43 | unique_lists.append([start, end])
44 | return unique_lists
45 |
46 |
47 | def single_aspect(text, aspects):
48 | # https://spacy.io/docs/usage/processing-text
49 | tokens = nlp(text)
50 | words = text.split()
51 | assert len(words) == len(list(tokens))
52 |
53 | token = aspects[0]
54 | asp, start, end = token[0], token[1], token[2]
55 | aspect_specific = []
56 | all_cons = []
57 | for sent in tokens.sents:
58 | for cons in sent._.constituents:
59 | if cons.text == text:
60 | continue
61 | all_cons.append(cons)
62 | if cons.start <= start <= end <= cons.end:
63 | if len(cons._.labels) > 0: # len(cons) > 1:
64 | aspect_specific.append(cons)
65 | aspect_specific_cons = []
66 | aspect_tag = ''
67 | for cons in aspect_specific:
68 | if cons._.labels[0] != 'S':
69 | aspect_specific_cons.append([cons.start, cons.end])
70 | aspect_tag = cons._.labels[0]
71 | break
72 |
73 | for cons in all_cons:
74 | if len(cons._.labels) > 0 and cons._.labels[0] == aspect_tag: # len(cons) != 1
75 | flag = True
76 | for asp_cons in aspect_specific_cons:
77 | if cons.end <= asp_cons[0] or cons.start >= asp_cons[1]:
78 | continue
79 | else:
80 | flag = False
81 | if flag:
82 | aspect_specific_cons.append([cons.start, cons.end])
83 | if aspect_specific_cons == []:
84 | aspect_specific_cons.append([0, len(words)])
85 | return aspect_specific_cons
86 |
87 |
88 | def distance_matrix(text):
89 | # https://spacy.io/docs/usage/processing-text
90 | tokens = nlp(text)
91 | words = text.split()
92 | matrix = np.zeros((len(words), len(words))).astype('float32')
93 | assert len(words) == len(list(tokens))
94 |
95 | for sent in tokens.sents:
96 | for cons in sent._.constituents:
97 | if len(cons) == 1:
98 | continue
99 | matrix[cons.start:cons.end, cons.start:cons.end] += np.ones([len(cons), len(cons)])
100 |
101 | hops_matrix = np.amax(matrix, axis=1, keepdims=True) - matrix # hops
102 | dis_matrix = 2 - hops_matrix / (np.amax(hops_matrix, axis=1, keepdims=True) + 1)
103 |
104 | return dis_matrix
105 |
106 |
107 | def relation_matrix(text, aspects):
108 | # https://spacy.io/docs/usage/processing-text
109 | tokens = nlp(text)
110 | words = text.split()
111 | matrix = np.eye(len(words)).astype('float32')
112 | if len(words) != len(list(tokens)):
113 | print(words)
114 | print(list(tokens))
115 | assert len(words) == len(list(tokens))
116 |
117 | all_start = [aspect[1] for aspect in aspects]
118 | relations = [False] * len(tokens)
119 |
120 | if len(aspects) > 1:
121 | # intra-aspect
122 | # aspect-related collection
123 | aspect_nodes = [[] for _ in range(len(aspects))]
124 | for sent in tokens.sents:
125 | for cons in sent._.constituents:
126 | for idx, token in enumerate(aspects):
127 | asp, start, end = token[0], token[1], token[2]
128 | if cons.start <= start and end <= cons.end:
129 | aspect_nodes[idx].append(cons)
130 | # aspect-specific
131 | aspect_specific_cons = get_unique_elements(aspect_nodes, aspects)
132 | for idx, cons in enumerate(aspect_specific_cons):
133 | for i in range(cons[0], cons[1]):
134 | matrix[all_start[idx]][i] = 2
135 | matrix[i][all_start[idx]] = 2
136 | relations[i] = True
137 | # globally-shared
138 | for i in range(len(relations)):
139 | if not relations[i]:
140 | for j in all_start:
141 | matrix[i][j] = 3
142 | matrix[j][i] = 3
143 | # inter-aspect
144 | for i in range(len(all_start)):
145 | for j in range(i + 1, len(all_start)):
146 | matrix[all_start[i]][all_start[j]] = 4
147 | matrix[all_start[j]][all_start[i]] = 4
148 | else:
149 | # pseudo aspect
150 | # intra-aspect
151 | # aspect-related collection
152 | aspect_specific_cons = single_aspect(text, aspects)
153 | all_start += [aspect[0] for aspect in aspect_specific_cons[1:]]
154 |
155 | # aspect-specific
156 | for idx, cons in enumerate(aspect_specific_cons):
157 | for i in range(cons[0], cons[1]):
158 | matrix[all_start[idx]][i] = 2
159 | matrix[i][all_start[idx]] = 2
160 | relations[i] = True
161 | # globally-shared
162 | for i in range(len(relations)):
163 | if not relations[i]:
164 | for j in all_start:
165 | matrix[i][j] = 3
166 | matrix[j][i] = 3
167 |
168 | # inter-aspect
169 | for i in range(len(all_start)):
170 | for j in range(i + 1, len(all_start)):
171 | matrix[all_start[i]][all_start[j]] = 4
172 | matrix[all_start[j]][all_start[i]] = 4
173 |
174 | return matrix
175 |
176 |
177 | def lexicon_matrix(text, aspects, lexicon):
178 | # https://spacy.io/docs/usage/processing-text
179 | tokens = nlp(text)
180 | words = text.lower().split()
181 | assert len(words) == len(list(tokens))
182 |
183 | aspects_index = []
184 | for aspect in aspects:
185 | start = aspect[1]
186 | end = aspect[2]
187 | aspects_index.extend(list(range(start, end)))
188 | labels = []
189 | for i in range(len(tokens)):
190 | if words[i] not in lexicon.keys() or i in aspects_index:
191 | labels.append(0)
192 | else:
193 | labels.append(5)
194 | lex_matrix = np.tile(np.array(labels), (len(tokens), 1))
195 | return lex_matrix
196 |
197 |
198 | def build_graph(text, aspects, lexicon):
199 | rel = relation_matrix(text, aspects)
200 | np.fill_diagonal(rel, 1)
201 | mask = (np.zeros_like(rel) != rel).astype('float32')
202 |
203 | lex = lexicon_matrix(text, aspects, lexicon)
204 | lex = lex * mask
205 |
206 | dis = distance_matrix(text)
207 | np.fill_diagonal(dis, 1)
208 | dis = dis * mask
209 |
210 | return dis, rel, lex
211 |
212 |
213 | def map_bert_2D(ori_adj, text):
214 | words = text.split()
215 | bert_tokens = []
216 | bert_map = []
217 | for src_i, word in enumerate(words):
218 | for subword in tokenizer.tokenize(word):
219 | bert_tokens.append(subword) # * ['expand', '##able', 'highly', 'like', '##ing']
220 | bert_map.append(src_i) # * [0, 0, 1, 2, 2]
221 |
222 | truncate_tok_len = len(bert_tokens)
223 | bert_adj = np.zeros((truncate_tok_len, truncate_tok_len), dtype='float32')
224 | for i in range(truncate_tok_len):
225 | for j in range(truncate_tok_len):
226 | bert_adj[i][j] = ori_adj[bert_map[i]][bert_map[j]]
227 | return bert_adj
228 |
229 |
230 | def opinion_lexicon():
231 | pos_file = 'opinion-lexicon-English/positive-words.txt'
232 | neg_file = 'opinion-lexicon-English/negative-words.txt'
233 | fin1 = open(pos_file, 'r', encoding='utf-8', newline='\n', errors='ignore')
234 | fin2 = open(neg_file, 'r', encoding='utf-8', newline='\n', errors='ignore')
235 | lines1 = fin1.readlines()
236 | lines2 = fin2.readlines()
237 | fin1.close()
238 | fin2.close()
239 | lexicon = {}
240 | for pos_word in lines1:
241 | lexicon[pos_word.strip()] = 'positive'
242 | for neg_word in lines2:
243 | lexicon[neg_word.strip()] = 'negative'
244 |
245 | return lexicon
246 |
247 |
248 | def process(filename):
249 | fin = open(filename, 'r', encoding='utf-8', newline='\n', errors='ignore')
250 | lines = fin.readlines()
251 | fin.close()
252 |
253 | idx2graph_dis, idx2graph_rel, idx2graph_lex = {}, {}, {}
254 | lexicon = opinion_lexicon()
255 |
256 | fout1 = open(filename + '_distance.pkl', 'wb')
257 | fout2 = open(filename + '_relation.pkl', 'wb')
258 | fout3 = open(filename + '_opinion.pkl', 'wb')
259 |
260 | for i in tqdm(range(0, len(lines), 3)):
261 | text = lines[i].strip()
262 | all_aspect = lines[i + 1].strip()
263 | aspects = []
264 | for aspect_index in all_aspect.split('\t'):
265 | aspect, start, end = aspect_index.split('#')
266 | aspects.append([aspect, int(start), int(end)])
267 |
268 | dis_adj, rel_adj, lex_adj = build_graph(text, aspects, lexicon)
269 | bert_dis_adj = map_bert_2D(dis_adj, text)
270 | bert_rel_adj = map_bert_2D(rel_adj, text)
271 | bert_lex_adj = map_bert_2D(lex_adj, text)
272 |
273 | idx2graph_dis[i] = bert_dis_adj
274 | idx2graph_rel[i] = bert_rel_adj
275 | idx2graph_lex[i] = bert_lex_adj
276 |
277 | pickle.dump(idx2graph_dis, fout1)
278 | pickle.dump(idx2graph_rel, fout2)
279 | pickle.dump(idx2graph_lex, fout3)
280 | fout1.close()
281 | fout2.close()
282 | fout3.close()
283 |
284 |
285 | if __name__ == '__main__':
286 | process('dataset/lap14_train')
287 | process('dataset/lap14_test')
288 | process('dataset/rest14_train')
289 | process('dataset/rest14_test')
290 | process('dataset/mams_train')
291 | process('dataset/mams_dev')
292 | process('dataset/mams_test')
293 |
--------------------------------------------------------------------------------
/layer/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
--------------------------------------------------------------------------------
/layer/rgcn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def normalize(mx):
8 | """Row-normalize sparse matrix"""
9 | rowsum = mx.sum(dim=2) # Compute row sums along the last dimension
10 | r_inv = rowsum.pow(-1)
11 | r_inv[torch.isinf(r_inv)] = 0.
12 | r_mat_inv = torch.diag_embed(r_inv) # Create a batch of diagonal matrices
13 | mx = torch.matmul(r_mat_inv, mx)
14 | return mx
15 |
16 |
17 | class RelationalGraphConvLayer(nn.Module):
18 | def __init__(self, num_rel, input_size, output_size, bias=True):
19 | super(RelationalGraphConvLayer, self).__init__()
20 | self.num_rel = num_rel
21 | self.input_size = input_size
22 | self.output_size = output_size
23 |
24 | self.weight = nn.Parameter(torch.FloatTensor(self.num_rel, self.input_size, self.output_size))
25 | if bias:
26 | self.bias = nn.Parameter(torch.FloatTensor(self.output_size))
27 | else:
28 | self.register_parameter("bias", None)
29 |
30 | def forward(self, text, adj):
31 | weights = self.weight.view(self.num_rel * self.input_size, self.output_size) # r*input_size, output_size
32 | supports = []
33 | for i in range(self.num_rel):
34 | hidden = torch.bmm(normalize(adj[:, i]), text)
35 | supports.append(hidden)
36 | tmp = torch.cat(supports, dim=-1)
37 | output = torch.matmul(tmp.float(), weights) # batch_size, seq_len, output_size)
38 | if self.bias is not None:
39 | return output + self.bias
40 | else:
41 | return output
42 |
--------------------------------------------------------------------------------
/layer/supervisedcontrastiveloss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class SupervisedContrastiveLoss(nn.Module):
8 | def __init__(self, temperature=0.07):
9 | """
10 | Implementation of the loss described in the paper Supervised Contrastive Learning :
11 | https://arxiv.org/abs/2004.11362
12 |
13 | :param temperature: int
14 | """
15 | super(SupervisedContrastiveLoss, self).__init__()
16 | self.temperature = temperature
17 |
18 | def forward(self, projections, targets, weight=None):
19 | """
20 |
21 | :param projections: torch.Tensor, shape [batch_size, projection_dim]
22 | :param targets: torch.Tensor, shape [batch_size]
23 | :return: torch.Tensor, scalar
24 | """
25 | device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")
26 |
27 | dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
28 | # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
29 | exp_dot_tempered = (
30 | torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
31 | )
32 |
33 | mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
34 | mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
35 | mask_combined = mask_similar_class * mask_anchor_out # remove self
36 | cardinality_per_samples = torch.sum(mask_combined, dim=1) # num of positive examples
37 | if weight is not None:
38 | mask_combined = mask_combined * weight
39 | log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
40 | supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / (
41 | cardinality_per_samples + 1)
42 | supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)
43 |
44 | return supervised_contrastive_loss
45 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import logging
4 | import argparse
5 | import math
6 | import os
7 | import sys
8 | import random
9 | import numpy as np
10 |
11 | from sklearn import metrics
12 | from time import strftime, localtime
13 |
14 | from transformers import BertModel
15 |
16 | import torch
17 | import torch.nn as nn
18 | from torch.utils.data import DataLoader, random_split
19 |
20 | from data_utils import Tokenizer4Bert, ABSADataset
21 | from model import YORO
22 | from layer.supervisedcontrastiveloss import SupervisedContrastiveLoss
23 |
24 | logger = logging.getLogger()
25 | logger.setLevel(logging.INFO)
26 | logger.addHandler(logging.StreamHandler(sys.stdout))
27 |
28 |
29 | class Instructor:
30 | def __init__(self, opt):
31 | self.opt = opt
32 | tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.pretrained_bert_name)
33 | bert = BertModel.from_pretrained(opt.pretrained_bert_name)
34 | self.model = opt.model_class(bert, opt).to(opt.device)
35 | self.trainset = ABSADataset(opt.dataset_file['train'], tokenizer)
36 | self.testset = ABSADataset(opt.dataset_file['test'], tokenizer)
37 | if self.opt.dataset == 'mams':
38 | self.valset = ABSADataset(opt.dataset_file['dev'], tokenizer)
39 | else:
40 | assert 0 <= opt.valset_ratio < 1
41 | if opt.valset_ratio > 0:
42 | valset_len = int(len(self.trainset) * opt.valset_ratio)
43 | self.trainset, self.valset = random_split(self.trainset, (len(self.trainset) - valset_len, valset_len))
44 | else:
45 | self.valset = self.testset
46 |
47 | if opt.device.type == 'cuda':
48 | logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index)))
49 | self._print_args()
50 |
51 | def _print_args(self):
52 | n_trainable_params, n_nontrainable_params = 0, 0
53 | for p in self.model.parameters():
54 | n_params = torch.prod(torch.tensor(p.shape))
55 | if p.requires_grad:
56 | n_trainable_params += n_params
57 | else:
58 | n_nontrainable_params += n_params
59 | logger.info(
60 | '> n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
61 | logger.info('> training arguments:')
62 | for arg in vars(self.opt):
63 | logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))
64 |
65 | def _reset_params(self):
66 | for child in self.model.children():
67 | if type(child) not in [BertModel, nn.Embedding]: # skip bert params and embedding
68 | for p in child.parameters():
69 | if p.requires_grad:
70 | if len(p.shape) > 1:
71 | self.opt.initializer(p)
72 | else:
73 | stdv = 1. / math.sqrt(p.shape[0])
74 | torch.nn.init.uniform_(p, a=-stdv, b=stdv)
75 |
76 | def _train(self, criterion, optimizer, train_data_loader, val_data_loader, test_data_loader):
77 | max_val_acc = 0
78 | max_val_f1 = 0
79 | max_val_epoch = 0
80 | global_step = 0
81 | path = None
82 | for i_epoch in range(self.opt.num_epoch):
83 | logger.info('>' * 100)
84 | logger.info('epoch: {}'.format(i_epoch))
85 | n_correct, n_total, loss_total = 0, 0, 0
86 | n_op_correct, n_op_total = 0, 0
87 | # switch model to training mode
88 | self.model.train()
89 | for i_batch, batch in enumerate(train_data_loader):
90 | global_step += 1
91 | # clear gradient accumulators
92 | optimizer.zero_grad()
93 |
94 | inputs = [batch[col].to(self.opt.device) for col in self.opt.inputs_cols]
95 | outputs, opinion_outputs = self.model(inputs)
96 |
97 | targets = batch['polarities'].to(self.opt.device)
98 | outputs = outputs.view(-1, self.opt.polarities_dim) # bz*128,3
99 | targets = targets.view(-1) # bz*128,1
100 | mask = targets != -1 # bz*128, 1 non-aspect False aspect True
101 | mask_outputs = outputs[mask]
102 | mask_targets = targets[mask]
103 | loss1 = criterion[0](mask_outputs, mask_targets)
104 |
105 | opinion_targets = batch['opinion_indices'].to(self.opt.device)
106 | opinion_outputs = opinion_outputs.view(-1, self.opt.polarities_dim)
107 | opinion_targets = opinion_targets.view(-1) # bz*128,1
108 | opinion_mask = opinion_targets != -1
109 | mask_opinion_outputs = opinion_outputs[opinion_mask]
110 | mask_opinion_targets = opinion_targets[opinion_mask]
111 | loss2 = criterion[0](mask_opinion_outputs, mask_opinion_targets)
112 |
113 | loss3 = criterion[1](nn.functional.normalize(mask_outputs, dim=1), mask_targets)
114 |
115 | loss = loss1 + loss2 + self.opt.alpha * loss3 # 0.5
116 | loss.backward()
117 | optimizer.step()
118 |
119 | n_correct += (torch.argmax(mask_outputs, -1) == mask_targets).sum().item()
120 | n_total += len(mask_outputs)
121 | n_op_correct += (torch.argmax(mask_opinion_outputs, -1) == mask_opinion_targets).sum().item()
122 | n_op_total += len(mask_opinion_outputs)
123 |
124 | loss_total += loss.item()
125 | if global_step % self.opt.log_step == 0:
126 | train_acc = n_correct / n_total
127 | train_loss = loss_total / n_total
128 | train_op_acc = n_op_correct / n_op_total
129 | logger.info('loss: {:.4f}, acc: {:.4f}, op_acc: {:.4f}, '
130 | 'loss1: {:.4f}, loss2: {:.4f}, loss3: {:.4f}'.format(train_loss, train_acc,
131 | train_op_acc, loss1, loss2, loss3))
132 |
133 | val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader)
134 | logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1))
135 |
136 | if val_acc > max_val_acc: # acc improve
137 | max_val_acc = val_acc
138 | max_val_f1 = val_f1
139 | max_val_epoch = i_epoch
140 |
141 | if not os.path.exists(
142 | '{}/{}/{}'.format(self.opt.save_model_dir, self.opt.model_name, self.opt.dataset)):
143 | os.makedirs('{}/{}/{}'.format(self.opt.save_model_dir, self.opt.model_name, self.opt.dataset))
144 | path = '{0}/{1}/{2}/acc_{3}_f1_{4}_{5}.model'.format(self.opt.save_model_dir, self.opt.model_name,
145 | self.opt.dataset,
146 | round(val_acc, 4), round(val_f1, 4),
147 | strftime("%y%m%d-%H%M", localtime()))
148 | torch.save(self.model.state_dict(), path)
149 | logger.info('>> saved: {}'.format(path))
150 | if val_f1 > max_val_f1:
151 | max_val_f1 = val_f1
152 | if i_epoch - max_val_epoch >= self.opt.patience:
153 | print('>> early stop.')
154 | break
155 |
156 | return path
157 |
158 | def _evaluate_acc_f1(self, data_loader):
159 | n_correct, n_total = 0, 0
160 | t_targets_all, t_outputs_all = None, None
161 | # switch model to evaluation mode
162 | self.model.eval()
163 | with torch.no_grad():
164 | for i_batch, t_batch in enumerate(data_loader):
165 | t_inputs = [t_batch[col].to(self.opt.device) for col in self.opt.inputs_cols]
166 | t_targets = t_batch['polarities'].to(self.opt.device)
167 | t_outputs, t_opinion_outputs = self.model(t_inputs)
168 |
169 | t_targets = t_targets.view(-1)
170 | t_outputs = t_outputs.view(-1, self.opt.polarities_dim)
171 | t_mask = t_targets.view(-1) != -1
172 | t_mask_outputs = t_outputs[t_mask]
173 | t_mask_targets = t_targets[t_mask]
174 |
175 | n_correct += (torch.argmax(t_mask_outputs, -1) == t_mask_targets).sum().item()
176 | n_total += len(t_mask_outputs)
177 |
178 | if t_targets_all is None:
179 | t_targets_all = t_mask_targets
180 | t_outputs_all = t_mask_outputs
181 | else:
182 | t_targets_all = torch.cat((t_targets_all, t_mask_targets), dim=0)
183 | t_outputs_all = torch.cat((t_outputs_all, t_mask_outputs), dim=0)
184 |
185 | acc = n_correct / n_total
186 | f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2],
187 | average='macro')
188 | return acc, f1
189 |
190 | def run(self):
191 | acc_list, f1_list = [], []
192 | # Loss and Optimizer
193 | criterion = [nn.CrossEntropyLoss(), SupervisedContrastiveLoss()]
194 | _params = filter(lambda p: p.requires_grad, self.model.parameters())
195 | optimizer = self.opt.optimizer(_params, lr=self.opt.lr, weight_decay=self.opt.l2reg)
196 |
197 | train_data_loader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True)
198 | test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False)
199 | val_data_loader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False)
200 |
201 | for i in range(self.opt.repeat):
202 | self._reset_params()
203 | best_model_path = self._train(criterion, optimizer, train_data_loader, val_data_loader, test_data_loader)
204 | self.model.load_state_dict(torch.load(best_model_path))
205 | test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader)
206 | logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1))
207 | acc_list.append(test_acc)
208 | f1_list.append(test_f1)
209 | all_acc = np.asarray(acc_list)
210 | avg_acc = np.average(all_acc)
211 | all_f1 = np.asarray(f1_list)
212 | avg_f1 = np.average(all_f1)
213 | for acc, f1 in zip(acc_list, f1_list):
214 | logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(acc, f1))
215 | logger.info('>> avg_test_acc: {:.4f}, avg_test_f1: {:.4f}'.format(avg_acc, avg_f1))
216 |
217 |
218 | def main():
219 | # Hyper Parameters
220 | parser = argparse.ArgumentParser()
221 | parser.add_argument('--model_name', default='YORO', type=str)
222 | parser.add_argument('--dataset', default='rest14', type=str, help='mams, rest14, lap14')
223 | parser.add_argument('--optimizer', default='adam', type=str)
224 | parser.add_argument('--initializer', default='xavier_uniform_', type=str)
225 | parser.add_argument('--repeat', default=1, type=int)
226 | parser.add_argument('--lr', default=2e-5, type=float, help='try 5e-5, 2e-5 for BERT, 1e-3 for others')
227 | parser.add_argument('--dropout', default=0.3, type=float)
228 | parser.add_argument('--l2reg', default=1e-4, type=float)
229 | parser.add_argument('--num_epoch', default=20, type=int, help='try larger number for non-BERT models')
230 | parser.add_argument('--batch_size', default=16, type=int, help='try 16, 32, 64 for BERT models')
231 | parser.add_argument('--log_step', default=10, type=int)
232 | parser.add_argument('--bert_dim', default=768, type=int)
233 | parser.add_argument('--hidden_dim', default=768, type=int)
234 | parser.add_argument('--pretrained_bert_name', default='bert-base-uncased', type=str)
235 | parser.add_argument('--max_seq_len', default=128, type=int)
236 | parser.add_argument('--polarities_dim', default=3, type=int)
237 | parser.add_argument('--alpha', default=0.5, type=float)
238 | parser.add_argument('--patience', default=5, type=int)
239 | parser.add_argument('--device', default=None, type=str, help='e.g. cuda:0')
240 | parser.add_argument('--seed', default=1, type=int, help='set seed for reproducibility')
241 | parser.add_argument('--valset_ratio', default=0, type=float,help='set ratio between 0 and 1 for validation support')
242 | parser.add_argument('--save_model_dir', default='/Your_Path', type=str)
243 |
244 | opt = parser.parse_args()
245 |
246 | if opt.seed is not None:
247 | random.seed(opt.seed)
248 | np.random.seed(opt.seed)
249 | torch.manual_seed(opt.seed)
250 | torch.cuda.manual_seed(opt.seed)
251 | torch.backends.cudnn.deterministic = True
252 | torch.backends.cudnn.benchmark = False
253 | os.environ['PYTHONHASHSEED'] = str(opt.seed)
254 |
255 | model_classes = {
256 | 'YORO': YORO,
257 | }
258 | input_colses = {
259 | 'YORO': ['input_ids', 'token_type_ids', 'attention_mask', 'distance_adj', 'relation_adj'],
260 | }
261 | dataset_files = {
262 | 'lap14': {
263 | 'train': './dataset/lap14_train',
264 | 'test': './dataset/lap14_test'
265 | },
266 | 'rest14': {
267 | 'train': './dataset/rest14_train',
268 | 'test': './dataset/rest14_test'
269 | },
270 | 'mams': {
271 | 'train': './dataset/mams_train',
272 | 'dev': './dataset/mams_dev',
273 | 'test': './dataset/mams_test'
274 | }
275 | }
276 | initializers = {
277 | 'xavier_uniform_': torch.nn.init.xavier_uniform_,
278 | 'xavier_normal_': torch.nn.init.xavier_normal_,
279 | 'orthogonal_': torch.nn.init.orthogonal_,
280 | }
281 | optimizers = {
282 | 'adadelta': torch.optim.Adadelta, # default lr=1.0
283 | 'adagrad': torch.optim.Adagrad, # default lr=0.01
284 | 'adam': torch.optim.Adam, # default lr=0.001
285 | 'adamax': torch.optim.Adamax, # default lr=0.002
286 | 'asgd': torch.optim.ASGD, # default lr=0.01
287 | 'rmsprop': torch.optim.RMSprop, # default lr=0.01
288 | 'sgd': torch.optim.SGD,
289 | 'adamw': torch.optim.AdamW,
290 | }
291 | opt.model_class = model_classes[opt.model_name]
292 | opt.dataset_file = dataset_files[opt.dataset]
293 | opt.inputs_cols = input_colses[opt.model_name]
294 | opt.initializer = initializers[opt.initializer]
295 | opt.optimizer = optimizers[opt.optimizer]
296 | opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \
297 | if opt.device is None else torch.device(opt.device)
298 |
299 | if not os.path.exists('log/{}'.format(opt.model_name)):
300 | os.makedirs('log/{}'.format(opt.model_name))
301 | log_file = 'log/{}/{}-{}.log'.format(opt.model_name, opt.dataset, strftime("%y%m%d-%H%M", localtime()))
302 | logger.addHandler(logging.FileHandler(log_file))
303 |
304 | ins = Instructor(opt)
305 | ins.run()
306 |
307 |
308 | if __name__ == '__main__':
309 | main()
310 |
--------------------------------------------------------------------------------
/model/YORO.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: UTF-8 -*-
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from layer.rgcn import RelationalGraphConvLayer
6 |
7 |
8 | class YORO(nn.Module):
9 | def __init__(self, bert, args):
10 | super(YORO, self).__init__()
11 | self.bert = bert
12 | self.rgc1 = RelationalGraphConvLayer(5, args.bert_dim, args.bert_dim)
13 | self.rgc2 = RelationalGraphConvLayer(5, args.bert_dim, args.bert_dim)
14 | self.dropout = nn.Dropout(args.dropout)
15 | self.op_dense = nn.Linear(args.bert_dim, args.polarities_dim)
16 | self.dense = nn.Linear(args.bert_dim, args.polarities_dim)
17 |
18 | def forward(self, inputs):
19 | input_ids, token_type_ids, attention_mask, distance_adj, relation_adj = inputs
20 | output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
21 | hidden = output.last_hidden_state
22 |
23 | adj = distance_adj.unsqueeze(1).expand(-1, 5, -1, -1) * relation_adj
24 | x = F.relu(self.rgc1(hidden, adj))
25 | x = self.dropout(x)
26 | x = F.relu(self.rgc2(x, adj))
27 |
28 | hidden_output = self.dropout(x)
29 | op_logits = self.op_dense(hidden_output)
30 | logits = self.dense(hidden_output)
31 | return logits, op_logits
32 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from model.YORO import YORO
4 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python main.py --model_name YORO --dataset mams
2 | python main.py --model_name YORO --dataset rest14
3 | python main.py --model_name YORO --dataset lap14
--------------------------------------------------------------------------------