├── paper.pdf ├── code ├── bash │ ├── aste_14res.sh │ ├── aste_15res.sh │ ├── aste_16res.sh │ ├── aste_14lap.sh │ └── aste.sh ├── model │ ├── table.py │ ├── table_encoder │ │ └── resnet.py │ ├── bdtf_model.py │ ├── matching_layer.py │ └── seq2mat.py ├── utils │ ├── __init__.py │ ├── aste_result.py │ └── aste_datamodule.py └── aste_train.py ├── README.md └── README_EN.md /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HITSZ-HLT/BDTF-ABSA/HEAD/paper.pdf -------------------------------------------------------------------------------- /code/bash/aste_14res.sh: -------------------------------------------------------------------------------- 1 | bash bash/aste.sh -s 50 -c 0 -d V2/14res -l 5 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64 -------------------------------------------------------------------------------- /code/bash/aste_15res.sh: -------------------------------------------------------------------------------- 1 | bash bash/aste.sh -s 50 -c 0 -d V2/15res -l 4 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64 -------------------------------------------------------------------------------- /code/bash/aste_16res.sh: -------------------------------------------------------------------------------- 1 | bash bash/aste.sh -s 50 -c 0 -d V2/16res -l 5 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64 -------------------------------------------------------------------------------- /code/bash/aste_14lap.sh: -------------------------------------------------------------------------------- 1 | bash bash/aste.sh -s 40 -c 0 -d V2/14lap -l 3 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64 2 | -------------------------------------------------------------------------------- /code/model/table.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .seq2mat import * 4 | from .table_encoder.resnet import ResNet 5 | 6 | class TableEncoder(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | 10 | self.config = config 11 | if config.seq2mat == 'tensor': 12 | self.seq2mat = TensorSeq2Mat(config) 13 | elif config.seq2mat == 'tensorcontext': 14 | self.seq2mat = TensorcontextSeq2Mat(config) 15 | elif config.seq2mat == 'context': 16 | self.seq2mat = ContextSeq2Mat(config) 17 | else: 18 | self.seq2mat = Seq2Mat(config) 19 | 20 | if config.table_encoder != 'none': 21 | self.layer = nn.ModuleList([ResNet(config) for _ in range(config.num_table_layers)]) 22 | 23 | def forward(self, seq, mask): 24 | table = self.seq2mat(seq, seq) 25 | 26 | if self.config.table_encoder == 'none': 27 | return table 28 | 29 | for layer_module in self.layer: 30 | table = layer_module(table) 31 | 32 | return table 33 | -------------------------------------------------------------------------------- /code/bash/aste.sh: -------------------------------------------------------------------------------- 1 | while getopts ':d:s:c:t:n:l:z:e:D:' opt 2 | do 3 | case $opt in 4 | d) 5 | dataset="$OPTARG" ;; 6 | s) 7 | seed="$OPTARG" ;; 8 | c) 9 | CUDA_IDS="$OPTARG" ;; 10 | t) 11 | table_encoder="$OPTARG" ;; 12 | n) 13 | num_table_layers="$OPTARG" ;; 14 | l) 15 | learning_rate="$OPTARG" ;; 16 | z) 17 | span_pruning="$OPTARG" ;; 18 | e) 19 | seq2mat="$OPTARG" ;; 20 | D) 21 | num_d="$OPTARG" ;; 22 | ?) 23 | exit 1;; 24 | esac 25 | done 26 | 27 | 28 | if [ ! "${table_encoder}" ] 29 | then 30 | table_encoder=resnet 31 | fi 32 | 33 | 34 | if [ ! "${num_table_layers}" ] 35 | then 36 | num_table_layers=2 37 | fi 38 | 39 | 40 | gradient_clip_val=1 41 | warmup_steps=100 42 | weight_decay=0.01 43 | precision=16 44 | batch_size=4 45 | data_dir="../data/aste_data_bert/${dataset}" 46 | 47 | 48 | CUDA_VISIBLE_DEVICES=${CUDA_IDS} python3 aste_train.py \ 49 | --gpus=1 \ 50 | --precision=${precision} \ 51 | --data_dir ${data_dir} \ 52 | --model_name_or_path 'bert-base-uncased' \ 53 | --output_dir ../output/ASTE/${dataset}/ \ 54 | --learning_rate ${learning_rate}e-5 \ 55 | --train_batch_size ${batch_size} \ 56 | --eval_batch_size ${batch_size} \ 57 | --seed $seed \ 58 | --warmup_steps ${warmup_steps} \ 59 | --lr_scheduler linear \ 60 | --gradient_clip_val ${gradient_clip_val} \ 61 | --weight_decay ${weight_decay} \ 62 | --max_seq_length -1 \ 63 | --max_epochs 10 \ 64 | --cuda_ids ${CUDA_IDS} \ 65 | --do_train \ 66 | --table_encoder ${table_encoder} \ 67 | --num_table_layers ${num_table_layers} \ 68 | --span_pruning ${span_pruning} \ 69 | --seq2mat ${seq2mat} \ 70 | --num_d ${num_d} -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | 5 | 6 | 7 | class NpEncoder(json.JSONEncoder): 8 | def default(self, obj): 9 | if isinstance(obj, np.integer): 10 | return int(obj) 11 | elif isinstance(obj, np.floating): 12 | return float(obj) 13 | elif isinstance(obj, np.ndarray): 14 | return obj.tolist() 15 | else: 16 | return super(NpEncoder, self).default(obj) 17 | 18 | 19 | def load_json(file_name): 20 | with open(file_name, mode='r', encoding='utf-8-sig') as f: 21 | return json.load(f) 22 | 23 | 24 | def append_json(file_name, obj, mode='a'): 25 | mkdir_if_not_exist(file_name) 26 | with open(file_name, mode=mode, encoding='utf-8') as f: 27 | if type(obj) is dict: 28 | string = json.dumps(obj) 29 | elif type(obj) is list: 30 | string = ' '.join([str(item) for item in obj]) 31 | elif type(obj) is str: 32 | string = obj 33 | else: 34 | raise Exception() 35 | 36 | string = string + '\n' 37 | f.write(string) 38 | 39 | 40 | def mkdir_if_not_exist(path): 41 | dir_name, file_name = os.path.split(path) 42 | if not os.path.exists(dir_name): 43 | os.makedirs(dir_name) 44 | 45 | 46 | def save_json(json_obj, file_name): 47 | mkdir_if_not_exist(file_name) 48 | with open(file_name, mode='w+', encoding='utf-8-sig') as f: 49 | json.dump(json_obj, f, indent=4, cls=NpEncoder) 50 | 51 | 52 | 53 | def params_count(model): 54 | """ 55 | Compute the number of parameters. 56 | Args: 57 | model (model): model to count the number of parameters. 58 | """ 59 | return np.sum([p.numel() for p in model.parameters()]).item() 60 | 61 | 62 | 63 | def load_json(file_name): 64 | with open(file_name, mode='r', encoding='utf-8-sig') as f: 65 | return json.load(f) 66 | 67 | 68 | def yield_data_file(data_dir): 69 | for file_name in os.listdir(data_dir): 70 | yield os.path.join(data_dir, file_name) 71 | -------------------------------------------------------------------------------- /code/model/table_encoder/resnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from transformers.models.t5.modeling_t5 import T5LayerNorm 4 | from einops import rearrange 5 | 6 | class ResNet(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | if config.model_type == 'bart': 10 | layer_norm_eps = 1e-12 11 | elif config.model_type == 't5': 12 | layer_norm_eps = 1e-12 13 | else: 14 | layer_norm_eps = config.layer_norm_eps 15 | self.conv1 = nn.Conv2d( 16 | in_channels =config.hidden_size, 17 | out_channels=config.hidden_size, 18 | kernel_size =(1, 1), 19 | padding=0 20 | ) 21 | self.norm1 = T5LayerNorm(config.hidden_size, layer_norm_eps) 22 | 23 | self.conv2 = nn.Conv2d( 24 | in_channels =config.hidden_size, 25 | out_channels=config.hidden_size, 26 | kernel_size =(3, 3), 27 | padding=1 28 | ) 29 | self.norm2 = T5LayerNorm(config.hidden_size, layer_norm_eps) 30 | 31 | self.conv3 = nn.Conv2d( 32 | in_channels =config.hidden_size, 33 | out_channels=config.hidden_size, 34 | kernel_size =(1, 1), 35 | padding=0 36 | ) 37 | self.norm3 = T5LayerNorm(config.hidden_size, layer_norm_eps) 38 | 39 | def layer_forward(self, x, conv, norm): 40 | x = conv(x) 41 | n = x.size(-1) 42 | x = rearrange(x, 'b d m n -> b (m n) d') 43 | x = norm(x) 44 | x = F.relu(x) 45 | x = rearrange(x, 'b (m n) d -> b d m n', n=n) 46 | return x 47 | 48 | def forward(self, x_input, **kwargs): 49 | x = rearrange(x_input, 'b m n d -> b d m n') 50 | x = self.layer_forward(x, self.conv1, self.norm1) 51 | x = self.layer_forward(x, self.conv2, self.norm2) 52 | x = self.layer_forward(x, self.conv3, self.norm3) 53 | x = rearrange(x, 'b d m n -> b m n d') 54 | return x + x_input 55 | -------------------------------------------------------------------------------- /code/model/bdtf_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertModel, BertPreTrainedModel 4 | from .table import TableEncoder 5 | from .matching_layer import MatchingLayer 6 | 7 | 8 | class BDTFModel(BertPreTrainedModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | 12 | self.bert = BertModel(config) 13 | self.table_encoder = TableEncoder(config) 14 | self.inference = InferenceLayer(config) 15 | self.matching = MatchingLayer(config) 16 | self.init_weights() 17 | 18 | def forward(self, input_ids, attention_mask, ids, 19 | start_label_masks, end_label_masks, 20 | t_start_labels=None, t_end_labels=None, 21 | o_start_labels=None, o_end_labels=None, 22 | table_labels_S=None, table_labels_E=None, 23 | polarity_labels=None, pairs_true=None): 24 | 25 | seq = self.bert(input_ids, attention_mask)[0] 26 | table = self.table_encoder(seq, attention_mask) 27 | 28 | output = self.inference(table, attention_mask, table_labels_S, table_labels_E) 29 | 30 | output['ids'] = ids 31 | 32 | output = self.matching(output, table, pairs_true, seq) 33 | return output 34 | 35 | 36 | class InferenceLayer(nn.Module): 37 | def __init__(self,config): 38 | super().__init__() 39 | self.config = config 40 | self.cls_linear_S = nn.Linear(768,1) 41 | self.cls_linear_E = nn.Linear(768,1) 42 | 43 | def span_pruning(self, pred, z, attention_mask): 44 | mask_length = attention_mask.sum(dim=1)-2 45 | length = ((attention_mask.sum(dim=1)-2)*z).long() 46 | length[length<5] = 5 47 | max_length = mask_length**2 48 | for i in range(length.shape[0]): 49 | if length[i] > max_length[i]: 50 | length[i] = max_length[i] 51 | batch_size = attention_mask.shape[0] 52 | pred_sort,_ = pred.view(batch_size, -1).sort(descending=True) 53 | batchs = torch.arange(batch_size).to('cuda') 54 | topkth = pred_sort[batchs, length-1].unsqueeze(1) 55 | return pred >= (topkth.view(batch_size,1,1)) 56 | 57 | def forward(self, table, attention_mask, table_labels_S, table_labels_E): 58 | outputs = {} 59 | 60 | logits_S = torch.squeeze(self.cls_linear_S(table), 3) 61 | logits_E = torch.squeeze(self.cls_linear_E(table), 3) 62 | 63 | 64 | loss_func = nn.BCEWithLogitsLoss(weight=(table_labels_S>=0)) 65 | 66 | outputs['table_loss_S'] = loss_func(logits_S, table_labels_S.float()) 67 | outputs['table_loss_E'] = loss_func(logits_E, table_labels_E.float()) 68 | 69 | S_pred = torch.sigmoid(logits_S) * (table_labels_S>=0) 70 | E_pred = torch.sigmoid(logits_E) * (table_labels_S>=0) 71 | 72 | if self.config.span_pruning != 0: 73 | table_predict_S = self.span_pruning(S_pred, self.config.span_pruning, attention_mask) 74 | table_predict_E = self.span_pruning(E_pred, self.config.span_pruning, attention_mask) 75 | else: 76 | table_predict_S = S_pred>0.5 77 | table_predict_E = E_pred>0.5 78 | outputs['table_predict_S'] = table_predict_S 79 | outputs['table_predict_E'] = table_predict_E 80 | outputs['table_labels_S'] = table_labels_S 81 | outputs['table_labels_E'] = table_labels_E 82 | return outputs 83 | 84 | -------------------------------------------------------------------------------- /code/model/matching_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class MatchingLayer(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | self.config = config 9 | self.linear = nn.Linear(config.hidden_size*3, 4) 10 | 11 | def gene_pred(self, batch_size, S_preds, E_preds, pairs_true): 12 | all_pred = [[] for i in range(batch_size)] 13 | pred_label = [[] for i in range(batch_size)] 14 | pred_maxlen = 0 15 | for i in range(batch_size): 16 | S_pred = torch.nonzero(S_preds[i]).cpu().numpy() 17 | E_pred = torch.nonzero(E_preds[i]).cpu().numpy() 18 | 19 | for (s0, s1) in S_pred: 20 | for (e0, e1) in E_pred: 21 | if s0 <= e0 and s1 <= e1: 22 | sentiment = 0 23 | for j in range(len(pairs_true[i])): 24 | p = pairs_true[i][j] 25 | if [s0-1, e0, s1-1, e1] == p[:4]: 26 | sentiment = p[4] 27 | pred_label[i].append(sentiment) 28 | all_pred[i].append([s0-1, e0, s1-1, e1]) 29 | if len(all_pred[i])>pred_maxlen: 30 | pred_maxlen = len(all_pred[i]) 31 | 32 | for i in range(batch_size): 33 | for j in range(len(all_pred[i]), pred_maxlen): 34 | pred_label[i].append(-1) 35 | pred_label = torch.tensor(pred_label).to('cuda') 36 | 37 | return all_pred, pred_label, pred_maxlen 38 | 39 | def input_encoding(self, batch_size, pairs, maxlen, table, seq): 40 | input_ret = torch.zeros([batch_size, maxlen, self.config.hidden_size*3]).to('cuda') 41 | for i in range(batch_size): 42 | j = 0 43 | for (s0, e0, s1, e1) in pairs[i]: 44 | S = table[i, s0+1, s1+1, :] 45 | E = table[i, e0, e1, :] 46 | R = torch.max(torch.max(table[i, s0+1:e0+1, s1+1:e1+1, :], dim=1)[0], dim=0)[0] 47 | input_ret[i, j, :] = torch.cat([S, E, R]) 48 | j += 1 49 | return input_ret 50 | 51 | 52 | def forward(self, outputs, Table, pairs_true, seq): 53 | seq = seq.clone().detach() 54 | table = Table.clone() 55 | batch_size = table.size(0) 56 | 57 | all_pred, pred_label, pred_maxlen = self.gene_pred(batch_size, outputs['table_predict_S'], outputs['table_predict_E'], pairs_true) 58 | pred_input = self.input_encoding(batch_size, all_pred, pred_maxlen, table, seq) 59 | pred_output = self.linear(pred_input) 60 | loss_func = nn.CrossEntropyLoss(ignore_index = -1) 61 | 62 | loss_input = pred_output 63 | loss_label = pred_label 64 | 65 | if loss_input.shape[1] == 0: 66 | loss_input = torch.zeros([batch_size, 1, 2]) 67 | loss_label = torch.zeros([batch_size, 1])-1 68 | 69 | outputs['pair_loss'] = loss_func(loss_input.transpose(1, 2), loss_label.long()) 70 | 71 | pairs_logits = F.softmax(pred_output, dim=2) 72 | if pairs_logits.shape[1] == 0: 73 | outputs['pairs_preds'] = [] 74 | return outputs 75 | 76 | pairs_pred = pairs_logits.argmax(dim=2) 77 | 78 | outputs['pairs_preds'] = [] 79 | for i in range(batch_size): 80 | for j in range(len(all_pred[i])): 81 | if pairs_pred[i][j] >= 1: 82 | se = all_pred[i][j] 83 | outputs['pairs_preds'].append((i, se[0], se[1], se[2], se[3], pairs_pred[i][j].item())) 84 | return outputs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](https://github.com/HITSZ-HLT/BDTF-ASTE/) | [**English**](https://github.com/HITSZ-HLT/BDTF-ASTE/blob/master/README_EN.md) 2 | 3 | 4 | # BDTF-ASTE 5 | 6 | 本仓库开源了以下论文的代码: 7 | 8 | - 标题:[Boundary-Driven Table-Filling for Aspect Sentiment Triplet Extraction](https://aclanthology.org/2022.emnlp-main.435/) 9 | - 作者:Yice Zhang∗, Yifan Yang∗, Yihui Li, Bin Liang, Shiwei Chen, Yixue Dang, Ming Yang, and Ruifeng Xu 10 | - 会议:EMNLP-2022 Main (Long) 11 | 12 | ## 工作简介 13 | 14 | ### ASTE任务 15 | 16 | 本文要解决的是Aspect-Based Sentiment Analysis(ABSA)问题中的Aspect Sentiment Triplet Extraction(ASTE)任务。 17 | 如下图所示,ASTE的目的是抽取用户评论中表达观点的方面情感三元组,一个元组包含三个部分: 18 | - Aspect Term: 情感所针对的目标对象,一般是被评价实体(餐馆或者产品)的某个方面项,常被称作方面术语、方面词、属性词等。 19 | - Opinion Term: 具体表达情感的词或短语,常被称作情感术语、情感词等。 20 | - Sentiment Polarity: 用户针对Aspect Term所表达的情感倾向,类别空间为`{POS, NEG, NEU}`。 21 | 22 |
ASTE
23 | 24 | ### 以往方法的问题 25 | 26 | 以往的方法将本任务建模为一个表格填充问题(table-filling problem)。如下图所示,二维表中的每个元素为词与词之间的关系。该方法首先通过对角线抽取aspect和opinion,然后通过aspect和opinion定位对应的关系区域,通过投票的方法是确定aspect和opinion之间的关系。该方法存在诸多问题。比较明显的问题有两个: 27 | 1. 关系不一致:将aspect和opinion之间的关系分解为词与词之间的关系,这会带来潜在的关系不一致问题。 28 | 2. 边界不敏感:如果aspect或者opinion的边界出现了小错误,二者关系预测的结果大概率不会改变,这就使得模型产生了边界错误的输出,如('dogs', 'top notch', `POS`)。 29 | 30 |
GTS
31 | 32 | 33 | 以往的工作尝试使用Span-based的方法来解决关系不一致的问题。这是一种可行的思路。但是该方法忽略了细粒度的词级别的信息,这正是表格填充方法的优点。 34 | 35 | ### 本文提出的方法 36 | 37 | 本文为了解决上述的两个问题,提出了边界驱动的表格填充方法(Boundary-Driven Table-Filling)。如下图所示,该方法将方面关系三元组转为二维表中的一个关系区域,因而将ASTE任务转化为关系区域的定位和分类。对关系区域整体进行分类可以解决了关系不一致的问题,那些边界错误的关系区域也可以通过将其分类为Invaild而移除。 38 | 39 |
BDTF
40 | 41 | 此外,本文还提出了一种关系学习的方法来学习一个二维的表示。该方法包含三个部分: 42 | - 首先,将评论文本输入到`BERT`中学习词级别的上下文表示。 43 | - 然后,通过基于张量的操作,根据词表示构建关系表示。文本中所有词之间的关系表示构成一个二维的表,表中的元素为一个向量。 44 | - 最后,使用CNN对二维表进行建模。 45 | 该方法学习到的二维表示将被用到关系区域的定位和分类中。 46 | 47 | 整体上,本文所提出方法的模型框架如下图所示。 48 | 49 |
Model
50 | 51 | ### 实验结果 52 | 53 | 本方法的主要实验结果如下表,详细的分析见论文。 54 | 55 |
Result
56 | 57 | ## 运行代码 58 | ### 环境配置 59 | 60 | - transformers==4.15.0 61 | - pytorch==1.7.1 62 | - einops=0.4.0 63 | - torchmetrics==0.7.0 64 | - tntorch==1.0.1 65 | - pytorch-lightning==1.3.5 66 | 67 | ### 代码结构 68 | 69 | ``` 70 | ├── code 71 | │ ├── utils 72 | │ │ ├── __init__.py 73 | │ │ ├── aste_datamodule.py 74 | | | └── aste_result.py 75 | │ ├── model 76 | │ │ ├── seq2mat.py 77 | │ │ ├── table.py 78 | │ │ ├── table_encoder 79 | │ │ | └── resnet.py 80 | | | └── bdtf_model.py 81 | | ├── aste_train.py 82 | | └── bash 83 | │ ├── aste.sh 84 | │ ├── aste_14res.sh 85 | │ ├── aste_14lap.sh 86 | │ ├── aste_15res.sh 87 | | └── aste_16res.sh 88 | └── data 89 | └── aste_data_bert 90 | ├── V1 91 | │ ├── 14res 92 | | │ ├── train.json 93 | | │ ├── dev.json 94 | | │ └── test.json 95 | │ ├── 14lap/... 96 | │ ├── 15res/... 97 | | └── 16res/... 98 | └── V2/... 99 | ``` 100 | 101 | ### 运行代码 102 | 103 | 在`code`目录下 104 | - 运行`chmod +x bash/*`。 105 | - 运行`bash/aste_14lap.sh`。 106 | 107 | 下面是aste_14lap.sh运行的结果。这里随机种子取的是40,计算设备为A100。 108 | 109 |
Result
110 | 111 | 在V100上跑aste_14lap.sh,结果如下。 112 | 113 |
Result2
114 | 115 | 请注意,文章发布的性能都是在5个随机种子下运行然后取平均的结果,这与单次运行可能存在一些出入。 116 | 117 | ## 如有问题请在`issues`提出,或者联系我 118 | 119 | - email: `zhangyc_hit@163.com` 120 | 121 | 122 | -------------------------------------------------------------------------------- /code/utils/aste_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from . import append_json, save_json, mkdir_if_not_exist 4 | 5 | polarity_map = { 6 | 'NEG': 0, 7 | 'NEU': 1, 8 | 'POS': 2 9 | } 10 | polarity_map_reversed = { 11 | 0: 'NEG', 12 | 1: 'NEU', 13 | 2: 'POS' 14 | } 15 | 16 | class F1_Measure: 17 | def __init__(self): 18 | self.pred_set = set() 19 | self.true_set = set() 20 | 21 | def pred_inc(self, idx, preds): 22 | for pred in preds: 23 | self.pred_set.add((idx, tuple(pred))) 24 | 25 | def true_inc(self, idx, trues): 26 | for true in trues: 27 | self.true_set.add((idx, tuple(true))) 28 | 29 | def report(self): 30 | self.f1, self.p, self.r = self.cal_f1(self.pred_set, self.true_set) 31 | return self.f1 32 | 33 | def __getitem__(self, key): 34 | if hasattr(self, key): 35 | return getattr(self, key) 36 | else: 37 | raise NotImplementedError 38 | 39 | def cal_f1(self, pred_set, true_set): 40 | intersection = pred_set.intersection(true_set) 41 | _p = len(intersection) / len(pred_set) if pred_set else 1 42 | _r = len(intersection) / (len(true_set)) if true_set else 1 43 | f1 = 2 * _p * _r / (_p + _r) if _p + _r else 0 44 | return f1, _p, _r 45 | 46 | 47 | 48 | class NER_F1_Measure(F1_Measure): 49 | def __init__(self, entity_types): 50 | super().__init__() 51 | self.entity_types = entity_types 52 | 53 | def report(self): 54 | for entity_type in self.entity_types: 55 | name = '_'.join(entity_type) 56 | pred_set = self.filter(self.pred_set, entity_type) 57 | true_set = self.filter(self.true_set, entity_type) 58 | 59 | f1, p, r = self.cal_f1(pred_set, true_set) 60 | setattr(self, f'{name}_f1', f1) 61 | 62 | def filter(self, set_, entity_type): 63 | return_set = set() 64 | for type_ in entity_type: 65 | return_set.update(set([it for it in set_ if it[1][0] == type_])) 66 | return return_set 67 | 68 | 69 | class Result: 70 | def __init__(self, result_json): 71 | self.result_json = result_json 72 | self.detailed_metric = None 73 | self.monitor = 0 74 | 75 | def __setitem__(self, key, value): 76 | self.result_json[key] = value 77 | 78 | def __ge__(self, other): 79 | return self.monitor >= other.monitor 80 | 81 | def __gt__(self, other): 82 | return self.monitor > other.monitor 83 | 84 | @classmethod 85 | def parse_from(cls, all_preds, examples): 86 | result_json = {} 87 | examples = {example['ID']: example for example in examples} 88 | 89 | for preds in all_preds: 90 | 91 | for ID in preds['ids']: 92 | example = examples[ID] 93 | pairs_true = [] 94 | for pp in example['pairs']: 95 | pl = polarity_map[pp[4]]+1 96 | pairs_true.append([pp[0],pp[1],pp[2],pp[3],pl]) 97 | result_json[ID] = { 98 | 'ID': ID, 99 | 'sentence': example['sentence'], 100 | 'pairs': pairs_true, 101 | 'tokens': str(example['tokens']), 102 | 'pair_preds': set(), 103 | } 104 | for (i, a_start, a_end, b_start, b_end, pol) in preds['pair_preds']: 105 | ID = preds['ids'][i] 106 | result_json[ID]['pair_preds'].add((a_start, a_end, b_start, b_end, pol)) 107 | return cls(result_json) 108 | 109 | 110 | def cal_metric(self): 111 | b_pair_f1 = F1_Measure() 112 | for item in self.result_json.values(): 113 | for pair_f1 in (b_pair_f1, ): 114 | pair_f1.true_inc(item['ID'], item['pairs']) 115 | 116 | b_pair_f1.pred_inc(item['ID'], item['pair_preds']) 117 | 118 | b_pair_f1.report() 119 | 120 | detailed_metrics = { 121 | 'pair_f1': b_pair_f1['f1'], 122 | 'pair_p':b_pair_f1['p'], 123 | 'pair_r':b_pair_f1['r'], 124 | } 125 | 126 | self.detailed_metrics = detailed_metrics 127 | self.monitor = b_pair_f1['f1'] 128 | 129 | def report(self): 130 | print(f'monitor: {self.monitor:.4f}', end=" ** ") 131 | for metric_names in (('pair_f1','pair_p','pair_r'),): 132 | for metric_name in metric_names: 133 | value = self.detailed_metrics[metric_name] if metric_name in self.detailed_metrics else 0 134 | print(f'{metric_name}: {value:.4f}', end=' | ') 135 | print() 136 | 137 | def save(self, dir_name, args): 138 | mkdir_if_not_exist(dir_name) 139 | current_time = time.strftime("%Y-%m-%d %H_%M_%S", time.localtime()) 140 | current_day = time.strftime("%Y-%m-%d", time.localtime()) 141 | 142 | result_file_name = os.path.join(dir_name, f'val_results_{self.monitor*10000:4.0f}_{current_time}.txt') 143 | performance_dir = os.path.join(os.path.dirname(os.path.dirname(dir_name)), 'performance') 144 | 145 | performance_dir = os.path.join(performance_dir, current_day) 146 | performance_file_name = os.path.join(performance_dir, f'{args.cuda_ids}.txt') 147 | 148 | for key, item in self.result_json.items(): 149 | item['pair_preds'] = list(item['pair_preds']) 150 | 151 | save_json(list(self.result_json.values()), result_file_name) 152 | print('## save result to', result_file_name) 153 | 154 | description = f'{args.data_dir}, lr={args.learning_rate}, seed={args.seed}, model_name_or_path={args.model_name_or_path}' 155 | detailed_metrics = {k: (v if type(v) in (int, float) else v.item()) for k,v in self.detailed_metrics.items()} 156 | 157 | append_json(performance_file_name, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 158 | append_json(performance_file_name, f'{description} {self.monitor*10000:4.0f}') 159 | append_json(performance_file_name, f'{args.span_pruning} {args.seq2mat} {args.num_d} {args.table_encoder} {args.num_table_layers}') 160 | append_json(performance_file_name, detailed_metrics) 161 | append_json(performance_file_name, '') 162 | 163 | 164 | -------------------------------------------------------------------------------- /code/model/seq2mat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers.models.t5.modeling_t5 import T5LayerNorm 4 | from transformers.activations import ACT2FN 5 | 6 | 7 | class Seq2Mat(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | self.config = config 11 | self.W = nn.Linear(config.hidden_size*2, config.hidden_size) 12 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps) 13 | self.activation = ACT2FN[config.hidden_act] 14 | 15 | def forward(self, x, y): 16 | """ 17 | x,y: [B, L, H] => [B, L, L, H] 18 | """ 19 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :]) 20 | t = torch.cat([x, y], dim=-1) 21 | t = self.W(t) 22 | t = self.activation(t) 23 | return t 24 | 25 | 26 | class ContextSeq2Mat(nn.Module): 27 | def __init__(self, config): 28 | super().__init__() 29 | self.config = config 30 | self.W = nn.Linear(config.hidden_size*3, config.hidden_size) 31 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps) 32 | self.activation = ACT2FN[config.hidden_act] 33 | 34 | def forward(self, x, y): 35 | """ 36 | x,y: [B, L, H] => [B, L, L, H] 37 | """ 38 | xmat = x.clone() 39 | batch_size = xmat.shape[0] 40 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :]) 41 | 42 | max_len = xmat.shape[1] 43 | xmat_t = xmat.transpose(1, 2) 44 | context = torch.ones_like(x).to('cuda') 45 | for i in range(max_len): 46 | diag = x.diagonal(dim1=1, dim2=2, offset=-i) 47 | xmat_t = torch.max(xmat_t[:, :, :max_len-i], diag) 48 | bb = [[b] for b in range(batch_size)] 49 | linexup = [[j for j in range(max_len-i)] for b in range(batch_size)] 50 | lineyup = [[j+i for j in range(max_len-i)] for b in range(batch_size)] 51 | linexdown = [[j+i for j in range(max_len-i)] for b in range(batch_size)] 52 | lineydown = [[j for j in range(max_len-i)] for b in range(batch_size)] 53 | context[bb, linexup, lineyup, :] = xmat_t.permute(0, 2, 1) 54 | context[bb, linexdown, lineydown, :] = xmat_t.permute(0, 2, 1) 55 | 56 | t = torch.cat([x, y, context], dim=-1) 57 | t = self.W(t) 58 | t = self.activation(t) 59 | return t 60 | 61 | 62 | class TensorSeq2Mat(nn.Module): 63 | """ 64 | refernce: SOCHER R, PERELYGIN A, WU J, 等. Recursive deep models for semantic compositionality over a sentiment treebank[C]//Proceedings of the 2013 conference on empirical methods in natural language processing. 2013: 1631-1642. 65 | """ 66 | def __init__(self, config): 67 | super().__init__() 68 | self.config = config 69 | self.h = config.num_attention_heads 70 | self.d = config.num_d 71 | self.W = nn.Linear(2*config.hidden_size+self.d, config.hidden_size) 72 | self.V = nn.Parameter(torch.Tensor(self.d, config.hidden_size, config.hidden_size)) 73 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps) 74 | self.activation = ACT2FN[config.hidden_act] 75 | self.init_weights() 76 | 77 | def init_weights(self): 78 | self.V.data.normal_(mean=0.0, std=self.config.initializer_range) 79 | 80 | def rntn(self, x, y): 81 | t = torch.cat([x, y], dim=-1) 82 | xv = torch.einsum('b m n p, k p d -> b m n k d', x, self.V) 83 | xvy = torch.einsum('b m n k d, b m n d -> b m n k', xv, y) 84 | t = torch.cat([t, xvy], dim=-1) 85 | tw = self.W(t) 86 | return tw 87 | 88 | def forward(self, x, y): 89 | """ 90 | x,y: [B, L, H] => [B, L, L, H] 91 | """ 92 | seq = x 93 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :]) 94 | t = self.rntn(x, y) 95 | t = self.activation(t) 96 | return t 97 | 98 | 99 | class TensorcontextSeq2Mat(nn.Module): 100 | """ 101 | refernce: SOCHER R, PERELYGIN A, WU J, 等. Recursive deep models for semantic compositionality over a sentiment treebank[C]//Proceedings of the 2013 conference on empirical methods in natural language processing. 2013: 1631-1642. 102 | """ 103 | def __init__(self, config): 104 | super().__init__() 105 | self.config = config 106 | self.h = config.num_attention_heads 107 | self.d = config.num_d 108 | self.W = nn.Linear(3*config.hidden_size+self.d, config.hidden_size) 109 | self.V = nn.Parameter(torch.Tensor(self.d, config.hidden_size, config.hidden_size)) 110 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps) 111 | self.activation = ACT2FN[config.hidden_act] 112 | self.init_weights() 113 | 114 | def init_weights(self): 115 | if self.config.model_type=='bart' or self.config.model_type=='t5': 116 | self.V.data.normal_(mean=0.0, std=0.02) 117 | else: 118 | self.V.data.normal_(mean=0.0, std=self.config.initializer_range) 119 | 120 | def rntn(self, x, y, xmat): 121 | max_len = xmat.shape[1] 122 | xmat_t = xmat.transpose(1, 2) 123 | batch_size = xmat.shape[0] 124 | context = torch.ones_like(x).to('cuda') 125 | for i in range(max_len): 126 | diag = x.diagonal(dim1=1, dim2=2, offset=-i) 127 | xmat_t = torch.max(xmat_t[:, :, :max_len-i], diag) 128 | bb = [[b] for b in range(batch_size)] 129 | linexup = [[j for j in range(max_len-i)] for b in range(batch_size)] 130 | lineyup = [[j+i for j in range(max_len-i)] for b in range(batch_size)] 131 | linexdown = [[j+i for j in range(max_len-i)] for b in range(batch_size)] 132 | lineydown = [[j for j in range(max_len-i)] for b in range(batch_size)] 133 | context[bb, linexup, lineyup, :] = xmat_t.permute(0, 2, 1) 134 | context[bb, linexdown, lineydown, :] = xmat_t.permute(0, 2, 1) 135 | 136 | t = torch.cat([x, y, context], dim=-1) 137 | xvy = torch.einsum('b m n p, k p d, b m n d -> b m n k', x, self.V, y) 138 | t = torch.cat([t, xvy], dim=-1) 139 | tw = self.W(t) 140 | return tw 141 | 142 | def forward(self, x, y): 143 | """ 144 | x,y: [B, L, H] => [B, L, L, H] 145 | """ 146 | xmat = x 147 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :]) 148 | t = self.rntn(x, y, xmat) 149 | t = self.activation(t) 150 | return t 151 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](https://github.com/HITSZ-HLT/BDTF-ASTE/) | [**English**](https://github.com/HITSZ-HLT/BDTF-ASTE/blob/master/README_EN.md) 2 | 3 | # BDTF-ASTE 4 | 5 | This repository releases the code of the following paper: 6 | 7 | - Title: [Boundary-Driven Table-Filling for Aspect Sentiment Triplet Extraction](https://aclanthology.org/2022.emnlp-main.435/) 8 | - Authors: Yice Zhang∗, Yifan Yang∗, Yihui Li, Bin Liang, Shiwei Chen, Yixue Dang, Ming Yang, and Ruifeng Xu 9 | - Conference: EMNLP-2022 Main (Long) 10 | 11 | ## Brief Introduction of Our Paper 12 | 13 | ### The ASTE Task 14 | 15 | The task that this paper addresses is Aspect Sentiment Triplet Extraction (ASTE), which is an important task in Aspect-Based Sentiment Analysis(ABSA). 16 | As shown in the figure below, ASTE aims to extract the aspect terms along with the corresponding opinion terms and the expressed sentiments in the review. 17 | Specifically, a triplet is defined as (aspect term, opinion term, sentiment polarity): 18 | - **Aspect term**: the target of an opinion, usually an aspect of an entity (a restaurant or product). 19 | - **Opinion term**: the word or phrase that specifically expresses the sentiment. 20 | - **Sentiment polarity**: a specific category in `{POS, NEG, NEU}`. 21 | 22 |
ASTE
23 | 24 | ### Previous Methods and Limitations 25 | 26 | Previous methods tackle the ASTE task through a table-filling approach, where the triplets are represented by a two-dimensional (2D) table of word-pair relations. 27 | In this approach, aspect terms and opinion terms are extracted through the diagonal elements of the table, and sentiments are treated as relation tags that are represented by the non-diagonal elements of the table. 28 | This formalization enables joint learning of different subtasks in ASTE, achieving superior performance over the pipeline approach. 29 | 30 | However, the previous table formalization suffers from relation inconsistency and boundary insensitivity when dealing with multi-word aspect terms and opinion terms. It decomposes the relation between an aspect term and an opinion term into the relations between the corresponding aspect words and opinion words. 31 | In other words, a term-level relation is represented by several wordlevel relation tags. The relation tags in the table are assigned independently, which leads to potential inconsistencies in the predictions of the wordlevel relations. 32 | In addition, when there are minor boundary errors in the aspect term or opinion term, the voting result for the term-level relation may stay unchanged, encouraging the model to produce wrong predictions. Researchers try to solve this problem through a span-based method, but their method discards fine-grained word-level information, which is the advantage of the table-filling approach. 33 | 34 |
GTS
35 | 36 | ### Proposed Approach 37 | 38 | This paper proposes a Boundary-Driven Table-Filling (BDTF) approach for ASTE to overcome the above issues. 39 | In BDTF, a triplet is represented as a relation region in the 2D table, which is shown in the figure below. 40 | In this way, it extracts triplets by directly detecting and classifying the relation regions in a 2D table. 41 | Classification over the entire relation region ensures relation consistency, and those relation regions with boundary errors can be removed by being classified as invalid. 42 | 43 |
BDTF
44 | 45 | 46 | In addition, this paper also develops an effective relation representation learning approach to learn the table representation. 47 | This consists of three parts: 48 | - We first learn the word-level contextualized representations of the input review through a pre-trained language model. 49 | - Then we adopt a tensor-based operation to construct the relation-level representations to fully exploit the word-to-word interactions. 50 | - Finally, we model relation-to-relation interactions through a multi-layer convolution-based encoder to enhance the relation-level representations. 51 | 52 | The relation representations of each two words in the review together form a 2D relation matrix, which serves as the table representation for BDTF. 53 | 54 | The proposed approach is briefly presented in the figure below. 55 | 56 |
Model
57 | 58 | ### Experimentual Results 59 | 60 | The main results are listed in the table below. See the paper for a detailed analysis. 61 | 62 |
Result
63 | 64 | ## How to Run 65 | 66 | ### Requirements 67 | 68 | - transformers==4.15.0 69 | - pytorch==1.7.1 70 | - einops=0.4.0 71 | - torchmetrics==0.7.0 72 | - tntorch==1.0.1 73 | - pytorch-lightning==1.3.5 74 | 75 | ### Files 76 | 77 | ``` 78 | ├── code 79 | │ ├── utils 80 | │ │ ├── __init__.py 81 | │ │ ├── aste_datamodule.py 82 | | | └── aste_result.py 83 | │ ├── model 84 | │ │ ├── seq2mat.py 85 | │ │ ├── table.py 86 | │ │ ├── table_encoder 87 | │ │ | └── resnet.py 88 | | | └── bdtf_model.py 89 | | ├── aste_train.py 90 | | └── bash 91 | │ ├── aste.sh 92 | │ ├── aste_14res.sh 93 | │ ├── aste_14lap.sh 94 | │ ├── aste_15res.sh 95 | | └── aste_16res.sh 96 | └── data 97 | └── aste_data_bert 98 | ├── V1 99 | │ ├── 14res 100 | | │ ├── train.json 101 | | │ ├── dev.json 102 | | │ └── test.json 103 | │ ├── 14lap/... 104 | │ ├── 15res/... 105 | | └── 16res/... 106 | └── V2/... 107 | ``` 108 | 109 | ### Run our code! 110 | 111 | Enter `code` and 112 | - execute `chmod +x bash/*`, 113 | - execute `bash/aste_14lap.sh`. 114 | 115 | Result of aste_14lap.sh (Random seed is set to be 40 and the computing device is A100): 116 | 117 |
Result
118 | 119 | Result of aste_14lap.sh (Random seed is set to be 40 and the computing device is V100): 120 | 121 |
Result2
122 | 123 | Note that the performance posted in the paper is the average results of 5 run with 5 different random seeds, which has some differences from a single run. 124 | 125 | ## If you have any questions, please raise an `issue` or contact me 126 | 127 | - email: `zhangyc_hit@163.com` 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /code/utils/aste_datamodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from transformers import AutoTokenizer 5 | 6 | import os 7 | from . import load_json 8 | 9 | polarity_map = { 10 | 'NEG': 0, 11 | 'NEU': 1, 12 | 'POS': 2 13 | } 14 | 15 | polarity_map_reversed = { 16 | 0: 'NEG', 17 | 1: 'NEU', 18 | 2: 'POS' 19 | } 20 | 21 | 22 | class Example: 23 | def __init__(self, data, max_length=-1): 24 | self.data = data 25 | self.max_length = max_length 26 | self.data['tokens'] = eval(str(self.data['tokens'])) 27 | 28 | def __getitem__(self, key): 29 | return self.data[key] 30 | 31 | def t_entities(self): 32 | return [tuple(entity[:3]) for entity in self['entities'] if entity[0]=='target'] 33 | 34 | def o_entities(self): 35 | return [tuple(entity[:3]) for entity in self['entities'] if entity[0]=='opinion'] 36 | 37 | def entity_label(self, target_oriented, length): 38 | entities = self.t_entities() if target_oriented else self.o_entities() 39 | return Example.make_start_end_labels(entities, length) 40 | 41 | def table_label(self, length, ty, id_len): 42 | label = [[-1 for _ in range(length)] for _ in range(length)] 43 | id_len = id_len.item() 44 | 45 | for i in range(1, id_len-1): 46 | for j in range(1, id_len-1): 47 | label[i][j] = 0 48 | 49 | for t_start, t_end, o_start, o_end, pol in self['pairs']: 50 | if ty == 'S': 51 | label[t_start+1][o_start+1] = 1 52 | elif ty == 'E': 53 | label[t_end][o_end] = 1 54 | return label 55 | 56 | @staticmethod 57 | def make_start_end_labels(entities, length, plus_one=True): 58 | start_label = [0] * length 59 | end_label = [0] * length 60 | 61 | for (t, s, e) in entities: 62 | if plus_one: 63 | s, e = s+1, e+1 64 | 65 | if s < length: 66 | start_label[s] = 1 67 | 68 | if e-1 < length: 69 | end_label[e-1] = 1 70 | 71 | return start_label, end_label 72 | 73 | 74 | class DataCollatorForASTE: 75 | def __init__(self, tokenizer, max_seq_length): 76 | self.tokenizer = tokenizer 77 | self.max_seq_length = max_seq_length 78 | 79 | 80 | def __call__(self, examples): 81 | 82 | batch = self.tokenizer_function(examples) 83 | 84 | length = batch['input_ids'].size(1) 85 | 86 | batch['t_start_labels'], batch['t_end_labels'] = self.start_end_labels(examples, True, length) 87 | batch['o_start_labels'], batch['o_end_labels'] = self.start_end_labels(examples, False, length) 88 | batch['example_ids'] = [example['ID'] for example in examples] 89 | batch['table_labels_S'] = torch.tensor([examples[i].table_label(length, 'S', (batch['input_ids'][i]>0).sum()) for i in range(len(examples))], dtype=torch.long) 90 | batch['table_labels_E'] = torch.tensor([examples[i].table_label(length, 'E', (batch['input_ids'][i]>0).sum()) for i in range(len(examples))], dtype=torch.long) 91 | 92 | al = [example['pairs'] for example in examples] 93 | pairs_ret = [] 94 | for pairs in al: 95 | pairs_chg = [] 96 | for p in pairs: 97 | pairs_chg.append([p[0],p[1],p[2], p[3], polarity_map[p[4]]+1]) 98 | pairs_ret.append(pairs_chg) 99 | batch['pairs_true'] = pairs_ret 100 | 101 | return { 102 | 'ids': batch['example_ids'], 103 | 'input_ids' : batch['input_ids'], 104 | 'attention_mask': batch['attention_mask'], 105 | 106 | 't_start_labels': batch['t_start_labels'], 107 | 't_end_labels' : batch['t_end_labels'], 108 | 'o_start_labels': batch['o_start_labels'], 109 | 'o_end_labels' : batch['o_end_labels'], 110 | 111 | 'start_label_masks': batch['start_label_masks'], 112 | 'end_label_masks' : batch['end_label_masks'], 113 | 114 | 'table_labels_S' : batch['table_labels_S'], 115 | 'table_labels_E' : batch['table_labels_E'], 116 | 'pairs_true' : batch['pairs_true'], 117 | } 118 | 119 | def start_end_labels(self, examples, target_oriented, length): 120 | start_labels = [] 121 | end_labels = [] 122 | 123 | for example in examples: 124 | start_label, end_label = example.entity_label(target_oriented, length) 125 | start_labels.append(start_label) 126 | end_labels.append(end_label) 127 | 128 | start_labels = torch.tensor(start_labels, dtype=torch.long) 129 | end_labels = torch.tensor(end_labels, dtype=torch.long) 130 | 131 | return start_labels, end_labels 132 | 133 | def tokenizer_function(self, examples): 134 | text = [example['sentence'] for example in examples] 135 | kwargs = { 136 | 'text': text, 137 | 'return_tensors': 'pt' 138 | } 139 | 140 | if self.max_seq_length in (-1, 'longest'): 141 | kwargs['padding'] = True 142 | else: 143 | kwargs['padding'] = 'max_length' 144 | kwargs['max_length'] = self.max_seq_length 145 | kwargs['truncation'] = True 146 | 147 | batch_encodings = self.tokenizer(**kwargs) 148 | length = batch_encodings['input_ids'].size(1) 149 | 150 | start_label_masks = [] 151 | end_label_masks = [] 152 | 153 | for i in range(len(examples)): 154 | encoding = batch_encodings[i] 155 | word_ids = encoding.word_ids 156 | type_ids = encoding.type_ids 157 | 158 | start_label_mask = [(1 if type_ids[i]==0 else 0) for i in range(length)] 159 | end_label_mask = [(1 if type_ids[i]==0 else 0) for i in range(length)] 160 | 161 | for token_idx in range(length): 162 | current_word_idx = word_ids[token_idx] 163 | prev_word_idx = word_ids[token_idx-1] if token_idx-1 > 0 else None 164 | next_word_idx = word_ids[token_idx+1] if token_idx+1 < length else None 165 | 166 | if prev_word_idx is not None and current_word_idx == prev_word_idx: 167 | start_label_mask[token_idx] = 0 168 | if next_word_idx is not None and current_word_idx == next_word_idx: 169 | end_label_mask[token_idx] = 0 170 | 171 | start_label_masks.append(start_label_mask) 172 | end_label_masks.append(end_label_mask) 173 | 174 | batch_encodings = dict(batch_encodings) 175 | batch_encodings['start_label_masks'] = torch.tensor(start_label_masks, dtype=torch.long) 176 | batch_encodings['end_label_masks'] = torch.tensor(end_label_masks, dtype=torch.long) 177 | 178 | return batch_encodings 179 | 180 | 181 | class ASTEDataModule(pl.LightningDataModule): 182 | def __init__(self, 183 | model_name_or_path: str='', 184 | max_seq_length: int = -1, 185 | train_batch_size: int = 32, 186 | eval_batch_size: int = 32, 187 | data_dir: str = '', 188 | num_workers: int = 4, 189 | cuda_ids: int = -1, 190 | ): 191 | 192 | super().__init__() 193 | 194 | self.model_name_or_path = model_name_or_path 195 | self.max_seq_length = max_seq_length if max_seq_length > 0 else 'longest' 196 | self.train_batch_size = train_batch_size 197 | self.eval_batch_size = eval_batch_size 198 | 199 | self.data_dir = data_dir 200 | self.num_workers = num_workers 201 | self.cuda_ids = cuda_ids 202 | 203 | self.table_num_labels = 6 # 4 204 | 205 | try: 206 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) 207 | except: 208 | self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True) 209 | 210 | def load_dataset(self): 211 | train_file_name = os.path.join(self.data_dir, 'train.json') 212 | dev_file_name = os.path.join(self.data_dir, 'dev.json') 213 | test_file_name = os.path.join(self.data_dir, 'test.json') 214 | 215 | if not os.path.exists(dev_file_name): 216 | dev_file_name = test_file_name 217 | 218 | train_examples = [Example(data, self.max_seq_length) for data in load_json(train_file_name)] 219 | dev_examples = [Example(data, self.max_seq_length) for data in load_json(dev_file_name)] 220 | test_examples = [Example(data, self.max_seq_length) for data in load_json(test_file_name)] 221 | 222 | self.raw_datasets = { 223 | 'train': train_examples, 224 | 'dev' : dev_examples, 225 | 'test' : test_examples 226 | } 227 | 228 | def get_dataloader(self, mode, batch_size, shuffle): 229 | dataloader = DataLoader( 230 | dataset=self.raw_datasets[mode], 231 | batch_size=batch_size, 232 | shuffle=shuffle, 233 | num_workers=self.num_workers, 234 | collate_fn=DataCollatorForASTE(tokenizer=self.tokenizer, 235 | max_seq_length=self.max_seq_length), 236 | pin_memory=True, 237 | prefetch_factor=16 238 | ) 239 | 240 | print(mode, len(dataloader)) 241 | return dataloader 242 | 243 | def train_dataloader(self): 244 | return self.get_dataloader('train', self.train_batch_size, shuffle=True) 245 | 246 | def val_dataloader(self): 247 | return self.get_dataloader("dev", self.eval_batch_size, shuffle=False) 248 | 249 | def test_dataloader(self): 250 | return self.get_dataloader("test", self.eval_batch_size, shuffle=False) 251 | -------------------------------------------------------------------------------- /code/aste_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import argparse 5 | import random 6 | import numpy as np 7 | 8 | 9 | import pytorch_lightning as pl 10 | 11 | pl.seed_everything(42) 12 | 13 | from transformers import AutoTokenizer, AutoConfig 14 | from transformers.optimization import AdamW 15 | from pytorch_lightning.utilities import rank_zero_info 16 | from transformers import ( 17 | get_linear_schedule_with_warmup, 18 | get_cosine_schedule_with_warmup, 19 | get_cosine_with_hard_restarts_schedule_with_warmup, 20 | get_polynomial_decay_schedule_with_warmup, 21 | get_constant_schedule_with_warmup 22 | ) 23 | 24 | arg_to_scheduler = { 25 | 'linear': get_linear_schedule_with_warmup, 26 | 'cosine': get_cosine_schedule_with_warmup, 27 | 'cosine_w_restarts': get_cosine_with_hard_restarts_schedule_with_warmup, 28 | 'polynomial': get_polynomial_decay_schedule_with_warmup, 29 | 'constant': get_constant_schedule_with_warmup, 30 | } 31 | 32 | 33 | 34 | from model.bdtf_model import BDTFModel 35 | from utils.aste_datamodule import ASTEDataModule 36 | from utils.aste_result import Result 37 | from utils import params_count 38 | 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | class ASTE(pl.LightningModule): 44 | def __init__(self, hparams, data_module): 45 | super().__init__() 46 | self.save_hyperparameters(hparams) 47 | self.data_module = data_module 48 | 49 | self.config = AutoConfig.from_pretrained(self.hparams.model_name_or_path) 50 | self.config.table_num_labels = self.data_module.table_num_labels 51 | self.config.table_encoder = self.hparams.table_encoder 52 | self.config.num_table_layers = self.hparams.num_table_layers 53 | self.config.span_pruning = self.hparams.span_pruning 54 | self.config.seq2mat = self.hparams.seq2mat 55 | self.config.num_d = self.hparams.num_d 56 | 57 | self.model = BDTFModel.from_pretrained(self.hparams.model_name_or_path, config=self.config) 58 | 59 | print(self.model.config) 60 | print('---------------------------------------') 61 | print('total params_count:', params_count(self.model)) 62 | print('---------------------------------------') 63 | 64 | @pl.utilities.rank_zero_only 65 | def save_model(self): 66 | dir_name = os.path.join(self.hparams.output_dir, str(self.hparams.cuda_ids), 'model') 67 | print(f'## save model to {dir_name}') 68 | self.model.save_pretrained(dir_name) 69 | 70 | def load_model(self): 71 | dir_name = os.path.join(self.hparams.output_dir, str(self.hparams.cuda_ids), 'model') 72 | print(f'## load model to {dir_name}') 73 | self.model = BDTFModel.from_pretrained(dir_name) 74 | 75 | def forward(self, **inputs): 76 | outputs = self.model(**inputs) 77 | return outputs 78 | 79 | def training_step(self, batch, batch_idx): 80 | outputs = self(**batch) 81 | loss = outputs['table_loss_S'] + outputs['table_loss_E'] + outputs['pair_loss'] 82 | 83 | self.log('train_loss', loss) 84 | return loss 85 | 86 | def validation_step(self, batch, batch_idx): 87 | outputs = self(**batch) 88 | loss = outputs['table_loss_S'] + outputs['table_loss_E'] + outputs['pair_loss'] 89 | 90 | self.log('valid_loss', loss) 91 | 92 | return { 93 | 'ids': outputs['ids'], 94 | 'table_predict_S': outputs['table_predict_S'], 95 | 'table_predict_E': outputs['table_predict_E'], 96 | 'table_labels_S': outputs['table_labels_S'], 97 | 'table_labels_E': outputs['table_labels_E'], 98 | 'pair_preds': outputs['pairs_preds'] 99 | } 100 | 101 | def validation_epoch_end(self, outputs): 102 | examples = self.data_module.raw_datasets['dev'] 103 | 104 | self.current_val_result = Result.parse_from(outputs, examples) 105 | self.current_val_result.cal_metric() 106 | 107 | if not hasattr(self, 'best_val_result'): 108 | self.best_val_result = self.current_val_result 109 | 110 | elif self.best_val_result < self.current_val_result: 111 | self.best_val_result = self.current_val_result 112 | self.save_model() 113 | 114 | def test_step(self, batch, batch_idx): 115 | return self.validation_step(batch,batch_idx) 116 | 117 | def test_epoch_end(self, outputs): 118 | examples = self.data_module.raw_datasets['test'] 119 | 120 | self.test_result = Result.parse_from(outputs, examples) 121 | self.test_result.cal_metric() 122 | 123 | def save_test_result(self): 124 | dir_name = os.path.join(self.hparams.output_dir, 'result') 125 | self.test_result.save(dir_name, self.hparams) 126 | 127 | def setup(self, stage): 128 | if stage == 'fit': 129 | self.train_loader = self.train_dataloader() 130 | ngpus = (len(self.hparams.gpus.split(',')) if type(self.hparams.gpus) is str else self.hparams.gpus) 131 | effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * ngpus 132 | dataset_size = len(self.train_loader.dataset) 133 | self.total_steps = (dataset_size / effective_batch_size) * self.hparams.max_epochs 134 | 135 | def get_lr_scheduler(self): 136 | get_scheduler_func = arg_to_scheduler[self.hparams.lr_scheduler] 137 | if self.hparams.lr_scheduler == 'constant': 138 | scheduler = get_scheduler_func(self.opt, num_warmup_steps=self.hparams.warmup_steps) 139 | else: 140 | scheduler = get_scheduler_func(self.opt, num_warmup_steps=self.hparams.warmup_steps, 141 | num_training_steps=self.total_steps) 142 | scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1} 143 | return scheduler 144 | 145 | def configure_optimizers(self): 146 | no_decay = ['bias', 'LayerNorm.weight'] 147 | 148 | def has_keywords(n, keywords): 149 | return any(nd in n for nd in keywords) 150 | 151 | optimizer_grouped_parameters = [ 152 | { 153 | 'params': [p for n, p in self.model.named_parameters() if not has_keywords(n, no_decay)], 154 | 'lr': self.hparams.learning_rate, 155 | 'weight_decay': 0 156 | }, 157 | { 158 | 'params': [p for n, p in self.model.named_parameters() if has_keywords(n, no_decay)], 159 | 'lr': self.hparams.learning_rate, 160 | 'weight_decay': self.hparams.weight_decay 161 | } 162 | ] 163 | 164 | optimizer = AdamW(optimizer_grouped_parameters, eps=self.hparams.adam_epsilon) 165 | self.opt = optimizer 166 | scheduler = self.get_lr_scheduler() 167 | 168 | return [optimizer], [scheduler] 169 | 170 | @staticmethod 171 | def add_model_specific_args(parser): 172 | parser.add_argument("--learning_rate", default=1e-5, type=float) 173 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 174 | parser.add_argument("--warmup_steps", default=0, type=int) 175 | parser.add_argument("--weight_decay", default=0., type=float) 176 | parser.add_argument("--lr_scheduler", type=str) 177 | 178 | parser.add_argument("--seed", default=42, type=int) 179 | parser.add_argument("--output_dir", type=str) 180 | parser.add_argument("--do_train", action='store_true') 181 | 182 | parser.add_argument("--table_encoder", type=str, default='resnet', choices=['resnet','none']) 183 | parser.add_argument("--num_table_layers", type=int, default=2) 184 | parser.add_argument("--span_pruning", type=float, default=0.3) 185 | parser.add_argument("--seq2mat", type=str, default='none',choices=['none','tensor','context','tensorcontext']) 186 | parser.add_argument("--num_d", type=int, default=64) 187 | return parser 188 | 189 | 190 | class LoggingCallback(pl.Callback): 191 | def on_validation_end(self, trainer, pl_module): 192 | metrics = trainer.callback_metrics 193 | metrics = {k:(v.detach() if type(v) is torch.Tensor else v) for k,v in metrics.items()} 194 | rank_zero_info(metrics) 195 | 196 | 197 | class LoggingCallback(pl.Callback): 198 | def on_validation_end(self, trainer, pl_module): 199 | print('-------------------------------------------------------------------------------------------------------------------\n[current]\t', end='') 200 | pl_module.current_val_result.report() 201 | 202 | print('[best]\t\t', end='') 203 | pl_module.best_val_result.report() 204 | print('-------------------------------------------------------------------------------------------------------------------\n') 205 | 206 | def on_test_end(self, trainer, pl_module): 207 | pl_module.test_result.report() 208 | pl_module.save_test_result() 209 | 210 | 211 | def main(): 212 | parser = argparse.ArgumentParser() 213 | parser = pl.Trainer.add_argparse_args(parser) 214 | parser = ASTE.add_model_specific_args(parser) 215 | parser = ASTEDataModule.add_argparse_args(parser) 216 | 217 | args = parser.parse_args() 218 | pl.seed_everything(args.seed) 219 | 220 | if args.learning_rate >= 1: 221 | args.learning_rate /= 1e5 222 | 223 | 224 | data_module = ASTEDataModule.from_argparse_args(args) 225 | data_module.load_dataset() 226 | model = ASTE(args, data_module) 227 | 228 | logging_callback = LoggingCallback() 229 | 230 | kwargs = { 231 | 'weights_summary': None, 232 | 'callbacks': [logging_callback], 233 | 'logger': True, 234 | 'checkpoint_callback': False, 235 | 'num_sanity_val_steps': 5 if args.do_train else 0, 236 | } 237 | 238 | trainer = pl.Trainer.from_argparse_args(args, **kwargs) 239 | 240 | trainer.fit(model, datamodule=data_module) 241 | model.load_model() 242 | trainer.test(model, datamodule=data_module) 243 | 244 | if __name__ == '__main__': 245 | main() --------------------------------------------------------------------------------