├── README.md ├── assets ├── graph.png └── model.png ├── data_utils.py ├── graph.py ├── layer ├── __init__.py ├── rgcn.py └── supervisedcontrastiveloss.py ├── main.py ├── model ├── YORO.py └── __init__.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # **YORO** 2 | Code for Paper "You Only Read Once: Constituency-Oriented Relational Graph Convolutional Network for Multi-Aspect Multi-Sentiment Classification" 3 | 4 | *AAAI2024* 5 | 6 | Yongqiang Zheng, Xia Li 7 | 8 | 9 | ## **Model** 10 | ![](assets/model.png) 11 | 12 | ## **Requirements** 13 | - python==3.10.12 14 | - torch==1.12.1+cu113 15 | - transformers==4.30.2 16 | - scikit-learn==1.2.2 17 | - benepar==0.2.0 18 | 19 | ## **Datasets** 20 | Download datasets from these links and put them in the **dataset** folder: 21 | - [MAMS](https://github.com/siat-nlp/MAMS-for-ABSA) 22 | - [Rest14](https://alt.qcri.org/semeval2014/task4) 23 | - [Lap14](https://alt.qcri.org/semeval2014/task4) 24 | 25 | ## **Usage** 26 | 1. Download Bing Liu's opinion lexicon 27 | ``` 28 | wget http://www.cs.uic.edu/\~liub/FBS/opinion-lexicon-English.rar 29 | sudo apt-get install unrar 30 | unrar x opinion-lexicon-English.rar 31 | mv opinion-lexicon-English lexicon 32 | ``` 33 | 3. Generate constituency-oriented graph 34 | ``` 35 | python graph.py 36 | ``` 37 | 38 | An example of the construction of Constituency-Oriented Relational Graph Convolutional Network (CorrGCN) 39 | ![](assets/graph.png) 40 | 41 | ## **Training** 42 | ``` 43 | bash train.sh 44 | ``` 45 | 46 | ## **Credits** 47 | The code in this repository is based on [SEGCN-ABSA](https://github.com/gdufsnlp/SEGCN-ABSA). 48 | 49 | ## **Citation** 50 | ```bibtex 51 | @inproceedings{zheng2024you, 52 | title = {You Only Read Once: Constituency-Oriented Relational Graph Convolutional Network for Multi-Aspect Multi-Sentiment Classification}, 53 | author = {Zheng, Yongqiang and Li, Xia}, 54 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, 55 | volume = {38}, 56 | number = {17}, 57 | pages = {19715--19723}, 58 | year = {2024}, 59 | url = {https://ojs.aaai.org/index.php/AAAI/article/view/29945}, 60 | doi = {10.1609/aaai.v38i17.29945}, 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /assets/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gdufsnlp/YORO/0c1d045ed21858731522a87349d5b59ec875502e/assets/graph.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gdufsnlp/YORO/0c1d045ed21858731522a87349d5b59ec875502e/assets/model.png -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import pickle 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from transformers import BertTokenizer 7 | 8 | 9 | def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0): 10 | x = (np.ones(maxlen) * value).astype(dtype) 11 | if truncating == 'pre': 12 | trunc = sequence[-maxlen:] 13 | else: 14 | trunc = sequence[:maxlen] 15 | trunc = np.asarray(trunc, dtype=dtype) 16 | if padding == 'post': 17 | x[:len(trunc)] = trunc 18 | else: 19 | x[-len(trunc):] = trunc 20 | return x 21 | 22 | 23 | def opinion_lexicon(): 24 | pos_file = 'lexicon/positive-words.txt' 25 | neg_file = 'lexicon/negative-words.txt' 26 | lexicon = {} 27 | fin1 = open(pos_file, 'r', encoding='utf-8', newline='\n', errors='ignore') 28 | fin2 = open(neg_file, 'r', encoding='utf-8', newline='\n', errors='ignore') 29 | lines1 = fin1.readlines() 30 | lines2 = fin2.readlines() 31 | fin1.close() 32 | fin2.close() 33 | for pos_word in lines1: 34 | lexicon[pos_word.strip()] = 'positive' 35 | for neg_word in lines2: 36 | lexicon[neg_word.strip()] = 'negative' 37 | return lexicon 38 | 39 | 40 | class Tokenizer4Bert: 41 | def __init__(self, max_seq_len, pretrained_bert_name): 42 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name) 43 | self.max_seq_len = max_seq_len 44 | 45 | def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'): 46 | sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) 47 | if len(sequence) == 0: 48 | sequence = [0] 49 | if reverse: 50 | sequence = sequence[::-1] 51 | return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating) 52 | 53 | def opinion_in_text(self, text, aspects, lexicon): 54 | aspect_index = [] 55 | for asp_idx in aspects: 56 | _, start, end = asp_idx 57 | aspect_index.extend(list(range(start + 1, end + 1))) 58 | 59 | opinion_index = [] 60 | for idx, word in enumerate(text.split()): 61 | for t in self.tokenizer.tokenize(word): 62 | if idx in aspect_index: 63 | opinion_index.append(-1) # skip aspect words 64 | elif word in lexicon.keys(): 65 | if lexicon[word] == 'negative': 66 | opinion_index.append(0) 67 | elif lexicon[word] == 'positive': 68 | opinion_index.append(2) 69 | else: 70 | opinion_index.append(1) 71 | assert len(opinion_index) == len(self.tokenizer.tokenize(text)) 72 | return pad_and_truncate(opinion_index, self.max_seq_len, value=-1) 73 | 74 | def map_bert_1D(self, text): 75 | words = text.split() 76 | # bert_tokens = [] 77 | bert_map = [] 78 | for src_i, word in enumerate(words): 79 | for subword in self.tokenizer.tokenize(word): 80 | # bert_tokens.append(subword) # * ['expand', '##able', 'highly', 'like', '##ing'] 81 | bert_map.append(src_i) # * [0, 0, 1, 2, 2] 82 | 83 | return bert_map 84 | 85 | 86 | class ABSADataset(Dataset): 87 | def __init__(self, file, tokenizer): 88 | self.file = file 89 | self.tokenizer = tokenizer 90 | self.load_data() 91 | 92 | def load_data(self): 93 | fin = open(self.file, 'r', encoding='utf-8', newline='\n', errors='ignore') 94 | lines = fin.readlines() 95 | fin.close() 96 | fin = open(self.file + '_relation.pkl', 'rb') 97 | rel_matrix = pickle.load(fin) 98 | fin.close() 99 | fin = open(self.file + '_opinion.pkl', 'rb') 100 | lex_matrix = pickle.load(fin) 101 | fin.close() 102 | fin = open(self.file + '_distance.pkl', 'rb') 103 | dis_matrix = pickle.load(fin) 104 | fin.close() 105 | 106 | lexicon = opinion_lexicon() 107 | all_data = [] 108 | for i in range(0, len(lines), 3): 109 | text = lines[i].lower().strip() 110 | all_aspect = lines[i + 1].lower().strip() 111 | all_polarity = lines[i + 2].strip() 112 | aspects = [] 113 | for aspect_idx in all_aspect.split('\t'): 114 | aspect, start, end = aspect_idx.split('#') 115 | aspects.append([aspect, int(start), int(end)]) 116 | labels = [] 117 | for label in all_polarity.split('\t'): 118 | labels.append(int(label) + 1) 119 | 120 | text_len = len(self.tokenizer.tokenizer.tokenize(text)) 121 | input_ids = self.tokenizer.text_to_sequence('[CLS] ' + text + ' [SEP]') 122 | token_type_ids = [0] * (text_len + 2) 123 | attention_mask = [1] * len(token_type_ids) 124 | token_type_ids = pad_and_truncate(token_type_ids, self.tokenizer.max_seq_len) 125 | attention_mask = pad_and_truncate(attention_mask, self.tokenizer.max_seq_len) 126 | opinion_indices = self.tokenizer.opinion_in_text('[CLS] ' + text + ' [SEP]', aspects, lexicon) 127 | 128 | distance_adj = np.zeros((self.tokenizer.max_seq_len, self.tokenizer.max_seq_len)).astype('float32') 129 | distance_adj[1:text_len + 1, 1:text_len + 1] = dis_matrix[i] 130 | relation_adj = np.zeros((5, self.tokenizer.max_seq_len, self.tokenizer.max_seq_len)).astype('float32') 131 | for j in range(0, 4): 132 | r_tmp = np.where(rel_matrix[i] == j + 1, 1, 0) 133 | relation_adj[j, 1:text_len + 1, 1:text_len + 1] = r_tmp 134 | for k in range(4, 5): 135 | l_tmp = np.where(lex_matrix[i] == k + 1, 1, 0) 136 | relation_adj[k, 1:text_len + 1, 1:text_len + 1] = l_tmp 137 | polarities = [-1] * self.tokenizer.max_seq_len 138 | 139 | bert_index = self.tokenizer.map_bert_1D(text) 140 | for asp_idx, pol in zip(aspects, labels): 141 | _, start, end = asp_idx 142 | # label the first token of aspect 143 | polarities[bert_index.index(start) + 1] = pol # +1 for cls 144 | 145 | polarities = np.asarray(polarities) 146 | data = { 147 | 'input_ids': input_ids, 148 | 'token_type_ids': token_type_ids, 149 | 'attention_mask': attention_mask, 150 | 'distance_adj': distance_adj, 151 | 'relation_adj': relation_adj, 152 | 'polarities': polarities, 153 | 'opinion_indices': opinion_indices, 154 | } 155 | 156 | all_data.append(data) 157 | self.data = all_data 158 | 159 | def __getitem__(self, index): 160 | return self.data[index] 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import benepar 4 | import numpy as np 5 | import pickle 6 | import spacy 7 | 8 | from spacy.tokens import Doc 9 | from tqdm import tqdm 10 | from transformers import BertTokenizer 11 | 12 | 13 | class WhitespaceTokenizer(object): 14 | def __init__(self, vocab): 15 | self.vocab = vocab 16 | 17 | def __call__(self, text): 18 | words = text.split() 19 | # All tokens 'own' a subsequent space character in this tokenizer 20 | spaces = [True] * len(words) 21 | return Doc(self.vocab, words=words, spaces=spaces) 22 | 23 | 24 | # spaCy + Berkeley 25 | nlp = spacy.load('en_core_web_md') 26 | nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) 27 | nlp.add_pipe("benepar", config={"model": "benepar_en3"}) 28 | # BERT 29 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 30 | 31 | 32 | def get_unique_elements(lists, aspects): 33 | unique_lists = [] 34 | for i, lst in enumerate(lists): 35 | other_lists = lists[:i] + lists[i + 1:] 36 | unique = set(lst) - set.union(*map(set, other_lists)) 37 | if len(unique) > 0: 38 | cons = max(list(unique), key=lambda x: len(x)) 39 | unique_lists.append([cons.start, cons.end]) 40 | else: 41 | start = aspects[i][1] 42 | end = aspects[i][2] 43 | unique_lists.append([start, end]) 44 | return unique_lists 45 | 46 | 47 | def single_aspect(text, aspects): 48 | # https://spacy.io/docs/usage/processing-text 49 | tokens = nlp(text) 50 | words = text.split() 51 | assert len(words) == len(list(tokens)) 52 | 53 | token = aspects[0] 54 | asp, start, end = token[0], token[1], token[2] 55 | aspect_specific = [] 56 | all_cons = [] 57 | for sent in tokens.sents: 58 | for cons in sent._.constituents: 59 | if cons.text == text: 60 | continue 61 | all_cons.append(cons) 62 | if cons.start <= start <= end <= cons.end: 63 | if len(cons._.labels) > 0: # len(cons) > 1: 64 | aspect_specific.append(cons) 65 | aspect_specific_cons = [] 66 | aspect_tag = '' 67 | for cons in aspect_specific: 68 | if cons._.labels[0] != 'S': 69 | aspect_specific_cons.append([cons.start, cons.end]) 70 | aspect_tag = cons._.labels[0] 71 | break 72 | 73 | for cons in all_cons: 74 | if len(cons._.labels) > 0 and cons._.labels[0] == aspect_tag: # len(cons) != 1 75 | flag = True 76 | for asp_cons in aspect_specific_cons: 77 | if cons.end <= asp_cons[0] or cons.start >= asp_cons[1]: 78 | continue 79 | else: 80 | flag = False 81 | if flag: 82 | aspect_specific_cons.append([cons.start, cons.end]) 83 | if aspect_specific_cons == []: 84 | aspect_specific_cons.append([0, len(words)]) 85 | return aspect_specific_cons 86 | 87 | 88 | def distance_matrix(text): 89 | # https://spacy.io/docs/usage/processing-text 90 | tokens = nlp(text) 91 | words = text.split() 92 | matrix = np.zeros((len(words), len(words))).astype('float32') 93 | assert len(words) == len(list(tokens)) 94 | 95 | for sent in tokens.sents: 96 | for cons in sent._.constituents: 97 | if len(cons) == 1: 98 | continue 99 | matrix[cons.start:cons.end, cons.start:cons.end] += np.ones([len(cons), len(cons)]) 100 | 101 | hops_matrix = np.amax(matrix, axis=1, keepdims=True) - matrix # hops 102 | dis_matrix = 2 - hops_matrix / (np.amax(hops_matrix, axis=1, keepdims=True) + 1) 103 | 104 | return dis_matrix 105 | 106 | 107 | def relation_matrix(text, aspects): 108 | # https://spacy.io/docs/usage/processing-text 109 | tokens = nlp(text) 110 | words = text.split() 111 | matrix = np.eye(len(words)).astype('float32') 112 | if len(words) != len(list(tokens)): 113 | print(words) 114 | print(list(tokens)) 115 | assert len(words) == len(list(tokens)) 116 | 117 | all_start = [aspect[1] for aspect in aspects] 118 | relations = [False] * len(tokens) 119 | 120 | if len(aspects) > 1: 121 | # intra-aspect 122 | # aspect-related collection 123 | aspect_nodes = [[] for _ in range(len(aspects))] 124 | for sent in tokens.sents: 125 | for cons in sent._.constituents: 126 | for idx, token in enumerate(aspects): 127 | asp, start, end = token[0], token[1], token[2] 128 | if cons.start <= start and end <= cons.end: 129 | aspect_nodes[idx].append(cons) 130 | # aspect-specific 131 | aspect_specific_cons = get_unique_elements(aspect_nodes, aspects) 132 | for idx, cons in enumerate(aspect_specific_cons): 133 | for i in range(cons[0], cons[1]): 134 | matrix[all_start[idx]][i] = 2 135 | matrix[i][all_start[idx]] = 2 136 | relations[i] = True 137 | # globally-shared 138 | for i in range(len(relations)): 139 | if not relations[i]: 140 | for j in all_start: 141 | matrix[i][j] = 3 142 | matrix[j][i] = 3 143 | # inter-aspect 144 | for i in range(len(all_start)): 145 | for j in range(i + 1, len(all_start)): 146 | matrix[all_start[i]][all_start[j]] = 4 147 | matrix[all_start[j]][all_start[i]] = 4 148 | else: 149 | # pseudo aspect 150 | # intra-aspect 151 | # aspect-related collection 152 | aspect_specific_cons = single_aspect(text, aspects) 153 | all_start += [aspect[0] for aspect in aspect_specific_cons[1:]] 154 | 155 | # aspect-specific 156 | for idx, cons in enumerate(aspect_specific_cons): 157 | for i in range(cons[0], cons[1]): 158 | matrix[all_start[idx]][i] = 2 159 | matrix[i][all_start[idx]] = 2 160 | relations[i] = True 161 | # globally-shared 162 | for i in range(len(relations)): 163 | if not relations[i]: 164 | for j in all_start: 165 | matrix[i][j] = 3 166 | matrix[j][i] = 3 167 | 168 | # inter-aspect 169 | for i in range(len(all_start)): 170 | for j in range(i + 1, len(all_start)): 171 | matrix[all_start[i]][all_start[j]] = 4 172 | matrix[all_start[j]][all_start[i]] = 4 173 | 174 | return matrix 175 | 176 | 177 | def lexicon_matrix(text, aspects, lexicon): 178 | # https://spacy.io/docs/usage/processing-text 179 | tokens = nlp(text) 180 | words = text.lower().split() 181 | assert len(words) == len(list(tokens)) 182 | 183 | aspects_index = [] 184 | for aspect in aspects: 185 | start = aspect[1] 186 | end = aspect[2] 187 | aspects_index.extend(list(range(start, end))) 188 | labels = [] 189 | for i in range(len(tokens)): 190 | if words[i] not in lexicon.keys() or i in aspects_index: 191 | labels.append(0) 192 | else: 193 | labels.append(5) 194 | lex_matrix = np.tile(np.array(labels), (len(tokens), 1)) 195 | return lex_matrix 196 | 197 | 198 | def build_graph(text, aspects, lexicon): 199 | rel = relation_matrix(text, aspects) 200 | np.fill_diagonal(rel, 1) 201 | mask = (np.zeros_like(rel) != rel).astype('float32') 202 | 203 | lex = lexicon_matrix(text, aspects, lexicon) 204 | lex = lex * mask 205 | 206 | dis = distance_matrix(text) 207 | np.fill_diagonal(dis, 1) 208 | dis = dis * mask 209 | 210 | return dis, rel, lex 211 | 212 | 213 | def map_bert_2D(ori_adj, text): 214 | words = text.split() 215 | bert_tokens = [] 216 | bert_map = [] 217 | for src_i, word in enumerate(words): 218 | for subword in tokenizer.tokenize(word): 219 | bert_tokens.append(subword) # * ['expand', '##able', 'highly', 'like', '##ing'] 220 | bert_map.append(src_i) # * [0, 0, 1, 2, 2] 221 | 222 | truncate_tok_len = len(bert_tokens) 223 | bert_adj = np.zeros((truncate_tok_len, truncate_tok_len), dtype='float32') 224 | for i in range(truncate_tok_len): 225 | for j in range(truncate_tok_len): 226 | bert_adj[i][j] = ori_adj[bert_map[i]][bert_map[j]] 227 | return bert_adj 228 | 229 | 230 | def opinion_lexicon(): 231 | pos_file = 'opinion-lexicon-English/positive-words.txt' 232 | neg_file = 'opinion-lexicon-English/negative-words.txt' 233 | fin1 = open(pos_file, 'r', encoding='utf-8', newline='\n', errors='ignore') 234 | fin2 = open(neg_file, 'r', encoding='utf-8', newline='\n', errors='ignore') 235 | lines1 = fin1.readlines() 236 | lines2 = fin2.readlines() 237 | fin1.close() 238 | fin2.close() 239 | lexicon = {} 240 | for pos_word in lines1: 241 | lexicon[pos_word.strip()] = 'positive' 242 | for neg_word in lines2: 243 | lexicon[neg_word.strip()] = 'negative' 244 | 245 | return lexicon 246 | 247 | 248 | def process(filename): 249 | fin = open(filename, 'r', encoding='utf-8', newline='\n', errors='ignore') 250 | lines = fin.readlines() 251 | fin.close() 252 | 253 | idx2graph_dis, idx2graph_rel, idx2graph_lex = {}, {}, {} 254 | lexicon = opinion_lexicon() 255 | 256 | fout1 = open(filename + '_distance.pkl', 'wb') 257 | fout2 = open(filename + '_relation.pkl', 'wb') 258 | fout3 = open(filename + '_opinion.pkl', 'wb') 259 | 260 | for i in tqdm(range(0, len(lines), 3)): 261 | text = lines[i].strip() 262 | all_aspect = lines[i + 1].strip() 263 | aspects = [] 264 | for aspect_index in all_aspect.split('\t'): 265 | aspect, start, end = aspect_index.split('#') 266 | aspects.append([aspect, int(start), int(end)]) 267 | 268 | dis_adj, rel_adj, lex_adj = build_graph(text, aspects, lexicon) 269 | bert_dis_adj = map_bert_2D(dis_adj, text) 270 | bert_rel_adj = map_bert_2D(rel_adj, text) 271 | bert_lex_adj = map_bert_2D(lex_adj, text) 272 | 273 | idx2graph_dis[i] = bert_dis_adj 274 | idx2graph_rel[i] = bert_rel_adj 275 | idx2graph_lex[i] = bert_lex_adj 276 | 277 | pickle.dump(idx2graph_dis, fout1) 278 | pickle.dump(idx2graph_rel, fout2) 279 | pickle.dump(idx2graph_lex, fout3) 280 | fout1.close() 281 | fout2.close() 282 | fout3.close() 283 | 284 | 285 | if __name__ == '__main__': 286 | process('dataset/lap14_train') 287 | process('dataset/lap14_test') 288 | process('dataset/rest14_train') 289 | process('dataset/rest14_test') 290 | process('dataset/mams_train') 291 | process('dataset/mams_dev') 292 | process('dataset/mams_test') 293 | -------------------------------------------------------------------------------- /layer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | -------------------------------------------------------------------------------- /layer/rgcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def normalize(mx): 8 | """Row-normalize sparse matrix""" 9 | rowsum = mx.sum(dim=2) # Compute row sums along the last dimension 10 | r_inv = rowsum.pow(-1) 11 | r_inv[torch.isinf(r_inv)] = 0. 12 | r_mat_inv = torch.diag_embed(r_inv) # Create a batch of diagonal matrices 13 | mx = torch.matmul(r_mat_inv, mx) 14 | return mx 15 | 16 | 17 | class RelationalGraphConvLayer(nn.Module): 18 | def __init__(self, num_rel, input_size, output_size, bias=True): 19 | super(RelationalGraphConvLayer, self).__init__() 20 | self.num_rel = num_rel 21 | self.input_size = input_size 22 | self.output_size = output_size 23 | 24 | self.weight = nn.Parameter(torch.FloatTensor(self.num_rel, self.input_size, self.output_size)) 25 | if bias: 26 | self.bias = nn.Parameter(torch.FloatTensor(self.output_size)) 27 | else: 28 | self.register_parameter("bias", None) 29 | 30 | def forward(self, text, adj): 31 | weights = self.weight.view(self.num_rel * self.input_size, self.output_size) # r*input_size, output_size 32 | supports = [] 33 | for i in range(self.num_rel): 34 | hidden = torch.bmm(normalize(adj[:, i]), text) 35 | supports.append(hidden) 36 | tmp = torch.cat(supports, dim=-1) 37 | output = torch.matmul(tmp.float(), weights) # batch_size, seq_len, output_size) 38 | if self.bias is not None: 39 | return output + self.bias 40 | else: 41 | return output 42 | -------------------------------------------------------------------------------- /layer/supervisedcontrastiveloss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SupervisedContrastiveLoss(nn.Module): 8 | def __init__(self, temperature=0.07): 9 | """ 10 | Implementation of the loss described in the paper Supervised Contrastive Learning : 11 | https://arxiv.org/abs/2004.11362 12 | 13 | :param temperature: int 14 | """ 15 | super(SupervisedContrastiveLoss, self).__init__() 16 | self.temperature = temperature 17 | 18 | def forward(self, projections, targets, weight=None): 19 | """ 20 | 21 | :param projections: torch.Tensor, shape [batch_size, projection_dim] 22 | :param targets: torch.Tensor, shape [batch_size] 23 | :return: torch.Tensor, scalar 24 | """ 25 | device = torch.device("cuda") if projections.is_cuda else torch.device("cpu") 26 | 27 | dot_product_tempered = torch.mm(projections, projections.T) / self.temperature 28 | # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0) 29 | exp_dot_tempered = ( 30 | torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5 31 | ) 32 | 33 | mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device) 34 | mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device) 35 | mask_combined = mask_similar_class * mask_anchor_out # remove self 36 | cardinality_per_samples = torch.sum(mask_combined, dim=1) # num of positive examples 37 | if weight is not None: 38 | mask_combined = mask_combined * weight 39 | log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True))) 40 | supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / ( 41 | cardinality_per_samples + 1) 42 | supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample) 43 | 44 | return supervised_contrastive_loss 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import logging 4 | import argparse 5 | import math 6 | import os 7 | import sys 8 | import random 9 | import numpy as np 10 | 11 | from sklearn import metrics 12 | from time import strftime, localtime 13 | 14 | from transformers import BertModel 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.utils.data import DataLoader, random_split 19 | 20 | from data_utils import Tokenizer4Bert, ABSADataset 21 | from model import YORO 22 | from layer.supervisedcontrastiveloss import SupervisedContrastiveLoss 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | logger.addHandler(logging.StreamHandler(sys.stdout)) 27 | 28 | 29 | class Instructor: 30 | def __init__(self, opt): 31 | self.opt = opt 32 | tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.pretrained_bert_name) 33 | bert = BertModel.from_pretrained(opt.pretrained_bert_name) 34 | self.model = opt.model_class(bert, opt).to(opt.device) 35 | self.trainset = ABSADataset(opt.dataset_file['train'], tokenizer) 36 | self.testset = ABSADataset(opt.dataset_file['test'], tokenizer) 37 | if self.opt.dataset == 'mams': 38 | self.valset = ABSADataset(opt.dataset_file['dev'], tokenizer) 39 | else: 40 | assert 0 <= opt.valset_ratio < 1 41 | if opt.valset_ratio > 0: 42 | valset_len = int(len(self.trainset) * opt.valset_ratio) 43 | self.trainset, self.valset = random_split(self.trainset, (len(self.trainset) - valset_len, valset_len)) 44 | else: 45 | self.valset = self.testset 46 | 47 | if opt.device.type == 'cuda': 48 | logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index))) 49 | self._print_args() 50 | 51 | def _print_args(self): 52 | n_trainable_params, n_nontrainable_params = 0, 0 53 | for p in self.model.parameters(): 54 | n_params = torch.prod(torch.tensor(p.shape)) 55 | if p.requires_grad: 56 | n_trainable_params += n_params 57 | else: 58 | n_nontrainable_params += n_params 59 | logger.info( 60 | '> n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params)) 61 | logger.info('> training arguments:') 62 | for arg in vars(self.opt): 63 | logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) 64 | 65 | def _reset_params(self): 66 | for child in self.model.children(): 67 | if type(child) not in [BertModel, nn.Embedding]: # skip bert params and embedding 68 | for p in child.parameters(): 69 | if p.requires_grad: 70 | if len(p.shape) > 1: 71 | self.opt.initializer(p) 72 | else: 73 | stdv = 1. / math.sqrt(p.shape[0]) 74 | torch.nn.init.uniform_(p, a=-stdv, b=stdv) 75 | 76 | def _train(self, criterion, optimizer, train_data_loader, val_data_loader, test_data_loader): 77 | max_val_acc = 0 78 | max_val_f1 = 0 79 | max_val_epoch = 0 80 | global_step = 0 81 | path = None 82 | for i_epoch in range(self.opt.num_epoch): 83 | logger.info('>' * 100) 84 | logger.info('epoch: {}'.format(i_epoch)) 85 | n_correct, n_total, loss_total = 0, 0, 0 86 | n_op_correct, n_op_total = 0, 0 87 | # switch model to training mode 88 | self.model.train() 89 | for i_batch, batch in enumerate(train_data_loader): 90 | global_step += 1 91 | # clear gradient accumulators 92 | optimizer.zero_grad() 93 | 94 | inputs = [batch[col].to(self.opt.device) for col in self.opt.inputs_cols] 95 | outputs, opinion_outputs = self.model(inputs) 96 | 97 | targets = batch['polarities'].to(self.opt.device) 98 | outputs = outputs.view(-1, self.opt.polarities_dim) # bz*128,3 99 | targets = targets.view(-1) # bz*128,1 100 | mask = targets != -1 # bz*128, 1 non-aspect False aspect True 101 | mask_outputs = outputs[mask] 102 | mask_targets = targets[mask] 103 | loss1 = criterion[0](mask_outputs, mask_targets) 104 | 105 | opinion_targets = batch['opinion_indices'].to(self.opt.device) 106 | opinion_outputs = opinion_outputs.view(-1, self.opt.polarities_dim) 107 | opinion_targets = opinion_targets.view(-1) # bz*128,1 108 | opinion_mask = opinion_targets != -1 109 | mask_opinion_outputs = opinion_outputs[opinion_mask] 110 | mask_opinion_targets = opinion_targets[opinion_mask] 111 | loss2 = criterion[0](mask_opinion_outputs, mask_opinion_targets) 112 | 113 | loss3 = criterion[1](nn.functional.normalize(mask_outputs, dim=1), mask_targets) 114 | 115 | loss = loss1 + loss2 + self.opt.alpha * loss3 # 0.5 116 | loss.backward() 117 | optimizer.step() 118 | 119 | n_correct += (torch.argmax(mask_outputs, -1) == mask_targets).sum().item() 120 | n_total += len(mask_outputs) 121 | n_op_correct += (torch.argmax(mask_opinion_outputs, -1) == mask_opinion_targets).sum().item() 122 | n_op_total += len(mask_opinion_outputs) 123 | 124 | loss_total += loss.item() 125 | if global_step % self.opt.log_step == 0: 126 | train_acc = n_correct / n_total 127 | train_loss = loss_total / n_total 128 | train_op_acc = n_op_correct / n_op_total 129 | logger.info('loss: {:.4f}, acc: {:.4f}, op_acc: {:.4f}, ' 130 | 'loss1: {:.4f}, loss2: {:.4f}, loss3: {:.4f}'.format(train_loss, train_acc, 131 | train_op_acc, loss1, loss2, loss3)) 132 | 133 | val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader) 134 | logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1)) 135 | 136 | if val_acc > max_val_acc: # acc improve 137 | max_val_acc = val_acc 138 | max_val_f1 = val_f1 139 | max_val_epoch = i_epoch 140 | 141 | if not os.path.exists( 142 | '{}/{}/{}'.format(self.opt.save_model_dir, self.opt.model_name, self.opt.dataset)): 143 | os.makedirs('{}/{}/{}'.format(self.opt.save_model_dir, self.opt.model_name, self.opt.dataset)) 144 | path = '{0}/{1}/{2}/acc_{3}_f1_{4}_{5}.model'.format(self.opt.save_model_dir, self.opt.model_name, 145 | self.opt.dataset, 146 | round(val_acc, 4), round(val_f1, 4), 147 | strftime("%y%m%d-%H%M", localtime())) 148 | torch.save(self.model.state_dict(), path) 149 | logger.info('>> saved: {}'.format(path)) 150 | if val_f1 > max_val_f1: 151 | max_val_f1 = val_f1 152 | if i_epoch - max_val_epoch >= self.opt.patience: 153 | print('>> early stop.') 154 | break 155 | 156 | return path 157 | 158 | def _evaluate_acc_f1(self, data_loader): 159 | n_correct, n_total = 0, 0 160 | t_targets_all, t_outputs_all = None, None 161 | # switch model to evaluation mode 162 | self.model.eval() 163 | with torch.no_grad(): 164 | for i_batch, t_batch in enumerate(data_loader): 165 | t_inputs = [t_batch[col].to(self.opt.device) for col in self.opt.inputs_cols] 166 | t_targets = t_batch['polarities'].to(self.opt.device) 167 | t_outputs, t_opinion_outputs = self.model(t_inputs) 168 | 169 | t_targets = t_targets.view(-1) 170 | t_outputs = t_outputs.view(-1, self.opt.polarities_dim) 171 | t_mask = t_targets.view(-1) != -1 172 | t_mask_outputs = t_outputs[t_mask] 173 | t_mask_targets = t_targets[t_mask] 174 | 175 | n_correct += (torch.argmax(t_mask_outputs, -1) == t_mask_targets).sum().item() 176 | n_total += len(t_mask_outputs) 177 | 178 | if t_targets_all is None: 179 | t_targets_all = t_mask_targets 180 | t_outputs_all = t_mask_outputs 181 | else: 182 | t_targets_all = torch.cat((t_targets_all, t_mask_targets), dim=0) 183 | t_outputs_all = torch.cat((t_outputs_all, t_mask_outputs), dim=0) 184 | 185 | acc = n_correct / n_total 186 | f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], 187 | average='macro') 188 | return acc, f1 189 | 190 | def run(self): 191 | acc_list, f1_list = [], [] 192 | # Loss and Optimizer 193 | criterion = [nn.CrossEntropyLoss(), SupervisedContrastiveLoss()] 194 | _params = filter(lambda p: p.requires_grad, self.model.parameters()) 195 | optimizer = self.opt.optimizer(_params, lr=self.opt.lr, weight_decay=self.opt.l2reg) 196 | 197 | train_data_loader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True) 198 | test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False) 199 | val_data_loader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False) 200 | 201 | for i in range(self.opt.repeat): 202 | self._reset_params() 203 | best_model_path = self._train(criterion, optimizer, train_data_loader, val_data_loader, test_data_loader) 204 | self.model.load_state_dict(torch.load(best_model_path)) 205 | test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader) 206 | logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1)) 207 | acc_list.append(test_acc) 208 | f1_list.append(test_f1) 209 | all_acc = np.asarray(acc_list) 210 | avg_acc = np.average(all_acc) 211 | all_f1 = np.asarray(f1_list) 212 | avg_f1 = np.average(all_f1) 213 | for acc, f1 in zip(acc_list, f1_list): 214 | logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(acc, f1)) 215 | logger.info('>> avg_test_acc: {:.4f}, avg_test_f1: {:.4f}'.format(avg_acc, avg_f1)) 216 | 217 | 218 | def main(): 219 | # Hyper Parameters 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument('--model_name', default='YORO', type=str) 222 | parser.add_argument('--dataset', default='rest14', type=str, help='mams, rest14, lap14') 223 | parser.add_argument('--optimizer', default='adam', type=str) 224 | parser.add_argument('--initializer', default='xavier_uniform_', type=str) 225 | parser.add_argument('--repeat', default=1, type=int) 226 | parser.add_argument('--lr', default=2e-5, type=float, help='try 5e-5, 2e-5 for BERT, 1e-3 for others') 227 | parser.add_argument('--dropout', default=0.3, type=float) 228 | parser.add_argument('--l2reg', default=1e-4, type=float) 229 | parser.add_argument('--num_epoch', default=20, type=int, help='try larger number for non-BERT models') 230 | parser.add_argument('--batch_size', default=16, type=int, help='try 16, 32, 64 for BERT models') 231 | parser.add_argument('--log_step', default=10, type=int) 232 | parser.add_argument('--bert_dim', default=768, type=int) 233 | parser.add_argument('--hidden_dim', default=768, type=int) 234 | parser.add_argument('--pretrained_bert_name', default='bert-base-uncased', type=str) 235 | parser.add_argument('--max_seq_len', default=128, type=int) 236 | parser.add_argument('--polarities_dim', default=3, type=int) 237 | parser.add_argument('--alpha', default=0.5, type=float) 238 | parser.add_argument('--patience', default=5, type=int) 239 | parser.add_argument('--device', default=None, type=str, help='e.g. cuda:0') 240 | parser.add_argument('--seed', default=1, type=int, help='set seed for reproducibility') 241 | parser.add_argument('--valset_ratio', default=0, type=float,help='set ratio between 0 and 1 for validation support') 242 | parser.add_argument('--save_model_dir', default='/Your_Path', type=str) 243 | 244 | opt = parser.parse_args() 245 | 246 | if opt.seed is not None: 247 | random.seed(opt.seed) 248 | np.random.seed(opt.seed) 249 | torch.manual_seed(opt.seed) 250 | torch.cuda.manual_seed(opt.seed) 251 | torch.backends.cudnn.deterministic = True 252 | torch.backends.cudnn.benchmark = False 253 | os.environ['PYTHONHASHSEED'] = str(opt.seed) 254 | 255 | model_classes = { 256 | 'YORO': YORO, 257 | } 258 | input_colses = { 259 | 'YORO': ['input_ids', 'token_type_ids', 'attention_mask', 'distance_adj', 'relation_adj'], 260 | } 261 | dataset_files = { 262 | 'lap14': { 263 | 'train': './dataset/lap14_train', 264 | 'test': './dataset/lap14_test' 265 | }, 266 | 'rest14': { 267 | 'train': './dataset/rest14_train', 268 | 'test': './dataset/rest14_test' 269 | }, 270 | 'mams': { 271 | 'train': './dataset/mams_train', 272 | 'dev': './dataset/mams_dev', 273 | 'test': './dataset/mams_test' 274 | } 275 | } 276 | initializers = { 277 | 'xavier_uniform_': torch.nn.init.xavier_uniform_, 278 | 'xavier_normal_': torch.nn.init.xavier_normal_, 279 | 'orthogonal_': torch.nn.init.orthogonal_, 280 | } 281 | optimizers = { 282 | 'adadelta': torch.optim.Adadelta, # default lr=1.0 283 | 'adagrad': torch.optim.Adagrad, # default lr=0.01 284 | 'adam': torch.optim.Adam, # default lr=0.001 285 | 'adamax': torch.optim.Adamax, # default lr=0.002 286 | 'asgd': torch.optim.ASGD, # default lr=0.01 287 | 'rmsprop': torch.optim.RMSprop, # default lr=0.01 288 | 'sgd': torch.optim.SGD, 289 | 'adamw': torch.optim.AdamW, 290 | } 291 | opt.model_class = model_classes[opt.model_name] 292 | opt.dataset_file = dataset_files[opt.dataset] 293 | opt.inputs_cols = input_colses[opt.model_name] 294 | opt.initializer = initializers[opt.initializer] 295 | opt.optimizer = optimizers[opt.optimizer] 296 | opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \ 297 | if opt.device is None else torch.device(opt.device) 298 | 299 | if not os.path.exists('log/{}'.format(opt.model_name)): 300 | os.makedirs('log/{}'.format(opt.model_name)) 301 | log_file = 'log/{}/{}-{}.log'.format(opt.model_name, opt.dataset, strftime("%y%m%d-%H%M", localtime())) 302 | logger.addHandler(logging.FileHandler(log_file)) 303 | 304 | ins = Instructor(opt) 305 | ins.run() 306 | 307 | 308 | if __name__ == '__main__': 309 | main() 310 | -------------------------------------------------------------------------------- /model/YORO.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from layer.rgcn import RelationalGraphConvLayer 6 | 7 | 8 | class YORO(nn.Module): 9 | def __init__(self, bert, args): 10 | super(YORO, self).__init__() 11 | self.bert = bert 12 | self.rgc1 = RelationalGraphConvLayer(5, args.bert_dim, args.bert_dim) 13 | self.rgc2 = RelationalGraphConvLayer(5, args.bert_dim, args.bert_dim) 14 | self.dropout = nn.Dropout(args.dropout) 15 | self.op_dense = nn.Linear(args.bert_dim, args.polarities_dim) 16 | self.dense = nn.Linear(args.bert_dim, args.polarities_dim) 17 | 18 | def forward(self, inputs): 19 | input_ids, token_type_ids, attention_mask, distance_adj, relation_adj = inputs 20 | output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 21 | hidden = output.last_hidden_state 22 | 23 | adj = distance_adj.unsqueeze(1).expand(-1, 5, -1, -1) * relation_adj 24 | x = F.relu(self.rgc1(hidden, adj)) 25 | x = self.dropout(x) 26 | x = F.relu(self.rgc2(x, adj)) 27 | 28 | hidden_output = self.dropout(x) 29 | op_logits = self.op_dense(hidden_output) 30 | logits = self.dense(hidden_output) 31 | return logits, op_logits 32 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from model.YORO import YORO 4 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python main.py --model_name YORO --dataset mams 2 | python main.py --model_name YORO --dataset rest14 3 | python main.py --model_name YORO --dataset lap14 --------------------------------------------------------------------------------