├── data └── preprocess │ ├── 14lap.pt │ ├── 14rest.pt │ ├── 15rest.pt │ ├── 16rest.pt │ ├── 14lap_standard.pt │ ├── 14rest_standard.pt │ ├── 15rest_standard.pt │ └── 16rest_standard.pt ├── README.md ├── Model.py ├── makeData_standard.py ├── utils.py ├── dataProcess.py ├── Data.py ├── makeData_dual.py └── main.py /data/preprocess/14lap.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/14lap.pt -------------------------------------------------------------------------------- /data/preprocess/14rest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/14rest.pt -------------------------------------------------------------------------------- /data/preprocess/15rest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/15rest.pt -------------------------------------------------------------------------------- /data/preprocess/16rest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/16rest.pt -------------------------------------------------------------------------------- /data/preprocess/14lap_standard.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/14lap_standard.pt -------------------------------------------------------------------------------- /data/preprocess/14rest_standard.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/14rest_standard.pt -------------------------------------------------------------------------------- /data/preprocess/15rest_standard.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/15rest_standard.pt -------------------------------------------------------------------------------- /data/preprocess/16rest_standard.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenshaowei57/BMRC/HEAD/data/preprocess/16rest_standard.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BMRC 2 | 3 | Code and data of the paper "Bidirectional Machine Reading Comprehension for Aspect Sentiment Triplet Extraction, AAAI 2021" (https://arxiv.org/pdf/2103.07665.pdf) 4 | 5 | Authors: Shaowei Chen, Yu Wang, Jie Liu, Yuelin Wang 6 | 7 | #### Requirements: 8 | 9 | ``` 10 | python==3.6.9 11 | torch==1.2.0 12 | transformers==2.9.0 13 | ``` 14 | 15 | #### Original Datasets: 16 | 17 | You can download the 14-Res, 14-Lap, 15-Res, 16-Res datasets from https://github.com/xuuuluuu/SemEval-Triplet-data. 18 | 19 | #### Data Preprocess: 20 | 21 | ``` 22 | python dataProcess.py 23 | python makeData_dual.py 24 | python makeData_standard.py 25 | ``` 26 | 27 | #### How to run: 28 | 29 | ``` 30 | python main.py --mode train # For training 31 | ``` 32 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 3 | # @Date: 2021-5-4 4 | 5 | from transformers import BertTokenizer, BertModel, BertConfig 6 | import torch.nn as nn 7 | 8 | 9 | class BERTModel(nn.Module): 10 | def __init__(self, args): 11 | hidden_size = args.hidden_size 12 | 13 | super(BERTModel, self).__init__() 14 | 15 | # BERT模型 16 | if args.bert_model_type == 'bert-base-uncased': 17 | self._bert = BertModel.from_pretrained(args.bert_model_type) 18 | self._tokenizer = BertTokenizer.from_pretrained(args.bert_model_type) 19 | print('Bertbase model loaded') 20 | 21 | else: 22 | raise KeyError('Config.args.bert_model_type should be bert-based-uncased. ') 23 | 24 | self.classifier_start = nn.Linear(hidden_size, 2) 25 | 26 | self.classifier_end = nn.Linear(hidden_size, 2) 27 | 28 | self._classifier_sentiment = nn.Linear(hidden_size, 3) 29 | 30 | def forward(self, query_tensor, query_mask, query_seg, step): 31 | 32 | hidden_states = self._bert(query_tensor, attention_mask=query_mask, token_type_ids=query_seg)[0] 33 | if step == 0: # predict entity 34 | out_scores_start = self.classifier_start(hidden_states) 35 | out_scores_end = self.classifier_end(hidden_states) 36 | return out_scores_start, out_scores_end 37 | else: # predict sentiment 38 | cls_hidden_states = hidden_states[:, 0, :] 39 | cls_hidden_scores = self._classifier_sentiment(cls_hidden_states) 40 | return cls_hidden_scores 41 | -------------------------------------------------------------------------------- /makeData_standard.py: -------------------------------------------------------------------------------- 1 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 2 | # @Date: 2021-5-4 3 | 4 | import torch 5 | import pickle 6 | 7 | 8 | 9 | 10 | def make_standard(home_path, dataset_name, dataset_type): 11 | # read triple 12 | f = open(home_path + dataset_name + "/" + dataset_name + "_pair/" + dataset_type + "_pair.pkl", "rb") 13 | triple_data = pickle.load(f) 14 | f.close() 15 | 16 | for triplet in triple_data: 17 | 18 | aspect_temp = [] 19 | opinion_temp = [] 20 | pair_temp = [] 21 | triplet_temp = [] 22 | asp_pol_temp = [] 23 | for temp_t in triplet: 24 | triplet_temp.append([temp_t[0][0], temp_t[0][-1], temp_t[1][0], temp_t[1][-1], temp_t[2]]) 25 | ap = [temp_t[0][0], temp_t[0][-1], temp_t[2]] 26 | if ap not in asp_pol_temp: 27 | asp_pol_temp.append(ap) 28 | a = [temp_t[0][0], temp_t[0][-1]] 29 | if a not in aspect_temp: 30 | aspect_temp.append(a) 31 | o = [temp_t[1][0], temp_t[1][-1]] 32 | if o not in opinion_temp: 33 | opinion_temp.append(o) 34 | p = [temp_t[0][0], temp_t[0][-1], temp_t[1][0], temp_t[1][-1]] 35 | if p not in pair_temp: 36 | pair_temp.append(p) 37 | 38 | standard_list.append({'asp_target': aspect_temp, 'opi_target': opinion_temp, 'asp_opi_target': pair_temp, 39 | 'asp_pol_target': asp_pol_temp, 'triplet': triplet_temp}) 40 | 41 | return standard_list 42 | 43 | 44 | if __name__ == '__main__': 45 | home_path = "./data/original/" 46 | dataset_name_list = ["14rest", "15rest", "16rest", "14lap"] 47 | for dataset_name in dataset_name_list: 48 | output_path = "./data/preprocess/" + dataset_name + "_standard.pt" 49 | dev_standard = make_standard(home_path, dataset_name, 'dev') 50 | test_standard = make_standard(home_path, dataset_name, 'test') 51 | torch.save({'dev': dev_standard, 'test': test_standard}, output_path) 52 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 3 | # @Date: 2021-5-4 4 | 5 | import torch 6 | from torch.nn import functional as F 7 | import logging 8 | 9 | 10 | def normalize_size(tensor): 11 | if len(tensor.size()) == 3: 12 | tensor = tensor.contiguous().view(-1, tensor.size(2)) 13 | elif len(tensor.size()) == 2: 14 | tensor = tensor.contiguous().view(-1) 15 | 16 | return tensor 17 | 18 | 19 | def calculate_entity_loss(pred_start, pred_end, gold_start, gold_end): 20 | pred_start = normalize_size(pred_start) 21 | pred_end = normalize_size(pred_end) 22 | gold_start = normalize_size(gold_start) 23 | gold_end = normalize_size(gold_end) 24 | 25 | weight = torch.tensor([1, 3]).float().cuda() 26 | 27 | loss_start = F.cross_entropy(pred_start, gold_start.long(), size_average=False, weight=weight, ignore_index=-1) 28 | loss_end = F.cross_entropy(pred_end, gold_end.long(), size_average=False, weight=weight, ignore_index=-1) 29 | 30 | return 0.5 * loss_start + 0.5 * loss_end 31 | 32 | 33 | def calculate_sentiment_loss(pred_sentiment, gold_sentiment): 34 | return F.cross_entropy(pred_sentiment, gold_sentiment.long(), size_average=False, ignore_index=-1) 35 | 36 | 37 | def get_logger(filename, verbosity=1, name=None): 38 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 39 | formatter = logging.Formatter( 40 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 41 | ) 42 | logger = logging.getLogger(name) 43 | logger.setLevel(level_dict[verbosity]) 44 | 45 | fh = logging.FileHandler(filename, "w") 46 | fh.setFormatter(formatter) 47 | logger.addHandler(fh) 48 | 49 | sh = logging.StreamHandler() 50 | sh.setFormatter(formatter) 51 | logger.addHandler(sh) 52 | 53 | return logger 54 | 55 | def filter_prob(f_asp_prob, f_opi_prob, f_opi_start_index, f_opi_end_index, beta): 56 | filter_start = [] 57 | filter_end = [] 58 | for idx in range(len(f_opi_prob)): 59 | if f_asp_prob*f_opi_prob[idx]>=beta: 60 | filter_start.append(f_opi_start_index[idx]) 61 | filter_end.append(f_opi_end_index[idx]) 62 | return filter_start, filter_end 63 | 64 | def filter_unpaired(start_prob, end_prob, start, end): 65 | filtered_start = [] 66 | filtered_end = [] 67 | filtered_prob = [] 68 | if len(start)>0 and len(end)>0: 69 | length = start[-1]+1 if start[-1]>=end[-1] else end[-1]+1 70 | temp_seq = [0]*length 71 | for s in start: 72 | temp_seq[s]+=1 73 | for e in end: 74 | temp_seq[e]+=2 75 | last_start = -1 76 | for idx in range(len(temp_seq)): 77 | assert temp_seq[idx]<4 78 | if temp_seq[idx] == 1: 79 | last_start = idx 80 | elif temp_seq[idx] == 2: 81 | if last_start!=-1 and idx-last_start<5: 82 | filtered_start.append(last_start) 83 | filtered_end.append(idx) 84 | prob = start_prob[start.index(last_start)] * end_prob[end.index(idx)] 85 | filtered_prob.append(prob) 86 | last_start = -1 87 | elif temp_seq[idx] == 3: 88 | filtered_start.append(idx) 89 | filtered_end.append(idx) 90 | prob = start_prob[start.index(idx)] * end_prob[end.index(idx)] 91 | filtered_prob.append(prob) 92 | last_start = -1 93 | return filtered_start, filtered_end, filtered_prob -------------------------------------------------------------------------------- /dataProcess.py: -------------------------------------------------------------------------------- 1 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 2 | # @Date: 2021-5-4 3 | 4 | import pickle 5 | import torch 6 | 7 | 8 | class dual_sample(object): 9 | def __init__(self, 10 | original_sample, 11 | text, 12 | forward_querys, 13 | forward_answers, 14 | backward_querys, 15 | backward_answers, 16 | sentiment_querys, 17 | sentiment_answers): 18 | self.original_sample = original_sample # 19 | self.text = text # 20 | self.forward_querys=forward_querys 21 | self.forward_answers=forward_answers 22 | self.backward_querys=backward_querys 23 | self.backward_answers=backward_answers 24 | self.sentiment_querys=sentiment_querys 25 | self.sentiment_answers=sentiment_answers 26 | 27 | 28 | def get_text(lines): 29 | # Line sample: 30 | # It is always reliable , never bugged and responds well .####It=O is=O always=O reliable=O ,=O never=O bugged=O and=O responds=T-POS well=O .=O####It=O is=O always=O reliable=O ,=O never=O bugged=O and=O responds=O well=S .=O 31 | text_list = [] 32 | aspect_list = [] 33 | opinion_list = [] 34 | for f in lines: 35 | temp = f.split("####") 36 | assert len(temp) == 3 37 | word_list = temp[0].split() 38 | aspect_label_list = [t.split("=")[-1] for t in temp[1].split()] 39 | opinion_label_list = [t.split("=")[-1] for t in temp[2].split()] 40 | assert len(word_list) == len(aspect_label_list) == len(opinion_label_list) 41 | text_list.append(word_list) 42 | aspect_list.append(aspect_label_list) 43 | opinion_list.append(opinion_label_list) 44 | return text_list, aspect_list, opinion_list 45 | 46 | 47 | def valid_data(triplet, aspect, opinion): 48 | for t in triplet[0][0]: 49 | assert aspect[t] != ["O"] 50 | for t in triplet[0][1]: 51 | assert opinion[t] != ["O"] 52 | 53 | 54 | def fusion_dual_triplet(triplet): 55 | triplet_aspect = [] 56 | triplet_opinion = [] 57 | triplet_sentiment = [] 58 | dual_opinion = [] 59 | dual_aspect = [] 60 | for t in triplet: 61 | if t[0] not in triplet_aspect: 62 | triplet_aspect.append(t[0]) 63 | triplet_opinion.append([t[1]]) 64 | triplet_sentiment.append(t[2]) 65 | else: 66 | idx = triplet_aspect.index(t[0]) 67 | triplet_opinion[idx].append(t[1]) 68 | assert triplet_sentiment[idx] == t[2] 69 | if t[1] not in dual_opinion: 70 | dual_opinion.append(t[1]) 71 | dual_aspect.append([t[0]]) 72 | else: 73 | idx = dual_opinion.index(t[1]) 74 | dual_aspect[idx].append(t[0]) 75 | 76 | return triplet_aspect, triplet_opinion, triplet_sentiment, dual_opinion, dual_aspect 77 | 78 | 79 | if __name__ == '__main__': 80 | home_path = "./data/original/" 81 | dataset_name_list = ["14lap", "14rest", "15rest", "16rest"] 82 | dataset_type_list = ["train", "test", "dev"] 83 | for dataset_name in dataset_name_list: 84 | for dataset_type in dataset_type_list: 85 | output_path = "./data/preprocess/" + dataset_name + "_" + dataset_type + "_dual.pt" 86 | # read triple 87 | f = open(home_path + dataset_name + "/" + dataset_name + "_pair/" + dataset_type + "_pair.pkl", "rb") 88 | triple_data = pickle.load(f) 89 | f.close() 90 | # read text 91 | f = open(home_path + dataset_name + "/" + dataset_type + ".txt", "r", encoding="utf-8") 92 | text_lines = f.readlines() 93 | f.close() 94 | # get text 95 | text_list, aspect_list, opinion_list = get_text(text_lines) 96 | sample_list = [] 97 | for k in range(len(text_list)): 98 | triplet = triple_data[k] 99 | text = text_list[k] 100 | valid_data(triplet, aspect_list[k], opinion_list[k]) 101 | triplet_aspect, triplet_opinion, triplet_sentiment, dual_opinion, dual_aspect = fusion_dual_triplet(triplet) 102 | forward_query_list = [] 103 | backward_query_list = [] 104 | sentiment_query_list = [] 105 | forward_answer_list = [] 106 | backward_answer_list = [] 107 | sentiment_answer_list = [] 108 | forward_query_list.append(["What", "aspects", "?"]) 109 | start = [0] * len(text) 110 | end = [0] * len(text) 111 | for ta in triplet_aspect: 112 | start[ta[0]] = 1 113 | end[ta[-1]] = 1 114 | forward_answer_list.append([start, end]) 115 | backward_query_list.append(["What", "opinions", "?"]) 116 | start = [0] * len(text) 117 | end = [0] * len(text) 118 | for to in dual_opinion: 119 | start[to[0]] = 1 120 | end[to[-1]] = 1 121 | backward_answer_list.append([start, end]) 122 | 123 | for idx in range(len(triplet_aspect)): 124 | ta = triplet_aspect[idx] 125 | # opinion query 126 | query = ["What", "opinion", "given", "the", "aspect"] + text[ta[0]:ta[-1] + 1] + ["?"] 127 | forward_query_list.append(query) 128 | start = [0] * len(text) 129 | end = [0] * len(text) 130 | for to in triplet_opinion[idx]: 131 | start[to[0]] = 1 132 | end[to[-1]] = 1 133 | forward_answer_list.append([start, end]) 134 | # sentiment query 135 | query = ["What", "sentiment", "given", "the", "aspect"] + text[ta[0]:ta[-1] + 1] + ["and", "the", 136 | "opinion"] 137 | for idy in range(len(triplet_opinion[idx]) - 1): 138 | to = triplet_opinion[idx][idy] 139 | query += text[to[0]:to[-1] + 1] + ["/"] 140 | to = triplet_opinion[idx][-1] 141 | query += text[to[0]:to[-1] + 1] + ["?"] 142 | sentiment_query_list.append(query) 143 | sentiment_answer_list.append(triplet_sentiment[idx]) 144 | for idx in range(len(dual_opinion)): 145 | ta = dual_opinion[idx] 146 | # opinion query 147 | query = ["What", "aspect", "does", "the", "opinion"] + text[ta[0]:ta[-1] + 1] + ["describe", "?"] 148 | backward_query_list.append(query) 149 | start = [0] * len(text) 150 | end = [0] * len(text) 151 | for to in dual_aspect[idx]: 152 | start[to[0]] = 1 153 | end[to[-1]] = 1 154 | backward_answer_list.append([start, end]) 155 | 156 | temp_sample = dual_sample(text_lines[k], text, forward_query_list, forward_answer_list, backward_query_list, backward_answer_list, sentiment_query_list, sentiment_answer_list) 157 | sample_list.append(temp_sample) 158 | torch.save(sample_list, output_path) 159 | -------------------------------------------------------------------------------- /Data.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 3 | # @Date: 2021-5-4 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | import numpy as np 7 | 8 | 9 | class OriginalDataset(Dataset): 10 | def __init__(self, pre_data): 11 | self._forward_asp_query = pre_data['_forward_asp_query'] 12 | self._forward_opi_query = pre_data['_forward_opi_query'] 13 | self._forward_asp_answer_start = pre_data['_forward_asp_answer_start'] 14 | self._forward_asp_answer_end = pre_data['_forward_asp_answer_end'] 15 | self._forward_opi_answer_start = pre_data['_forward_opi_answer_start'] 16 | self._forward_opi_answer_end = pre_data['_forward_opi_answer_end'] 17 | self._forward_asp_query_mask = pre_data['_forward_asp_query_mask'] 18 | self._forward_opi_query_mask = pre_data['_forward_opi_query_mask'] 19 | self._forward_asp_query_seg = pre_data['_forward_asp_query_seg'] 20 | self._forward_opi_query_seg = pre_data['_forward_opi_query_seg'] 21 | 22 | self._backward_asp_query = pre_data['_backward_asp_query'] 23 | self._backward_opi_query = pre_data['_backward_opi_query'] 24 | self._backward_asp_answer_start = pre_data['_backward_asp_answer_start'] 25 | self._backward_asp_answer_end = pre_data['_backward_asp_answer_end'] 26 | self._backward_opi_answer_start = pre_data['_backward_opi_answer_start'] 27 | self._backward_opi_answer_end = pre_data['_backward_opi_answer_end'] 28 | self._backward_asp_query_mask = pre_data['_backward_asp_query_mask'] 29 | self._backward_opi_query_mask = pre_data['_backward_opi_query_mask'] 30 | self._backward_asp_query_seg = pre_data['_backward_asp_query_seg'] 31 | self._backward_opi_query_seg = pre_data['_backward_opi_query_seg'] 32 | 33 | self._sentiment_query = pre_data['_sentiment_query'] 34 | self._sentiment_answer = pre_data['_sentiment_answer'] 35 | self._sentiment_query_mask = pre_data['_sentiment_query_mask'] 36 | self._sentiment_query_seg = pre_data['_sentiment_query_seg'] 37 | 38 | self._aspect_num = pre_data['_aspect_num'] 39 | self._opinion_num = pre_data['_opinion_num'] 40 | 41 | 42 | class ReviewDataset(Dataset): 43 | def __init__(self, train, dev, test, set): 44 | ''' 45 | 评论数据集 46 | :param train: list, training set of 14 lap, 14 res, 15 res, 16 res 47 | :param dev: list, the same 48 | :param test: list, the same 49 | ''' 50 | self._train_set = train 51 | self._dev_set = dev 52 | self._test_set = test 53 | if set == 'train': 54 | self._dataset = self._train_set 55 | elif set == 'dev': 56 | self._dataset = self._dev_set 57 | elif set == 'test': 58 | self._dataset = self._test_set 59 | 60 | self._forward_asp_query = self._dataset._forward_asp_query 61 | self._forward_opi_query = self._dataset._forward_opi_query 62 | self._forward_asp_answer_start = self._dataset._forward_asp_answer_start 63 | self._forward_asp_answer_end = self._dataset._forward_asp_answer_end 64 | self._forward_opi_answer_start = self._dataset._forward_opi_answer_start 65 | self._forward_opi_answer_end = self._dataset._forward_opi_answer_end 66 | self._forward_asp_query_mask = self._dataset._forward_asp_query_mask 67 | self._forward_opi_query_mask = self._dataset._forward_opi_query_mask 68 | self._forward_asp_query_seg = self._dataset._forward_asp_query_seg 69 | self._forward_opi_query_seg = self._dataset._forward_opi_query_seg 70 | self._backward_asp_query = self._dataset._backward_asp_query 71 | self._backward_opi_query = self._dataset._backward_opi_query 72 | self._backward_asp_answer_start = self._dataset._backward_asp_answer_start 73 | self._backward_asp_answer_end = self._dataset._backward_asp_answer_end 74 | self._backward_opi_answer_start = self._dataset._backward_opi_answer_start 75 | self._backward_opi_answer_end = self._dataset._backward_opi_answer_end 76 | self._backward_asp_query_mask = self._dataset._backward_asp_query_mask 77 | self._backward_opi_query_mask = self._dataset._backward_opi_query_mask 78 | self._backward_asp_query_seg = self._dataset._backward_asp_query_seg 79 | self._backward_opi_query_seg = self._dataset._backward_opi_query_seg 80 | self._sentiment_query = self._dataset._sentiment_query 81 | self._sentiment_answer = self._dataset._sentiment_answer 82 | self._sentiment_query_mask = self._dataset._sentiment_query_mask 83 | self._sentiment_query_seg = self._dataset._sentiment_query_seg 84 | self._aspect_num = self._dataset._aspect_num 85 | self._opinion_num = self._dataset._opinion_num 86 | 87 | def get_batch_num(self, batch_size): 88 | return len(self._forward_asp_query) // batch_size 89 | 90 | def __len__(self): 91 | return len(self._forward_asp_query) 92 | 93 | def __getitem__(self, item): 94 | forward_asp_query = self._forward_asp_query[item] 95 | forward_opi_query = self._forward_opi_query[item] 96 | forward_asp_answer_start = self._forward_asp_answer_start[item] 97 | forward_asp_answer_end = self._forward_asp_answer_end[item] 98 | forward_opi_answer_start = self._forward_opi_answer_start[item] 99 | forward_opi_answer_end = self._forward_opi_answer_end[item] 100 | forward_asp_query_mask = self._forward_asp_query_mask[item] 101 | forward_opi_query_mask = self._forward_opi_query_mask[item] 102 | forward_asp_query_seg = self._forward_asp_query_seg[item] 103 | forward_opi_query_seg = self._forward_opi_query_seg[item] 104 | backward_asp_query = self._backward_asp_query[item] 105 | backward_opi_query = self._backward_opi_query[item] 106 | backward_asp_answer_start = self._backward_asp_answer_start[item] 107 | backward_asp_answer_end = self._backward_asp_answer_end[item] 108 | backward_opi_answer_start = self._backward_opi_answer_start[item] 109 | backward_opi_answer_end = self._backward_opi_answer_end[item] 110 | backward_asp_query_mask = self._backward_asp_query_mask[item] 111 | backward_opi_query_mask = self._backward_opi_query_mask[item] 112 | backward_asp_query_seg = self._backward_asp_query_seg[item] 113 | backward_opi_query_seg = self._backward_opi_query_seg[item] 114 | sentiment_query = self._sentiment_query[item] 115 | sentiment_answer = self._sentiment_answer[item] 116 | sentiment_query_mask = self._sentiment_query_mask[item] 117 | sentiment_query_seg = self._sentiment_query_seg[item] 118 | aspect_num = self._aspect_num[item] 119 | opinion_num = self._opinion_num[item] 120 | 121 | return {"forward_asp_query": np.array(forward_asp_query), 122 | "forward_opi_query": np.array(forward_opi_query), 123 | "forward_asp_answer_start": np.array(forward_asp_answer_start), 124 | "forward_asp_answer_end": np.array(forward_asp_answer_end), 125 | "forward_opi_answer_start": np.array(forward_opi_answer_start), 126 | "forward_opi_answer_end": np.array(forward_opi_answer_end), 127 | "forward_asp_query_mask": np.array(forward_asp_query_mask), 128 | "forward_opi_query_mask": np.array(forward_opi_query_mask), 129 | "forward_asp_query_seg": np.array(forward_asp_query_seg), 130 | "forward_opi_query_seg": np.array(forward_opi_query_seg), 131 | "backward_asp_query": np.array(backward_asp_query), 132 | "backward_opi_query": np.array(backward_opi_query), 133 | "backward_asp_answer_start": np.array(backward_asp_answer_start), 134 | "backward_asp_answer_end": np.array(backward_asp_answer_end), 135 | "backward_opi_answer_start": np.array(backward_opi_answer_start), 136 | "backward_opi_answer_end": np.array(backward_opi_answer_end), 137 | "backward_asp_query_mask": np.array(backward_asp_query_mask), 138 | "backward_opi_query_mask": np.array(backward_opi_query_mask), 139 | "backward_asp_query_seg": np.array(backward_asp_query_seg), 140 | "backward_opi_query_seg": np.array(backward_opi_query_seg), 141 | "sentiment_query": np.array(sentiment_query), 142 | "sentiment_answer": np.array(sentiment_answer), 143 | "sentiment_query_mask": np.array(sentiment_query_mask), 144 | "sentiment_query_seg": np.array(sentiment_query_seg), 145 | "aspect_num": np.array(aspect_num), 146 | "opinion_num": np.array(opinion_num) 147 | } 148 | 149 | 150 | def generate_fi_batches(dataset, batch_size, shuffle=True, drop_last=True, ifgpu=True): 151 | dataloader = DataLoader(dataset=dataset, batch_size=batch_size, 152 | shuffle=shuffle, drop_last=drop_last) 153 | 154 | for data_dict in dataloader: 155 | out_dict = {} 156 | for name, tensor in data_dict.items(): 157 | if ifgpu: 158 | out_dict[name] = data_dict[name].cuda() 159 | else: 160 | out_dict[name] = data_dict[name] 161 | yield out_dict 162 | -------------------------------------------------------------------------------- /makeData_dual.py: -------------------------------------------------------------------------------- 1 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 2 | # @Date: 2021-5-4 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from transformers import BertTokenizer 7 | import numpy as np 8 | 9 | 10 | class dual_sample(object): 11 | def __init__(self, 12 | original_sample, 13 | text, 14 | forward_querys, 15 | forward_answers, 16 | backward_querys, 17 | backward_answers, 18 | sentiment_querys, 19 | sentiment_answers): 20 | self.original_sample = original_sample 21 | self.text = text # 22 | self.forward_querys = forward_querys 23 | self.forward_answers = forward_answers 24 | self.backward_querys = backward_querys 25 | self.backward_answers = backward_answers 26 | self.sentiment_querys = sentiment_querys 27 | self.sentiment_answers = sentiment_answers 28 | 29 | 30 | class sample_tokenized(object): 31 | def __init__(self, 32 | original_sample, 33 | forward_querys, 34 | forward_answers, 35 | backward_querys, 36 | backward_answers, 37 | sentiment_querys, 38 | sentiment_answers, 39 | forward_seg, 40 | backward_seg, 41 | sentiment_seg): 42 | self.original_sample = original_sample 43 | self.forward_querys = forward_querys 44 | self.forward_answers = forward_answers 45 | self.backward_querys = backward_querys 46 | self.backward_answers = backward_answers 47 | self.sentiment_querys = sentiment_querys 48 | self.sentiment_answers = sentiment_answers 49 | self.forward_seg = forward_seg 50 | self.backward_seg = backward_seg 51 | self.sentiment_seg = sentiment_seg 52 | 53 | 54 | class OriginalDataset(Dataset): 55 | def __init__(self, pre_data): 56 | self._forward_asp_query = pre_data['_forward_asp_query'] 57 | self._forward_opi_query = pre_data['_forward_opi_query'] # [max_aspect_num, max_opinion_query_length] 58 | self._forward_asp_answer_start = pre_data['_forward_asp_answer_start'] 59 | self._forward_asp_answer_end = pre_data['_forward_asp_answer_end'] 60 | self._forward_opi_answer_start = pre_data['_forward_opi_answer_start'] 61 | self._forward_opi_answer_end = pre_data['_forward_opi_answer_end'] 62 | self._forward_asp_query_mask = pre_data['_forward_asp_query_mask'] # [max_aspect_num, max_opinion_query_length] 63 | self._forward_opi_query_mask = pre_data['_forward_opi_query_mask'] # [max_aspect_num, max_opinion_query_length] 64 | self._forward_asp_query_seg = pre_data['_forward_asp_query_seg'] # [max_aspect_num, max_opinion_query_length] 65 | self._forward_opi_query_seg = pre_data['_forward_opi_query_seg'] # [max_aspect_num, max_opinion_query_length] 66 | 67 | self._backward_asp_query = pre_data['_backward_asp_query'] 68 | self._backward_opi_query = pre_data['_backward_opi_query'] # [max_aspect_num, max_opinion_query_length] 69 | self._backward_asp_answer_start = pre_data['_backward_asp_answer_start'] 70 | self._backward_asp_answer_end = pre_data['_backward_asp_answer_end'] 71 | self._backward_opi_answer_start = pre_data['_backward_opi_answer_start'] 72 | self._backward_opi_answer_end = pre_data['_backward_opi_answer_end'] 73 | self._backward_asp_query_mask = pre_data[ 74 | '_backward_asp_query_mask'] # [max_aspect_num, max_opinion_query_length] 75 | self._backward_opi_query_mask = pre_data[ 76 | '_backward_opi_query_mask'] # [max_aspect_num, max_opinion_query_length] 77 | self._backward_asp_query_seg = pre_data['_backward_asp_query_seg'] # [max_aspect_num, max_opinion_query_length] 78 | self._backward_opi_query_seg = pre_data['_backward_opi_query_seg'] # [max_aspect_num, max_opinion_query_length] 79 | 80 | self._sentiment_query = pre_data['_sentiment_query'] # [max_aspect_num, max_sentiment_query_length] 81 | self._sentiment_answer = pre_data['_sentiment_answer'] 82 | self._sentiment_query_mask = pre_data['_sentiment_query_mask'] # [max_aspect_num, max_sentiment_query_length] 83 | self._sentiment_query_seg = pre_data['_sentiment_query_seg'] # [max_aspect_num, max_sentiment_query_length] 84 | 85 | self._aspect_num = pre_data['_aspect_num'] 86 | self._opinion_num = pre_data['_opinion_num'] 87 | 88 | 89 | def pre_processing(sample_list, max_len): 90 | 91 | _tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 92 | _forward_asp_query = [] 93 | _forward_opi_query = [] 94 | _forward_asp_answer_start = [] 95 | _forward_asp_answer_end = [] 96 | _forward_opi_answer_start = [] 97 | _forward_opi_answer_end = [] 98 | _forward_asp_query_mask = [] 99 | _forward_opi_query_mask = [] 100 | _forward_asp_query_seg = [] 101 | _forward_opi_query_seg = [] 102 | 103 | _backward_asp_query = [] 104 | _backward_opi_query = [] 105 | _backward_asp_answer_start = [] 106 | _backward_asp_answer_end = [] 107 | _backward_opi_answer_start = [] 108 | _backward_opi_answer_end = [] 109 | _backward_asp_query_mask = [] 110 | _backward_opi_query_mask = [] 111 | _backward_asp_query_seg = [] 112 | _backward_opi_query_seg = [] 113 | 114 | _sentiment_query = [] 115 | _sentiment_answer = [] 116 | _sentiment_query_mask = [] 117 | _sentiment_query_seg = [] 118 | 119 | _aspect_num = [] 120 | _opinion_num = [] 121 | 122 | 123 | for instance in sample_list: 124 | f_query_list = instance.forward_querys 125 | f_answer_list = instance.forward_answers 126 | f_query_seg_list = instance.forward_seg 127 | b_query_list = instance.backward_querys 128 | b_answer_list = instance.backward_answers 129 | b_query_seg_list = instance.backward_seg 130 | s_query_list = instance.sentiment_querys 131 | s_answer_list = instance.sentiment_answers 132 | s_query_seg_list = instance.sentiment_seg 133 | 134 | # _aspect_num: 1/2/3/... 135 | _aspect_num.append(int(len(f_query_list) - 1)) 136 | _opinion_num.append(int(len(b_query_list) - 1)) 137 | 138 | # Forward 139 | # Aspect 140 | # query 141 | assert len(f_query_list[0]) == len(f_answer_list[0][0]) == len(f_answer_list[0][1]) 142 | f_asp_pad_num = max_len['mfor_asp_len'] - len(f_query_list[0]) 143 | 144 | _forward_asp_query.append(_tokenizer.convert_tokens_to_ids( 145 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in f_query_list[0]])) 146 | _forward_asp_query[-1].extend([0] * f_asp_pad_num) 147 | 148 | # query_mask 149 | _forward_asp_query_mask.append([1 for i in range(len(f_query_list[0]))]) 150 | _forward_asp_query_mask[-1].extend([0] * f_asp_pad_num) 151 | 152 | # answer 153 | _forward_asp_answer_start.append(f_answer_list[0][0]) 154 | _forward_asp_answer_start[-1].extend([-1] * f_asp_pad_num) 155 | _forward_asp_answer_end.append(f_answer_list[0][1]) 156 | _forward_asp_answer_end[-1].extend([-1] * f_asp_pad_num) 157 | 158 | # seg 159 | _forward_asp_query_seg.append(f_query_seg_list[0]) 160 | _forward_asp_query_seg[-1].extend([1] * f_asp_pad_num) 161 | 162 | # Opinion 163 | single_opinion_query = [] 164 | single_opinion_query_mask = [] 165 | single_opinion_query_seg = [] 166 | single_opinion_answer_start = [] 167 | single_opinion_answer_end = [] 168 | for i in range(1, len(f_query_list)): 169 | assert len(f_query_list[i]) == len(f_answer_list[i][0]) == len(f_answer_list[i][1]) 170 | pad_num = max_len['mfor_opi_len'] - len(f_query_list[i]) 171 | # query 172 | single_opinion_query.append(_tokenizer.convert_tokens_to_ids( 173 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in f_query_list[i]])) 174 | single_opinion_query[-1].extend([0] * pad_num) 175 | 176 | # query_mask 177 | single_opinion_query_mask.append([1 for i in range(len(f_query_list[i]))]) 178 | single_opinion_query_mask[-1].extend([0] * pad_num) 179 | 180 | # query_seg 181 | single_opinion_query_seg.append(f_query_seg_list[i]) 182 | single_opinion_query_seg[-1].extend([1] * pad_num) 183 | 184 | # answer 185 | single_opinion_answer_start.append(f_answer_list[i][0]) 186 | single_opinion_answer_start[-1].extend([-1] * pad_num) 187 | single_opinion_answer_end.append(f_answer_list[i][1]) 188 | single_opinion_answer_end[-1].extend([-1] * pad_num) 189 | 190 | # PAD: max_aspect_num 191 | _forward_opi_query.append(single_opinion_query) 192 | _forward_opi_query[-1].extend([[0 for i in range(max_len['mfor_opi_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 193 | 194 | _forward_opi_query_mask.append(single_opinion_query_mask) 195 | _forward_opi_query_mask[-1].extend([[0 for i in range(max_len['mfor_opi_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 196 | 197 | _forward_opi_query_seg.append(single_opinion_query_seg) 198 | _forward_opi_query_seg[-1].extend([[0 for i in range(max_len['mfor_opi_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 199 | 200 | _forward_opi_answer_start.append(single_opinion_answer_start) 201 | _forward_opi_answer_start[-1].extend([[-1 for i in range(max_len['mfor_opi_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 202 | _forward_opi_answer_end.append(single_opinion_answer_end) 203 | _forward_opi_answer_end[-1].extend([[-1 for i in range(max_len['mfor_opi_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 204 | 205 | # Backward 206 | # opinion 207 | # query 208 | assert len(b_query_list[0]) == len(b_answer_list[0][0]) == len(b_answer_list[0][1]) 209 | b_opi_pad_num = max_len['mback_opi_len'] - len(b_query_list[0]) 210 | 211 | _backward_opi_query.append(_tokenizer.convert_tokens_to_ids( 212 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in b_query_list[0]])) 213 | _backward_opi_query[-1].extend([0] * b_opi_pad_num) 214 | 215 | # mask 216 | _backward_opi_query_mask.append([1 for i in range(len(b_query_list[0]))]) 217 | _backward_opi_query_mask[-1].extend([0] * b_opi_pad_num) 218 | 219 | # answer 220 | _backward_opi_answer_start.append(b_answer_list[0][0]) 221 | _backward_opi_answer_start[-1].extend([-1] * b_opi_pad_num) 222 | _backward_opi_answer_end.append(b_answer_list[0][1]) 223 | _backward_opi_answer_end[-1].extend([-1] * b_opi_pad_num) 224 | 225 | # seg 226 | _backward_opi_query_seg.append(b_query_seg_list[0]) 227 | _backward_opi_query_seg[-1].extend([1] * b_opi_pad_num) 228 | 229 | # Aspect 230 | single_aspect_query = [] 231 | single_aspect_query_mask = [] 232 | single_aspect_query_seg = [] 233 | single_aspect_answer_start = [] 234 | single_aspect_answer_end = [] 235 | for i in range(1, len(b_query_list)): 236 | assert len(b_query_list[i]) == len(b_answer_list[i][0]) == len(b_answer_list[i][1]) 237 | pad_num = max_len['mback_asp_len'] - len(b_query_list[i]) 238 | # query 239 | single_aspect_query.append(_tokenizer.convert_tokens_to_ids( 240 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in b_query_list[i]])) 241 | single_aspect_query[-1].extend([0] * pad_num) 242 | 243 | # query_mask 244 | single_aspect_query_mask.append([1 for i in range(len(b_query_list[i]))]) 245 | single_aspect_query_mask[-1].extend([0] * pad_num) 246 | 247 | # query_seg 248 | single_aspect_query_seg.append(b_query_seg_list[i]) 249 | single_aspect_query_seg[-1].extend([1] * pad_num) 250 | 251 | # answer 252 | single_aspect_answer_start.append(b_answer_list[i][0]) 253 | single_aspect_answer_start[-1].extend([-1] * pad_num) 254 | single_aspect_answer_end.append(b_answer_list[i][1]) 255 | single_aspect_answer_end[-1].extend([-1] * pad_num) 256 | 257 | # PAD: max_opinion_num 258 | _backward_asp_query.append(single_aspect_query) 259 | _backward_asp_query[-1].extend([[0 for i in range(max_len['mback_asp_len'])]] * (max_len['max_opinion_num'] - _opinion_num[-1])) 260 | 261 | _backward_asp_query_mask.append(single_aspect_query_mask) 262 | _backward_asp_query_mask[-1].extend([[0 for i in range(max_len['mback_asp_len'])]] * (max_len['max_opinion_num'] - _opinion_num[-1])) 263 | 264 | _backward_asp_query_seg.append(single_aspect_query_seg) 265 | _backward_asp_query_seg[-1].extend([[0 for i in range(max_len['mback_asp_len'])]] * (max_len['max_opinion_num'] - _opinion_num[-1])) 266 | 267 | _backward_asp_answer_start.append(single_aspect_answer_start) 268 | _backward_asp_answer_start[-1].extend([[-1 for i in range(max_len['mback_asp_len'])]] * (max_len['max_opinion_num'] - _opinion_num[-1])) 269 | _backward_asp_answer_end.append(single_aspect_answer_end) 270 | _backward_asp_answer_end[-1].extend([[-1 for i in range(max_len['mback_asp_len'])]] * (max_len['max_opinion_num'] - _opinion_num[-1])) 271 | 272 | # Sentiment 273 | single_sentiment_query = [] 274 | single_sentiment_query_mask = [] 275 | single_sentiment_query_seg = [] 276 | single_sentiment_answer = [] 277 | for j in range(len(s_query_list)): 278 | sent_pad_num = max_len['max_sent_len'] - len(s_query_list[j]) 279 | single_sentiment_query.append(_tokenizer.convert_tokens_to_ids( 280 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in s_query_list[j]])) 281 | single_sentiment_query[-1].extend([0] * sent_pad_num) 282 | 283 | single_sentiment_query_mask.append([1 for i in range(len(s_query_list[j]))]) 284 | single_sentiment_query_mask[-1].extend([0] * sent_pad_num) 285 | 286 | # query_seg 287 | single_sentiment_query_seg.append(s_query_seg_list[j]) 288 | single_sentiment_query_seg[-1].extend([1] * sent_pad_num) 289 | 290 | single_sentiment_answer.append(s_answer_list[j]) 291 | 292 | _sentiment_query.append(single_sentiment_query) 293 | _sentiment_query[-1].extend([[0 for i in range(max_len['max_sent_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 294 | 295 | _sentiment_query_mask.append(single_sentiment_query_mask) 296 | _sentiment_query_mask[-1].extend([[0 for i in range(max_len['max_sent_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 297 | 298 | _sentiment_query_seg.append(single_sentiment_query_seg) 299 | _sentiment_query_seg[-1].extend([[0 for i in range(max_len['max_sent_len'])]] * (max_len['max_aspect_num'] - _aspect_num[-1])) 300 | 301 | _sentiment_answer.append(single_sentiment_answer) 302 | _sentiment_answer[-1].extend([-1] * (max_len['max_aspect_num'] - _aspect_num[-1])) 303 | 304 | result = {"_forward_asp_query":_forward_asp_query, "_forward_opi_query":_forward_opi_query, 305 | "_forward_asp_answer_start":_forward_asp_answer_start, "_forward_asp_answer_end":_forward_asp_answer_end, 306 | "_forward_opi_answer_start":_forward_opi_answer_start, "_forward_opi_answer_end":_forward_opi_answer_end, 307 | "_forward_asp_query_mask":_forward_asp_query_mask, "_forward_opi_query_mask":_forward_opi_query_mask, 308 | "_forward_asp_query_seg":_forward_asp_query_seg, "_forward_opi_query_seg":_forward_opi_query_seg, 309 | "_backward_asp_query":_backward_asp_query, "_backward_opi_query":_backward_opi_query, 310 | "_backward_asp_answer_start":_backward_asp_answer_start, "_backward_asp_answer_end":_backward_asp_answer_end, 311 | "_backward_opi_answer_start":_backward_opi_answer_start, "_backward_opi_answer_end":_backward_opi_answer_end, 312 | "_backward_asp_query_mask":_backward_asp_query_mask, "_backward_opi_query_mask":_backward_opi_query_mask, 313 | "_backward_asp_query_seg":_backward_asp_query_seg, "_backward_opi_query_seg":_backward_opi_query_seg, 314 | "_sentiment_query":_sentiment_query, "_sentiment_answer":_sentiment_answer, "_sentiment_query_mask":_sentiment_query_mask, 315 | "_sentiment_query_seg":_sentiment_query_seg, "_aspect_num":_aspect_num, "_opinion_num":_opinion_num} 316 | return OriginalDataset(result) 317 | 318 | 319 | def tokenized_data(data): 320 | max_forward_asp_query_length = 0 321 | max_forward_opi_query_length = 0 322 | max_backward_asp_query_length = 0 323 | max_backward_opi_query_length = 0 324 | max_sentiment_query_length = 0 325 | max_aspect_num = 0 326 | max_opinion_num = 0 327 | tokenized_sample_list = [] 328 | for sample in data: 329 | forward_querys = [] 330 | forward_answers = [] 331 | backward_querys = [] 332 | backward_answers = [] 333 | sentiment_querys = [] 334 | sentiment_answers = [] 335 | 336 | forward_querys_seg = [] 337 | backward_querys_seg = [] 338 | sentiment_querys_seg = [] 339 | if int(len(sample.forward_querys) - 1) > max_aspect_num: 340 | max_aspect_num = int(len(sample.forward_querys) - 1) 341 | if int(len(sample.backward_querys) - 1) > max_opinion_num: 342 | max_opinion_num = int(len(sample.backward_querys) - 1) 343 | for idx in range(len(sample.forward_querys)): 344 | temp_query = sample.forward_querys[idx] 345 | temp_text = sample.text 346 | temp_answer = sample.forward_answers[idx] 347 | temp_query_to = ['[CLS]'] + temp_query + ['[SEP]'] + temp_text 348 | temp_query_seg = [0] * (len(temp_query) + 2) + [1] * len(temp_text) 349 | temp_answer[0] = [-1] * (len(temp_query) + 2) + temp_answer[0] 350 | temp_answer[1] = [-1] * (len(temp_query) + 2) + temp_answer[1] 351 | assert len(temp_answer[0]) == len(temp_answer[1]) == len(temp_query_to) == len(temp_query_seg) 352 | if idx == 0: 353 | if len(temp_query_to) > max_forward_asp_query_length: 354 | max_forward_asp_query_length = len(temp_query_to) 355 | else: 356 | if len(temp_query_to) > max_forward_opi_query_length: 357 | max_forward_opi_query_length = len(temp_query_to) 358 | forward_querys.append(temp_query_to) 359 | forward_answers.append(temp_answer) 360 | forward_querys_seg.append(temp_query_seg) 361 | for idx in range(len(sample.backward_querys)): 362 | temp_query = sample.backward_querys[idx] 363 | temp_text = sample.text 364 | temp_answer = sample.backward_answers[idx] 365 | temp_query_to = ['[CLS]'] + temp_query + ['[SEP]'] + temp_text 366 | temp_query_seg = [0] * (len(temp_query) + 2) + [1] * len(temp_text) 367 | temp_answer[0] = [-1] * (len(temp_query) + 2) + temp_answer[0] 368 | temp_answer[1] = [-1] * (len(temp_query) + 2) + temp_answer[1] 369 | assert len(temp_answer[0]) == len(temp_answer[1]) == len(temp_query_to) == len(temp_query_seg) 370 | if idx == 0: 371 | if len(temp_query_to) > max_backward_opi_query_length: 372 | max_backward_opi_query_length = len(temp_query_to) 373 | else: 374 | if len(temp_query_to) > max_backward_asp_query_length: 375 | max_backward_asp_query_length = len(temp_query_to) 376 | backward_querys.append(temp_query_to) 377 | backward_answers.append(temp_answer) 378 | backward_querys_seg.append(temp_query_seg) 379 | for idx in range(len(sample.sentiment_querys)): 380 | temp_query = sample.sentiment_querys[idx] 381 | temp_text = sample.text 382 | temp_answer = sample.sentiment_answers[idx] 383 | temp_query_to = ['[CLS]'] + temp_query + ['[SEP]'] + temp_text 384 | temp_query_seg = [0] * (len(temp_query) + 2) + [1] * len(temp_text) 385 | assert len(temp_query_to) == len(temp_query_seg) 386 | if len(temp_query_to) > max_sentiment_query_length: 387 | max_sentiment_query_length = len(temp_query_to) 388 | sentiment_querys.append(temp_query_to) 389 | sentiment_answers.append(temp_answer) 390 | sentiment_querys_seg.append(temp_query_seg) 391 | 392 | temp_sample = sample_tokenized(sample.original_sample, forward_querys, forward_answers, backward_querys, 393 | backward_answers, sentiment_querys, sentiment_answers, forward_querys_seg, 394 | backward_querys_seg, sentiment_querys_seg) 395 | tokenized_sample_list.append(temp_sample) 396 | return tokenized_sample_list, {'mfor_asp_len': max_forward_asp_query_length, 397 | 'mfor_opi_len': max_forward_opi_query_length, 398 | 'mback_asp_len': max_backward_asp_query_length, 399 | 'mback_opi_len': max_backward_opi_query_length, 400 | 'max_sent_len': max_sentiment_query_length, 401 | 'max_aspect_num': max_aspect_num, 402 | 'max_opinion_num': max_opinion_num} 403 | 404 | 405 | if __name__ == '__main__': 406 | for dataset_name in ['14rest', '14lap', '15rest', '16rest']: 407 | output_path = './data/preprocess/' + dataset_name + '.pt' 408 | train_data = torch.load("./data/preprocess/" + dataset_name + "_train_dual.pt") 409 | dev_data = torch.load("./data/preprocess/" + dataset_name + "_dev_dual.pt") 410 | test_data = torch.load("./data/preprocess/" + dataset_name + "_test_dual.pt") 411 | 412 | train_tokenized, train_max_len = tokenized_data(train_data) 413 | dev_tokenized, dev_max_len = tokenized_data(dev_data) 414 | test_tokenized, test_max_len = tokenized_data(test_data) 415 | 416 | print('preprocessing_data') 417 | train_preprocess = pre_processing(train_tokenized, train_max_len) 418 | dev_preprocess = pre_processing(dev_tokenized, dev_max_len) 419 | test_preprocess = pre_processing(test_tokenized, test_max_len) 420 | print('save_data') 421 | torch.save({'train': train_preprocess, 'dev': dev_preprocess, 'test': test_preprocess}, output_path) 422 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | # @Author: Shaowei Chen, Contact: chenshaowei0507@163.com 3 | # @Date: 2021-5-4 4 | 5 | import argparse 6 | import Data 7 | import Model 8 | import utils 9 | import torch 10 | from torch.nn import functional as F 11 | from transformers import AdamW, get_linear_schedule_with_warmup, BertTokenizer 12 | import os 13 | from torch.utils.data import Dataset 14 | import random 15 | import numpy as np 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(1) 18 | 19 | class OriginalDataset(Dataset): 20 | def __init__(self, pre_data): 21 | self._forward_asp_query = pre_data['_forward_asp_query'] 22 | self._forward_opi_query = pre_data['_forward_opi_query'] 23 | self._forward_asp_answer_start = pre_data['_forward_asp_answer_start'] 24 | self._forward_asp_answer_end = pre_data['_forward_asp_answer_end'] 25 | self._forward_opi_answer_start = pre_data['_forward_opi_answer_start'] 26 | self._forward_opi_answer_end = pre_data['_forward_opi_answer_end'] 27 | self._forward_asp_query_mask = pre_data['_forward_asp_query_mask'] 28 | self._forward_opi_query_mask = pre_data['_forward_opi_query_mask'] 29 | self._forward_asp_query_seg = pre_data['_forward_asp_query_seg'] 30 | self._forward_opi_query_seg = pre_data['_forward_opi_query_seg'] 31 | 32 | self._backward_asp_query = pre_data['_backward_asp_query'] 33 | self._backward_opi_query = pre_data['_backward_opi_query'] 34 | self._backward_asp_answer_start = pre_data['_backward_asp_answer_start'] 35 | self._backward_asp_answer_end = pre_data['_backward_asp_answer_end'] 36 | self._backward_opi_answer_start = pre_data['_backward_opi_answer_start'] 37 | self._backward_opi_answer_end = pre_data['_backward_opi_answer_end'] 38 | self._backward_asp_query_mask = pre_data[ 39 | '_backward_asp_query_mask'] 40 | self._backward_opi_query_mask = pre_data[ 41 | '_backward_opi_query_mask'] 42 | self._backward_asp_query_seg = pre_data['_backward_asp_query_seg'] 43 | self._backward_opi_query_seg = pre_data['_backward_opi_query_seg'] 44 | 45 | self._sentiment_query = pre_data['_sentiment_query'] 46 | self._sentiment_answer = pre_data['_sentiment_answer'] 47 | self._sentiment_query_mask = pre_data['_sentiment_query_mask'] 48 | self._sentiment_query_seg = pre_data['_sentiment_query_seg'] 49 | 50 | self._aspect_num = pre_data['_aspect_num'] 51 | self._opinion_num = pre_data['_opinion_num'] 52 | 53 | 54 | def test(model, t, batch_generator, standard, beta, logger): 55 | model.eval() 56 | 57 | triplet_target_num = 0 58 | asp_target_num = 0 59 | opi_target_num = 0 60 | asp_opi_target_num = 0 61 | asp_pol_target_num = 0 62 | 63 | triplet_predict_num = 0 64 | asp_predict_num = 0 65 | opi_predict_num = 0 66 | asp_opi_predict_num = 0 67 | asp_pol_predict_num = 0 68 | 69 | triplet_match_num = 0 70 | asp_match_num = 0 71 | opi_match_num = 0 72 | asp_opi_match_num = 0 73 | asp_pol_match_num = 0 74 | 75 | for batch_index, batch_dict in enumerate(batch_generator): 76 | 77 | triplets_target = standard[batch_index]['triplet'] 78 | asp_target = standard[batch_index]['asp_target'] 79 | opi_target = standard[batch_index]['opi_target'] 80 | asp_opi_target = standard[batch_index]['asp_opi_target'] 81 | asp_pol_target = standard[batch_index]['asp_pol_target'] 82 | 83 | # 预测三元组 84 | triplets_predict = [] 85 | asp_predict = [] 86 | opi_predict = [] 87 | asp_opi_predict = [] 88 | asp_pol_predict = [] 89 | 90 | forward_pair_list = [] 91 | forward_pair_prob = [] 92 | forward_pair_ind_list = [] 93 | 94 | backward_pair_list = [] 95 | backward_pair_prob = [] 96 | backward_pair_ind_list = [] 97 | 98 | final_asp_list = [] 99 | final_opi_list = [] 100 | final_asp_ind_list = [] 101 | final_opi_ind_list = [] 102 | # forward q_1 103 | passenge_index = batch_dict['forward_asp_answer_start'][0].gt(-1).float().nonzero() 104 | passenge = batch_dict['forward_asp_query'][0][passenge_index].squeeze(1) 105 | 106 | f_asp_start_scores, f_asp_end_scores = model(batch_dict['forward_asp_query'], 107 | batch_dict['forward_asp_query_mask'], 108 | batch_dict['forward_asp_query_seg'], 0) 109 | f_asp_start_scores = F.softmax(f_asp_start_scores[0], dim=1) 110 | f_asp_end_scores = F.softmax(f_asp_end_scores[0], dim=1) 111 | f_asp_start_prob, f_asp_start_ind = torch.max(f_asp_start_scores, dim=1) 112 | f_asp_end_prob, f_asp_end_ind = torch.max(f_asp_end_scores, dim=1) 113 | 114 | f_asp_start_prob_temp = [] 115 | f_asp_end_prob_temp = [] 116 | f_asp_start_index_temp = [] 117 | f_asp_end_index_temp = [] 118 | for i in range(f_asp_start_ind.size(0)): 119 | if batch_dict['forward_asp_answer_start'][0, i] != -1: 120 | if f_asp_start_ind[i].item() == 1: 121 | f_asp_start_index_temp.append(i) 122 | f_asp_start_prob_temp.append(f_asp_start_prob[i].item()) 123 | if f_asp_end_ind[i].item() == 1: 124 | f_asp_end_index_temp.append(i) 125 | f_asp_end_prob_temp.append(f_asp_end_prob[i].item()) 126 | 127 | 128 | f_asp_start_index, f_asp_end_index, f_asp_prob = utils.filter_unpaired( 129 | f_asp_start_prob_temp, f_asp_end_prob_temp, f_asp_start_index_temp, f_asp_end_index_temp) 130 | 131 | for i in range(len(f_asp_start_index)): 132 | opinion_query = t.convert_tokens_to_ids( 133 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in 134 | '[CLS] What opinion given the aspect'.split(' ')]) 135 | for j in range(f_asp_start_index[i], f_asp_end_index[i] + 1): 136 | opinion_query.append(batch_dict['forward_asp_query'][0][j].item()) 137 | opinion_query.append(t.convert_tokens_to_ids('?')) 138 | opinion_query.append(t.convert_tokens_to_ids('[SEP]')) 139 | opinion_query_seg = [0] * len(opinion_query) 140 | f_opi_length = len(opinion_query) 141 | 142 | opinion_query = torch.tensor(opinion_query).long().cuda() 143 | opinion_query = torch.cat([opinion_query, passenge], -1).unsqueeze(0) 144 | opinion_query_seg += [1] * passenge.size(0) 145 | opinion_query_mask = torch.ones(opinion_query.size(1)).float().cuda().unsqueeze(0) 146 | opinion_query_seg = torch.tensor(opinion_query_seg).long().cuda().unsqueeze(0) 147 | 148 | f_opi_start_scores, f_opi_end_scores = model(opinion_query, opinion_query_mask, opinion_query_seg, 0) 149 | 150 | f_opi_start_scores = F.softmax(f_opi_start_scores[0], dim=1) 151 | f_opi_end_scores = F.softmax(f_opi_end_scores[0], dim=1) 152 | f_opi_start_prob, f_opi_start_ind = torch.max(f_opi_start_scores, dim=1) 153 | f_opi_end_prob, f_opi_end_ind = torch.max(f_opi_end_scores, dim=1) 154 | 155 | f_opi_start_prob_temp = [] 156 | f_opi_end_prob_temp = [] 157 | f_opi_start_index_temp = [] 158 | f_opi_end_index_temp = [] 159 | for k in range(f_opi_start_ind.size(0)): 160 | if opinion_query_seg[0, k] == 1: 161 | if f_opi_start_ind[k].item() == 1: 162 | f_opi_start_index_temp.append(k) 163 | f_opi_start_prob_temp.append(f_opi_start_prob[k].item()) 164 | if f_opi_end_ind[k].item() == 1: 165 | f_opi_end_index_temp.append(k) 166 | f_opi_end_prob_temp.append(f_opi_end_prob[k].item()) 167 | 168 | 169 | f_opi_start_index, f_opi_end_index, f_opi_prob = utils.filter_unpaired( 170 | f_opi_start_prob_temp, f_opi_end_prob_temp, f_opi_start_index_temp, f_opi_end_index_temp) 171 | 172 | 173 | for idx in range(len(f_opi_start_index)): 174 | asp = [batch_dict['forward_asp_query'][0][j].item() for j in range(f_asp_start_index[i], f_asp_end_index[i] + 1)] 175 | opi = [opinion_query[0][j].item() for j in range(f_opi_start_index[idx], f_opi_end_index[idx] + 1)] 176 | asp_ind = [f_asp_start_index[i]-5, f_asp_end_index[i]-5] 177 | opi_ind = [f_opi_start_index[idx]-f_opi_length, f_opi_end_index[idx]-f_opi_length] 178 | temp_prob = f_asp_prob[i] * f_opi_prob[idx] 179 | if asp_ind + opi_ind not in forward_pair_ind_list: 180 | forward_pair_list.append([asp] + [opi]) 181 | forward_pair_prob.append(temp_prob) 182 | forward_pair_ind_list.append(asp_ind + opi_ind) 183 | else: 184 | print('erro') 185 | exit(1) 186 | 187 | # backward q_1 188 | b_opi_start_scores, b_opi_end_scores = model(batch_dict['backward_opi_query'], 189 | batch_dict['backward_opi_query_mask'], 190 | batch_dict['backward_opi_query_seg'], 0) 191 | b_opi_start_scores = F.softmax(b_opi_start_scores[0], dim=1) 192 | b_opi_end_scores = F.softmax(b_opi_end_scores[0], dim=1) 193 | b_opi_start_prob, b_opi_start_ind = torch.max(b_opi_start_scores, dim=1) 194 | b_opi_end_prob, b_opi_end_ind = torch.max(b_opi_end_scores, dim=1) 195 | 196 | 197 | b_opi_start_prob_temp = [] 198 | b_opi_end_prob_temp = [] 199 | b_opi_start_index_temp = [] 200 | b_opi_end_index_temp = [] 201 | for i in range(b_opi_start_ind.size(0)): 202 | if batch_dict['backward_opi_answer_start'][0, i] != -1: 203 | if b_opi_start_ind[i].item() == 1: 204 | b_opi_start_index_temp.append(i) 205 | b_opi_start_prob_temp.append(b_opi_start_prob[i].item()) 206 | if b_opi_end_ind[i].item() == 1: 207 | b_opi_end_index_temp.append(i) 208 | b_opi_end_prob_temp.append(b_opi_end_prob[i].item()) 209 | 210 | b_opi_start_index, b_opi_end_index, b_opi_prob = utils.filter_unpaired( 211 | b_opi_start_prob_temp, b_opi_end_prob_temp, b_opi_start_index_temp, b_opi_end_index_temp) 212 | 213 | 214 | 215 | # backward q_2 216 | for i in range(len(b_opi_start_index)): 217 | aspect_query = t.convert_tokens_to_ids( 218 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in 219 | '[CLS] What aspect does the opinion'.split(' ')]) 220 | for j in range(b_opi_start_index[i], b_opi_end_index[i] + 1): 221 | aspect_query.append(batch_dict['backward_opi_query'][0][j].item()) 222 | aspect_query.append(t.convert_tokens_to_ids('describe')) 223 | aspect_query.append(t.convert_tokens_to_ids('?')) 224 | aspect_query.append(t.convert_tokens_to_ids('[SEP]')) 225 | aspect_query_seg = [0] * len(aspect_query) 226 | b_asp_length = len(aspect_query) 227 | aspect_query = torch.tensor(aspect_query).long().cuda() 228 | aspect_query = torch.cat([aspect_query, passenge], -1).unsqueeze(0) 229 | aspect_query_seg += [1] * passenge.size(0) 230 | aspect_query_mask = torch.ones(aspect_query.size(1)).float().cuda().unsqueeze(0) 231 | aspect_query_seg = torch.tensor(aspect_query_seg).long().cuda().unsqueeze(0) 232 | 233 | b_asp_start_scores, b_asp_end_scores = model(aspect_query, aspect_query_mask, aspect_query_seg, 0) 234 | 235 | b_asp_start_scores = F.softmax(b_asp_start_scores[0], dim=1) 236 | b_asp_end_scores = F.softmax(b_asp_end_scores[0], dim=1) 237 | b_asp_start_prob, b_asp_start_ind = torch.max(b_asp_start_scores, dim=1) 238 | b_asp_end_prob, b_asp_end_ind = torch.max(b_asp_end_scores, dim=1) 239 | 240 | b_asp_start_prob_temp = [] 241 | b_asp_end_prob_temp = [] 242 | b_asp_start_index_temp = [] 243 | b_asp_end_index_temp = [] 244 | for k in range(b_asp_start_ind.size(0)): 245 | if aspect_query_seg[0, k] == 1: 246 | if b_asp_start_ind[k].item() == 1: 247 | b_asp_start_index_temp.append(k) 248 | b_asp_start_prob_temp.append(b_asp_start_prob[k].item()) 249 | if b_asp_end_ind[k].item() == 1: 250 | b_asp_end_index_temp.append(k) 251 | b_asp_end_prob_temp.append(b_asp_end_prob[k].item()) 252 | 253 | b_asp_start_index, b_asp_end_index, b_asp_prob = utils.filter_unpaired( 254 | b_asp_start_prob_temp, b_asp_end_prob_temp, b_asp_start_index_temp, b_asp_end_index_temp) 255 | 256 | for idx in range(len(b_asp_start_index)): 257 | opi = [batch_dict['backward_opi_query'][0][j].item() for j in 258 | range(b_opi_start_index[i], b_opi_end_index[i] + 1)] 259 | asp = [aspect_query[0][j].item() for j in range(b_asp_start_index[idx], b_asp_end_index[idx] + 1)] 260 | asp_ind = [b_asp_start_index[idx]-b_asp_length, b_asp_end_index[idx]-b_asp_length] 261 | opi_ind = [b_opi_start_index[i]-5, b_opi_end_index[i]-5] 262 | temp_prob = b_asp_prob[idx] * b_opi_prob[i] 263 | if asp_ind + opi_ind not in backward_pair_ind_list: 264 | backward_pair_list.append([asp] + [opi]) 265 | backward_pair_prob.append(temp_prob) 266 | backward_pair_ind_list.append(asp_ind + opi_ind) 267 | else: 268 | print('erro') 269 | exit(1) 270 | # filter triplet 271 | # forward 272 | for idx in range(len(forward_pair_list)): 273 | if forward_pair_list[idx] in backward_pair_list: 274 | if forward_pair_list[idx][0] not in final_asp_list: 275 | final_asp_list.append(forward_pair_list[idx][0]) 276 | final_opi_list.append([forward_pair_list[idx][1]]) 277 | final_asp_ind_list.append(forward_pair_ind_list[idx][:2]) 278 | final_opi_ind_list.append([forward_pair_ind_list[idx][2:]]) 279 | else: 280 | asp_index = final_asp_list.index(forward_pair_list[idx][0]) 281 | if forward_pair_list[idx][1] not in final_opi_list[asp_index]: 282 | final_opi_list[asp_index].append(forward_pair_list[idx][1]) 283 | final_opi_ind_list[asp_index].append(forward_pair_ind_list[idx][2:]) 284 | else: 285 | if forward_pair_prob[idx] >= beta: 286 | if forward_pair_list[idx][0] not in final_asp_list: 287 | final_asp_list.append(forward_pair_list[idx][0]) 288 | final_opi_list.append([forward_pair_list[idx][1]]) 289 | final_asp_ind_list.append(forward_pair_ind_list[idx][:2]) 290 | final_opi_ind_list.append([forward_pair_ind_list[idx][2:]]) 291 | else: 292 | asp_index = final_asp_list.index(forward_pair_list[idx][0]) 293 | if forward_pair_list[idx][1] not in final_opi_list[asp_index]: 294 | final_opi_list[asp_index].append(forward_pair_list[idx][1]) 295 | final_opi_ind_list[asp_index].append(forward_pair_ind_list[idx][2:]) 296 | # backward 297 | for idx in range(len(backward_pair_list)): 298 | if backward_pair_list[idx] not in forward_pair_list: 299 | if backward_pair_prob[idx] >= beta: 300 | if backward_pair_list[idx][0] not in final_asp_list: 301 | final_asp_list.append(backward_pair_list[idx][0]) 302 | final_opi_list.append([backward_pair_list[idx][1]]) 303 | final_asp_ind_list.append(backward_pair_ind_list[idx][:2]) 304 | final_opi_ind_list.append([backward_pair_ind_list[idx][2:]]) 305 | else: 306 | asp_index = final_asp_list.index(backward_pair_list[idx][0]) 307 | if backward_pair_list[idx][1] not in final_opi_list[asp_index]: 308 | final_opi_list[asp_index].append(backward_pair_list[idx][1]) 309 | final_opi_ind_list[asp_index].append(backward_pair_ind_list[idx][2:]) 310 | # sentiment 311 | for idx in range(len(final_asp_list)): 312 | predict_opinion_num = len(final_opi_list[idx]) 313 | sentiment_query = t.convert_tokens_to_ids( 314 | [word.lower() if word not in ['[CLS]', '[SEP]'] else word for word in 315 | '[CLS] What sentiment given the aspect'.split(' ')]) 316 | sentiment_query+=final_asp_list[idx] 317 | sentiment_query += t.convert_tokens_to_ids([word.lower() for word in 'and the opinion'.split(' ')]) 318 | # # 拼接所有的opinion 319 | for idy in range(predict_opinion_num): 320 | sentiment_query+=final_opi_list[idx][idy] 321 | if idy < predict_opinion_num - 1: 322 | sentiment_query.append(t.convert_tokens_to_ids('/')) 323 | sentiment_query.append(t.convert_tokens_to_ids('?')) 324 | sentiment_query.append(t.convert_tokens_to_ids('[SEP]')) 325 | 326 | sentiment_query_seg = [0] * len(sentiment_query) 327 | sentiment_query = torch.tensor(sentiment_query).long().cuda() 328 | sentiment_query = torch.cat([sentiment_query, passenge], -1).unsqueeze(0) 329 | sentiment_query_seg += [1] * passenge.size(0) 330 | sentiment_query_mask = torch.ones(sentiment_query.size(1)).float().cuda().unsqueeze(0) 331 | sentiment_query_seg = torch.tensor(sentiment_query_seg).long().cuda().unsqueeze(0) 332 | 333 | sentiment_scores = model(sentiment_query, sentiment_query_mask, sentiment_query_seg, 1) 334 | sentiment_predicted = torch.argmax(sentiment_scores[0], dim=0).item() 335 | 336 | # 每个opinion对应一个三元组 337 | for idy in range(predict_opinion_num): 338 | asp_f = [] 339 | opi_f = [] 340 | asp_f.append(final_asp_ind_list[idx][0]) 341 | asp_f.append(final_asp_ind_list[idx][1]) 342 | opi_f.append(final_opi_ind_list[idx][idy][0]) 343 | opi_f.append(final_opi_ind_list[idx][idy][1]) 344 | triplet_predict = asp_f + opi_f + [sentiment_predicted] 345 | triplets_predict.append(triplet_predict) 346 | if opi_f not in opi_predict: 347 | opi_predict.append(opi_f) 348 | if asp_f + opi_f not in asp_opi_predict: 349 | asp_opi_predict.append(asp_f + opi_f) 350 | if asp_f + [sentiment_predicted] not in asp_pol_predict: 351 | asp_pol_predict.append(asp_f + [sentiment_predicted]) 352 | if asp_f not in asp_predict: 353 | asp_predict.append(asp_f) 354 | 355 | triplet_target_num += len(triplets_target) 356 | asp_target_num += len(asp_target) 357 | opi_target_num += len(opi_target) 358 | asp_opi_target_num += len(asp_opi_target) 359 | asp_pol_target_num += len(asp_pol_target) 360 | 361 | triplet_predict_num += len(triplets_predict) 362 | asp_predict_num += len(asp_predict) 363 | opi_predict_num += len(opi_predict) 364 | asp_opi_predict_num += len(asp_opi_predict) 365 | asp_pol_predict_num += len(asp_pol_predict) 366 | 367 | for trip in triplets_target: 368 | for trip_ in triplets_predict: 369 | if trip_ == trip: 370 | triplet_match_num += 1 371 | for trip in asp_target: 372 | for trip_ in asp_predict: 373 | if trip_ == trip: 374 | asp_match_num += 1 375 | for trip in opi_target: 376 | for trip_ in opi_predict: 377 | if trip_ == trip: 378 | opi_match_num += 1 379 | for trip in asp_opi_target: 380 | for trip_ in asp_opi_predict: 381 | if trip_ == trip: 382 | asp_opi_match_num += 1 383 | for trip in asp_pol_target: 384 | for trip_ in asp_pol_predict: 385 | if trip_ == trip: 386 | asp_pol_match_num += 1 387 | 388 | precision = float(triplet_match_num) / float(triplet_predict_num+1e-6) 389 | recall = float(triplet_match_num) / float(triplet_target_num+1e-6) 390 | f1 = 2 * precision * recall / (precision + recall+1e-6) 391 | logger.info('Triplet - Precision: {}\tRecall: {}\tF1: {}'.format(precision, recall, f1)) 392 | 393 | 394 | precision_aspect = float(asp_match_num) / float(asp_predict_num+1e-6) 395 | recall_aspect = float(asp_match_num) / float(asp_target_num+1e-6) 396 | f1_aspect = 2 * precision_aspect * recall_aspect / (precision_aspect + recall_aspect+1e-6) 397 | logger.info('Aspect - Precision: {}\tRecall: {}\tF1: {}'.format(precision_aspect, recall_aspect, f1_aspect)) 398 | 399 | precision_opinion = float(opi_match_num) / float(opi_predict_num+1e-6) 400 | recall_opinion = float(opi_match_num) / float(opi_target_num+1e-6) 401 | f1_opinion = 2 * precision_opinion * recall_opinion / (precision_opinion + recall_opinion+1e-6) 402 | logger.info('Opinion - Precision: {}\tRecall: {}\tF1: {}'.format(precision_opinion, recall_opinion, f1_opinion)) 403 | 404 | precision_aspect_sentiment = float(asp_pol_match_num) / float(asp_pol_predict_num+1e-6) 405 | recall_aspect_sentiment = float(asp_pol_match_num) / float(asp_pol_target_num+1e-6) 406 | f1_aspect_sentiment = 2 * precision_aspect_sentiment * recall_aspect_sentiment / ( 407 | precision_aspect_sentiment + recall_aspect_sentiment+1e-6) 408 | logger.info('Aspect-Sentiment - Precision: {}\tRecall: {}\tF1: {}'.format(precision_aspect_sentiment, 409 | recall_aspect_sentiment, 410 | f1_aspect_sentiment)) 411 | 412 | precision_aspect_opinion = float(asp_opi_match_num) / float(asp_opi_predict_num+1e-6) 413 | recall_aspect_opinion = float(asp_opi_match_num) / float(asp_opi_target_num+1e-6) 414 | f1_aspect_opinion = 2 * precision_aspect_opinion * recall_aspect_opinion / ( 415 | precision_aspect_opinion + recall_aspect_opinion+1e-6) 416 | logger.info( 417 | 'Aspect-Opinion - Precision: {}\tRecall: {}\tF1: {}'.format(precision_aspect_opinion, recall_aspect_opinion, 418 | f1_aspect_opinion)) 419 | return f1 420 | 421 | 422 | def main(args, tokenize): 423 | args.log_path = args.log_path + args.data_name + '_' + args.model_name + '.log' 424 | data_path = args.data_path + args.data_name + '.pt' 425 | standard_data_path = args.data_path + args.data_name + '_standard.pt' 426 | 427 | # init logger 428 | logger = utils.get_logger(args.log_path) 429 | 430 | # load data 431 | logger.info('loading data......') 432 | total_data = torch.load(data_path) 433 | standard_data = torch.load(standard_data_path) 434 | train_data = total_data['train'] 435 | dev_data = total_data['dev'] 436 | test_data = total_data['test'] 437 | dev_standard = standard_data['dev'] 438 | test_standard = standard_data['test'] 439 | 440 | # init model 441 | logger.info('initial model......') 442 | model = Model.BERTModel(args) 443 | if args.ifgpu: 444 | model = model.cuda() 445 | 446 | # print args 447 | logger.info(args) 448 | 449 | if args.mode == 'test': 450 | logger.info('start testing......') 451 | test_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'test') 452 | # load checkpoint 453 | logger.info('loading checkpoint......') 454 | checkpoint = torch.load(args.checkpoint_path) 455 | model.load_state_dict(checkpoint['net']) 456 | model.eval() 457 | 458 | batch_generator_test = Data.generate_fi_batches(dataset=test_dataset, batch_size=1, shuffle=False, 459 | ifgpu=args.ifgpu) 460 | # eval 461 | logger.info('evaluating......') 462 | f1 = test(model, tokenize, batch_generator_test, test_standard, args.beta, logger) 463 | 464 | 465 | elif args.mode == 'train': 466 | args.save_model_path = args.save_model_path + args.data_name + '_' + args.model_name + '.pth' 467 | train_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'train') 468 | dev_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'dev') 469 | test_dataset = Data.ReviewDataset(train_data, dev_data, test_data, 'test') 470 | batch_num_train = train_dataset.get_batch_num(args.batch_size) 471 | 472 | # optimizer 473 | logger.info('initial optimizer......') 474 | param_optimizer = list(model.named_parameters()) 475 | optimizer_grouped_parameters = [ 476 | {'params': [p for n, p in param_optimizer if "_bert" in n], 'weight_decay': 0.01}, 477 | {'params': [p for n, p in param_optimizer if "_bert" not in n], 478 | 'lr': args.learning_rate, 'weight_decay': 0.01}] 479 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.tuning_bert_rate, correct_bias=False) 480 | 481 | # load saved model, optimizer and epoch num 482 | if args.reload and os.path.exists(args.checkpoint_path): 483 | checkpoint = torch.load(args.checkpoint_path) 484 | model.load_state_dict(checkpoint['net']) 485 | optimizer.load_state_dict(checkpoint['optimizer']) 486 | start_epoch = checkpoint['epoch'] + 1 487 | logger.info('Reload model and optimizer after training epoch {}'.format(checkpoint['epoch'])) 488 | else: 489 | start_epoch = 1 490 | logger.info('New model and optimizer from epoch 0') 491 | 492 | # scheduler 493 | training_steps = args.epoch_num * batch_num_train 494 | warmup_steps = int(training_steps * args.warm_up) 495 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 496 | num_training_steps=training_steps) 497 | 498 | # training 499 | logger.info('begin training......') 500 | best_dev_f1 = 0. 501 | for epoch in range(start_epoch, args.epoch_num+1): 502 | model.train() 503 | model.zero_grad() 504 | 505 | batch_generator = Data.generate_fi_batches(dataset=train_dataset, batch_size=args.batch_size, 506 | ifgpu=args.ifgpu) 507 | 508 | for batch_index, batch_dict in enumerate(batch_generator): 509 | 510 | optimizer.zero_grad() 511 | 512 | # q1_a 513 | f_aspect_start_scores, f_aspect_end_scores = model(batch_dict['forward_asp_query'], 514 | batch_dict['forward_asp_query_mask'], 515 | batch_dict['forward_asp_query_seg'], 0) 516 | f_asp_loss = utils.calculate_entity_loss(f_aspect_start_scores, f_aspect_end_scores, 517 | batch_dict['forward_asp_answer_start'], 518 | batch_dict['forward_asp_answer_end']) 519 | # q1_b 520 | b_opi_start_scores, b_opi_end_scores = model(batch_dict['backward_opi_query'], 521 | batch_dict['backward_opi_query_mask'], 522 | batch_dict['backward_opi_query_seg'], 0) 523 | b_opi_loss = utils.calculate_entity_loss(b_opi_start_scores, b_opi_end_scores, 524 | batch_dict['backward_opi_answer_start'], 525 | batch_dict['backward_opi_answer_end']) 526 | # q2_a 527 | f_opi_start_scores, f_opi_end_scores = model( 528 | batch_dict['forward_opi_query'].view(-1, batch_dict['forward_opi_query'].size(-1)), 529 | batch_dict['forward_opi_query_mask'].view(-1, batch_dict['forward_opi_query_mask'].size(-1)), 530 | batch_dict['forward_opi_query_seg'].view(-1, batch_dict['forward_opi_query_seg'].size(-1)), 531 | 0) 532 | f_opi_loss = utils.calculate_entity_loss(f_opi_start_scores, f_opi_end_scores, 533 | batch_dict['forward_opi_answer_start'].view(-1, batch_dict['forward_opi_answer_start'].size(-1)), 534 | batch_dict['forward_opi_answer_end'].view(-1, batch_dict['forward_opi_answer_end'].size(-1))) 535 | # q2_b 536 | b_asp_start_scores, b_asp_end_scores = model( 537 | batch_dict['backward_asp_query'].view(-1, batch_dict['backward_asp_query'].size(-1)), 538 | batch_dict['backward_asp_query_mask'].view(-1, batch_dict['backward_asp_query_mask'].size(-1)), 539 | batch_dict['backward_asp_query_seg'].view(-1, batch_dict['backward_asp_query_seg'].size(-1)), 540 | 0) 541 | b_asp_loss = utils.calculate_entity_loss(b_asp_start_scores, b_asp_end_scores, 542 | batch_dict['backward_asp_answer_start'].view(-1, batch_dict['backward_asp_answer_start'].size(-1)), 543 | batch_dict['backward_asp_answer_end'].view(-1, batch_dict['backward_asp_answer_end'].size(-1))) 544 | # q_3 545 | sentiment_scores = model(batch_dict['sentiment_query'].view(-1, batch_dict['sentiment_query'].size(-1)), 546 | batch_dict['sentiment_query_mask'].view(-1, batch_dict['sentiment_query_mask'].size(-1)), 547 | batch_dict['sentiment_query_seg'].view(-1, batch_dict['sentiment_query_seg'].size(-1)), 548 | 1) 549 | sentiment_loss = utils.calculate_sentiment_loss(sentiment_scores, batch_dict['sentiment_answer'].view(-1)) 550 | 551 | # loss 552 | loss_sum = f_asp_loss + f_opi_loss + b_opi_loss + b_asp_loss + args.beta*sentiment_loss 553 | loss_sum.backward() 554 | optimizer.step() 555 | scheduler.step() 556 | 557 | # train logger 558 | if batch_index % 10 == 0: 559 | logger.info('Epoch:[{}/{}]\t Batch:[{}/{}]\t Loss Sum:{}\t ' 560 | 'forward Loss:{};{}\t backward Loss:{};{}\t Sentiment Loss:{}'. 561 | format(epoch, args.epoch_num, batch_index, batch_num_train, 562 | round(loss_sum.item(), 4), 563 | round(f_asp_loss.item(), 4), round(f_opi_loss.item(), 4), 564 | round(b_asp_loss.item(), 4), round(b_opi_loss.item(), 4), 565 | round(sentiment_loss.item(), 4))) 566 | 567 | # validation 568 | batch_generator_dev = Data.generate_fi_batches(dataset=dev_dataset, batch_size=1, shuffle=False, 569 | ifgpu=args.ifgpu) 570 | f1 = test(model, tokenize, batch_generator_dev, dev_standard, args.inference_beta, logger) 571 | # save model and optimizer 572 | if f1 > best_dev_f1: 573 | best_dev_f1 = f1 574 | logger.info('Model saved after epoch {}'.format(epoch)) 575 | state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} 576 | torch.save(state, args.save_model_path) 577 | 578 | # test 579 | batch_generator_test = Data.generate_fi_batches(dataset=test_dataset, batch_size=1, shuffle=False, 580 | ifgpu=args.ifgpu) 581 | f1 = test(model, tokenize, batch_generator_test, test_standard, args.inference_beta, logger) 582 | 583 | else: 584 | logger.info('Error mode!') 585 | exit(1) 586 | 587 | 588 | if __name__ == '__main__': 589 | parser = argparse.ArgumentParser(description='Bidirectional MRC-based sentiment triplet extraction') 590 | parser.add_argument('--data_path', type=str, default="./data/preprocess/") 591 | parser.add_argument('--log_path', type=str, default="./log/") 592 | parser.add_argument('--data_name', type=str, default="14lap", choices=["14lap", "14rest", "15rest", "16rest"]) 593 | 594 | parser.add_argument('--mode', type=str, default="train", choices=["train", "test"]) 595 | 596 | parser.add_argument('--reload', type=bool, default=False) 597 | parser.add_argument('--checkpoint_path', type=str, default="./model/14lap/modelFinal.model") 598 | parser.add_argument('--save_model_path', type=str, default="./model/") 599 | parser.add_argument('--model_name', type=str, default="1") 600 | 601 | # model hyper-parameter 602 | parser.add_argument('--bert_model_type', type=str, default="bert-base-uncased") 603 | parser.add_argument('--hidden_size', type=int, default=768) 604 | parser.add_argument('--inference_beta', type=float, default=0.8) 605 | 606 | # training hyper-parameter 607 | parser.add_argument('--ifgpu', type=bool, default=True) 608 | parser.add_argument('--epoch_num', type=int, default=40) 609 | parser.add_argument('--batch_size', type=int, default=4) 610 | parser.add_argument('--learning_rate', type=float, default=1e-3) 611 | parser.add_argument('--tuning_bert_rate', type=float, default=1e-5) 612 | parser.add_argument('--warm_up', type=float, default=0.1) 613 | parser.add_argument('--beta', type=float, default=1) 614 | 615 | args = parser.parse_args() 616 | 617 | t = BertTokenizer.from_pretrained(args.bert_model_type) 618 | 619 | main(args, t) 620 | --------------------------------------------------------------------------------